Skip to content

Commit 6bf12af

Browse files
committed
make all caches behave the same
1 parent fa15514 commit 6bf12af

File tree

9 files changed

+220
-152
lines changed

9 files changed

+220
-152
lines changed

src/TensorKit.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,9 @@ export scalar, add!, contract!
8888
# truncation schemes
8989
export notrunc, truncerr, truncdim, truncspace, truncbelow
9090

91+
# cache management
92+
export empty_globalcaches!
93+
9194
# Imports
9295
#---------
9396
using TupleTools
@@ -134,6 +137,7 @@ using PackageExtensionCompat
134137
# Auxiliary files
135138
#-----------------
136139
include("auxiliary/auxiliary.jl")
140+
include("auxiliary/caches.jl")
137141
include("auxiliary/dicts.jl")
138142
include("auxiliary/iterators.jl")
139143
include("auxiliary/linalg.jl")

src/auxiliary/caches.jl

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
const GLOBAL_CACHES = Any[]
2+
function empty_globalcaches!()
3+
foreach(empty!, GLOBAL_CACHES)
4+
return nothing
5+
end
6+
7+
abstract type CacheStyle end
8+
struct NoCache <: CacheStyle end
9+
struct TaskLocalCache{D<:AbstractDict} <: CacheStyle end
10+
struct GlobalLRUCache <: CacheStyle end
11+
12+
const DEFAULT_GLOBALCACHE_SIZE = Ref(10^5)
13+
14+
function CacheStyle(args...)
15+
return GlobalLRUCache()
16+
end
17+
18+
macro cached(ex)
19+
Meta.isexpr(ex, :function) ||
20+
error("cached macro can only be used on function definitions")
21+
fcall = ex.args[1]
22+
if Meta.isexpr(fcall, :where)
23+
hasparams = true
24+
params = fcall.args[2:end]
25+
fcall = fcall.args[1]
26+
else
27+
hasparams = false
28+
end
29+
if Meta.isexpr(fcall, :(::))
30+
typed = true
31+
typeex = fcall.args[2]
32+
fcall = fcall.args[1]
33+
else
34+
typed = false
35+
end
36+
Meta.isexpr(fcall, :call) ||
37+
error("cached macro can only be used on function definitions")
38+
fname = fcall.args[1]
39+
fargs = fcall.args[2:end]
40+
fargnames = map(fargs) do arg
41+
if Meta.isexpr(arg, :(::))
42+
return arg.args[1]
43+
else
44+
return arg
45+
end
46+
end
47+
_fbody = ex.args[2]
48+
49+
# actual implenetation, with underscore name
50+
_fname = Symbol(:_, fname)
51+
_fcall = Expr(:call, _fname, fargs...)
52+
if hasparams
53+
_fcall = Expr(:where, _fcall, params...)
54+
end
55+
_fex = Expr(:function, _fcall, _fbody)
56+
57+
# implementation that chooses the cache style
58+
newfcall = fcall
59+
if hasparams
60+
newfcall = Expr(:where, newfcall, params...)
61+
end
62+
cachestylevar = gensym(:cachestyle)
63+
cachestyleex = Expr(:(=), cachestylevar,
64+
Expr(:call, :CacheStyle, fname, fargnames...))
65+
newfbody = Expr(:block,
66+
cachestyleex,
67+
Expr(:call, fname, fargnames..., cachestylevar))
68+
newfex = Expr(:function, newfcall, newfbody)
69+
70+
# nocache implementation
71+
fnocachecall = Expr(:call, fname, fargs..., :(::NoCache))
72+
if hasparams
73+
fnocachecall = Expr(:where, fnocachecall, params...)
74+
end
75+
fnocachebody = Expr(:call, _fname, fargnames...)
76+
if typed
77+
T = gensym(:T)
78+
fnocachebody = Expr(:block, Expr(:(=), T, typeex), Expr(:(::), fnocachebody, T))
79+
end
80+
fnocacheex = Expr(:function, fnocachecall, fnocachebody)
81+
82+
# tasklocal cache implementation
83+
Dvar = gensym(:D)
84+
flocalcachecall = Expr(:call, fname, fargs..., :(::TaskLocalCache{$Dvar}))
85+
if hasparams
86+
flocalcachecall = Expr(:where, flocalcachecall, params..., Dvar)
87+
else
88+
flocalcachecall = Expr(:where, flocalcachecall, Dvar)
89+
end
90+
localcachename = Symbol(:_tasklocal_, fname, :_cache)
91+
cachevar = gensym(:cache)
92+
getlocalcacheex = :($cachevar::$Dvar = get!(task_local_storage(), $localcachename) do
93+
return $Dvar()
94+
end)
95+
valvar = gensym(:val)
96+
if length(fargnames) == 1
97+
key = fargnames[1]
98+
else
99+
key = Expr(:tuple, fargnames...)
100+
end
101+
getvalex = :(get!($cachevar, $key) do
102+
return $_fname($(fargnames...))
103+
end)
104+
if typed
105+
T = gensym(:T)
106+
flocalcachebody = Expr(:block,
107+
getlocalcacheex,
108+
Expr(:(=), T, typeex),
109+
Expr(:(=), Expr(:(::), valvar, T), getvalex),
110+
Expr(:return, valvar))
111+
else
112+
flocalcachebody = Expr(:block,
113+
getlocalcacheex,
114+
Expr(:(=), valvar, getvalex),
115+
Expr(:return, valvar))
116+
end
117+
flocalcacheex = Expr(:function, flocalcachecall, flocalcachebody)
118+
119+
# # global cache implementation
120+
fglobalcachecall = Expr(:call, fname, fargs..., :(::GlobalLRUCache))
121+
if hasparams
122+
fglobalcachecall = Expr(:where, fglobalcachecall, params...)
123+
end
124+
globalcachename = Symbol(:GLOBAL_, uppercase(string(fname)), :_CACHE)
125+
getglobalcachex = Expr(:(=), cachevar, globalcachename)
126+
if typed
127+
T = gensym(:T)
128+
fglobalcachebody = Expr(:block,
129+
getglobalcachex,
130+
Expr(:(=), T, typeex),
131+
Expr(:(=), Expr(:(::), valvar, T), getvalex),
132+
Expr(:return, valvar))
133+
else
134+
fglobalcachebody = Expr(:block,
135+
getglobalcachex,
136+
Expr(:(=), valvar, getvalex),
137+
Expr(:return, valvar))
138+
end
139+
fglobalcacheex = Expr(:function, fglobalcachecall, fglobalcachebody)
140+
fglobalcachedef = Expr(:const,
141+
Expr(:(=), globalcachename,
142+
:(LRU{Any,Any}(; maxsize=DEFAULT_GLOBALCACHE_SIZE[]))))
143+
fglobalcacheregister = Expr(:call, :push!, :GLOBAL_CACHES, globalcachename)
144+
145+
# # total expression
146+
return esc(Expr(:block, _fex, newfex, fnocacheex, flocalcacheex,
147+
fglobalcachedef, fglobalcacheregister, fglobalcacheex))
148+
end

src/fusiontrees/manipulations.jl

Lines changed: 37 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -526,10 +526,6 @@ function _recursive_repartition(f₁::FusionTree{I,N₁},
526526
end
527527
end
528528

529-
# transpose double fusion tree
530-
const transposecache = LRU{Any,Any}(; maxsize=10^5)
531-
const usetransposecache = Ref{Bool}(true)
532-
533529
"""
534530
transpose(f₁::FusionTree{I}, f₂::FusionTree{I},
535531
p1::NTuple{N₁, Int}, p2::NTuple{N₂, Int}) where {I, N₁, N₂}
@@ -548,28 +544,24 @@ function Base.transpose(f₁::FusionTree{I}, f₂::FusionTree{I},
548544
@assert length(f₁) + length(f₂) == N
549545
p = linearizepermutation(p1, p2, length(f₁), length(f₂))
550546
@assert iscyclicpermutation(p)
551-
if usetransposecache[]
552-
T = sectorscalartype(I)
553-
F₁ = fusiontreetype(I, N₁)
554-
F₂ = fusiontreetype(I, N₂)
555-
D = fusiontreedict(I){Tuple{F₁,F₂},T}
556-
return _get_transpose(D, (f₁, f₂, p1, p2))
557-
else
558-
return _transpose((f₁, f₂, p1, p2))
559-
end
547+
return fstranspose((f₁, f₂, p1, p2))
560548
end
561549

562-
@noinline function _get_transpose(::Type{D}, @nospecialize(key)) where {D}
563-
d::D = get!(transposecache, key) do
564-
return _transpose(key)
565-
end
566-
return d
567-
end
550+
const FSTransposeKey{I<:Sector,N₁,N₂} = Tuple{<:FusionTree{I},<:FusionTree{I},
551+
IndexTuple{N₁},IndexTuple{N₂}}
568552

569-
const TransposeKey{I<:Sector,N₁,N₂} = Tuple{<:FusionTree{I},<:FusionTree{I},
570-
IndexTuple{N₁},IndexTuple{N₂}}
553+
Base.@pure function _fsdicttype(I, N₁, N₂)
554+
F₁ = fusiontreetype(I, N₁)
555+
F₂ = fusiontreetype(I, N₂)
556+
T = sectorscalartype(I)
557+
return fusiontreedict(I){Tuple{F₁,F₂},T}
558+
end
571559

572-
function _transpose((f₁, f₂, p1, p2)::TransposeKey{I,N₁,N₂}) where {I<:Sector,N₁,N₂}
560+
@cached function fstranspose(key::FSTransposeKey{I,N₁,N₂})::_fsdicttype(I, N₁,
561+
N₂) where {I<:Sector,
562+
N₁,
563+
N₂}
564+
f₁, f₂, p1, p2 = key
573565
N = N₁ + N₂
574566
p = linearizepermutation(p1, p2, length(f₁), length(f₂))
575567
newtrees = repartition(f₁, f₂, N₁)
@@ -611,6 +603,14 @@ function _transpose((f₁, f₂, p1, p2)::TransposeKey{I,N₁,N₂}) where {I<:S
611603
return newtrees
612604
end
613605

606+
function CacheStyle(::typeof(fstranspose), k::FSTransposeKey{I}) where {I<:Sector}
607+
if FusionStyle(I) isa UniqueFusion
608+
return NoCache()
609+
else
610+
return GlobalLRUCache()
611+
end
612+
end
613+
614614
# COMPOSITE DUALITY MANIPULATIONS PART 2: Planar traces
615615
#-------------------------------------------------------------------
616616
# -> composite manipulations that depend on the duality (rigidity) and pivotal structure
@@ -1015,10 +1015,6 @@ function permute(f::FusionTree{I,N}, p::NTuple{N,Int}) where {I<:Sector,N}
10151015
end
10161016

10171017
# braid double fusion tree
1018-
const braidcache = LRU{Any,Any}(; maxsize=10^5)
1019-
const usebraidcache_abelian = Ref{Bool}(false)
1020-
const usebraidcache_nonabelian = Ref{Bool}(true)
1021-
10221018
"""
10231019
braid(f₁::FusionTree{I}, f₂::FusionTree{I},
10241020
levels1::IndexTuple, levels2::IndexTuple,
@@ -1043,42 +1039,15 @@ function braid(f₁::FusionTree{I}, f₂::FusionTree{I},
10431039
@assert length(f₁) + length(f₂) == N₁ + N₂
10441040
@assert length(f₁) == length(levels1) && length(f₂) == length(levels2)
10451041
@assert TupleTools.isperm((p1..., p2...))
1046-
if FusionStyle(f₁) isa UniqueFusion &&
1047-
BraidingStyle(f₁) isa SymmetricBraiding
1048-
if usebraidcache_abelian[]
1049-
T = Int # do we hardcode this ?
1050-
F₁ = fusiontreetype(I, N₁)
1051-
F₂ = fusiontreetype(I, N₂)
1052-
D = SingletonDict{Tuple{F₁,F₂},T}
1053-
return _get_braid(D, (f₁, f₂, levels1, levels2, p1, p2))
1054-
else
1055-
return _braid((f₁, f₂, levels1, levels2, p1, p2))
1056-
end
1057-
else
1058-
if usebraidcache_nonabelian[]
1059-
T = sectorscalartype(I)
1060-
F₁ = fusiontreetype(I, N₁)
1061-
F₂ = fusiontreetype(I, N₂)
1062-
D = FusionTreeDict{Tuple{F₁,F₂},T}
1063-
return _get_braid(D, (f₁, f₂, levels1, levels2, p1, p2))
1064-
else
1065-
return _braid((f₁, f₂, levels1, levels2, p1, p2))
1066-
end
1067-
end
1042+
return fsbraid((f₁, f₂, levels1, levels2, p1, p2))
10681043
end
1044+
const FSBraidKey{I<:Sector,N₁,N₂} = Tuple{<:FusionTree{I},<:FusionTree{I},
1045+
IndexTuple,IndexTuple,
1046+
IndexTuple{N₁},IndexTuple{N₂}}
10691047

1070-
@noinline function _get_braid(::Type{D}, @nospecialize(key)) where {D}
1071-
d::D = get!(braidcache, key) do
1072-
return _braid(key)
1073-
end
1074-
return d
1075-
end
1076-
1077-
const BraidKey{I<:Sector,N₁,N₂} = Tuple{<:FusionTree{I},<:FusionTree{I},
1078-
IndexTuple,IndexTuple,
1079-
IndexTuple{N₁},IndexTuple{N₂}}
1080-
1081-
function _braid((f₁, f₂, l1, l2, p1, p2)::BraidKey{I,N₁,N₂}) where {I<:Sector,N₁,N₂}
1048+
@cached function fsbraid(key::FSBraidKey{I,N₁,N₂})::_fsdicttype(I, N₁,
1049+
N₂) where {I<:Sector,N₁,N₂}
1050+
(f₁, f₂, l1, l2, p1, p2) = key
10821051
p = linearizepermutation(p1, p2, length(f₁), length(f₂))
10831052
levels = (l1..., reverse(l2)...)
10841053
local newtrees
@@ -1097,6 +1066,14 @@ function _braid((f₁, f₂, l1, l2, p1, p2)::BraidKey{I,N₁,N₂}) where {I<:S
10971066
return newtrees
10981067
end
10991068

1069+
function CacheStyle(::typeof(fsbraid), k::FSBraidKey{I}) where {I<:Sector}
1070+
if FusionStyle(I) isa UniqueFusion
1071+
return NoCache()
1072+
else
1073+
return GlobalLRUCache()
1074+
end
1075+
end
1076+
11001077
"""
11011078
permute(f₁::FusionTree{I}, f₂::FusionTree{I},
11021079
p1::NTuple{N₁, Int}, p2::NTuple{N₂, Int}) where {I, N₁, N₂}

src/spaces/homspace.jl

Lines changed: 11 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -239,18 +239,17 @@ struct FusionBlockStructure{I,N,F₁,F₂}
239239
fusiontreeindices::FusionTreeDict{Tuple{F₁,F₂},Int}
240240
end
241241

242-
abstract type CacheStyle end
243-
struct NoCache <: CacheStyle end
244-
struct TaskLocalCache{D<:AbstractDict} <: CacheStyle end
245-
struct GlobalLRUCache <: CacheStyle end
246-
247-
function CacheStyle(I::Type{<:Sector})
248-
return GlobalLRUCache()
242+
function fusionblockstructuretype(W::HomSpace)
243+
N₁ = length(codomain(W))
244+
N₂ = length(domain(W))
245+
N = N₁ + N₂
246+
I = sectortype(W)
247+
F₁ = fusiontreetype(I, N₁)
248+
F₂ = fusiontreetype(I, N₂)
249+
return FusionBlockStructure{I,N,F₁,F₂}
249250
end
250251

251-
fusionblockstructure(W::HomSpace) = fusionblockstructure(W, CacheStyle(sectortype(W)))
252-
253-
function fusionblockstructure(W::HomSpace, ::NoCache)
252+
@cached function fusionblockstructure(W::HomSpace)::fusionblockstructuretype(W)
254253
codom = codomain(W)
255254
dom = domain(W)
256255
N₁ = length(codom)
@@ -323,36 +322,8 @@ function _subblock_strides(subsz, sz, str)
323322
return Strided.StridedViews._computereshapestrides(subsz, sz_simplify...)
324323
end
325324

326-
function fusionblockstructure(W::HomSpace, ::TaskLocalCache{D}) where {D}
327-
cache::D = get!(task_local_storage(), :_local_tensorstructure_cache) do
328-
return D()
329-
end
330-
N₁ = length(codomain(W))
331-
N₂ = length(domain(W))
332-
N = N₁ + N₂
333-
I = sectortype(W)
334-
F₁ = fusiontreetype(I, N₁)
335-
F₂ = fusiontreetype(I, N₂)
336-
structure::FusionBlockStructure{I,N,F₁,F₂} = get!(cache, W) do
337-
return fusionblockstructure(W, NoCache())
338-
end
339-
return structure
340-
end
341-
342-
const GLOBAL_FUSIONBLOCKSTRUCTURE_CACHE = LRU{Any,Any}(; maxsize=10^4)
343-
# 10^4 different tensor spaces should be enough for most purposes
344-
function fusionblockstructure(W::HomSpace, ::GlobalLRUCache)
345-
cache = GLOBAL_FUSIONBLOCKSTRUCTURE_CACHE
346-
N₁ = length(codomain(W))
347-
N₂ = length(domain(W))
348-
N = N₁ + N₂
349-
I = sectortype(W)
350-
F₁ = fusiontreetype(I, N₁)
351-
F₂ = fusiontreetype(I, N₂)
352-
structure::FusionBlockStructure{I,N,F₁,F₂} = get!(cache, W) do
353-
return fusionblockstructure(W, NoCache())
354-
end
355-
return structure
325+
function CacheStyle(::typeof(fusionblockstructure), W::HomSpace)
326+
return GlobalLRUCache()
356327
end
357328

358329
# Diagonal ranges

0 commit comments

Comments
 (0)