Skip to content

Add (fast) full update #222

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 18 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/PEPSKit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ include("algorithms/truncation/bond_truncation.jl")
include("algorithms/time_evolution/evoltools.jl")
include("algorithms/time_evolution/simpleupdate.jl")
include("algorithms/time_evolution/simpleupdate3site.jl")
include("algorithms/time_evolution/fullupdate.jl")

include("algorithms/toolbox.jl")
include("algorithms/correlators.jl")
Expand Down Expand Up @@ -99,6 +100,7 @@ export fixedpoint
export absorb_weight
export ALSTruncation, FullEnvTruncation
export su_iter, su3site_iter, simpleupdate, SimpleUpdate
export fu_iter, fu_iter2, FullUpdate

export InfiniteSquareNetwork
export InfinitePartitionFunction
Expand Down
10 changes: 10 additions & 0 deletions src/algorithms/contractions/bondenv/als_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -206,3 +206,13 @@ function _solve_ab(
x1, info = linsolve(f, Sx, x0, 0, 1)
return x1, info
end

function _solve_ab_pinv!(
Rx::AbstractTensorMap{T,S,2,2}, Sx::AbstractTensorMap{T,S,2,1}; kwargs...
) where {T<:Number,S<:ElementarySpace}
Rx_inv, ϵ = _pinv!(copy(Rx); kwargs...)
is = filter(i -> isdual(codomain(Rx_inv, i)), 1:numout(Rx_inv))
x = Rx_inv * Sx
twist!(x, is)
return x, ϵ
end
29 changes: 25 additions & 4 deletions src/algorithms/ctmrg/sequential.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,11 @@
CTMRG_SYMBOLS[:sequential] = SequentialCTMRG

"""
ctmrg_leftmove(col::Int, network, env::CTMRGEnv, alg::SequentialCTMRG)
ctmrg_leftmove(col::Int, network, env::CTMRGEnv, alg::ProjectorAlgorithm)

Perform sequential CTMRG left move on the `col`-th column.
"""
function ctmrg_leftmove(col::Int, network, env::CTMRGEnv, alg::SequentialCTMRG)
function ctmrg_leftmove(col::Int, network, env::CTMRGEnv, alg::ProjectorAlgorithm)
#=
----> left move
C1 ← T1 ← r-1
Expand All @@ -52,17 +52,38 @@
C4 → T3 → r+1
c-1 c
=#
projectors, info = sequential_projectors(col, network, env, alg.projector_alg)
projectors, info = sequential_projectors(col, network, env, alg)
env = renormalize_sequentially(col, projectors, network, env)
return env, info
end

"""
ctmrg_rightmove(col::Int, network, env::CTMRGEnv, alg::ProjectorAlgorithm)

Perform sequential CTMRG right move on the `col`-th column.
"""
function ctmrg_rightmove(col::Int, network, env::CTMRGEnv, alg::ProjectorAlgorithm)

Check warning on line 65 in src/algorithms/ctmrg/sequential.jl

View check run for this annotation

Codecov / codecov/patch

src/algorithms/ctmrg/sequential.jl#L65

Added line #L65 was not covered by tests
#=
right move <---
←-- T1 ← C2 r-1
‖ ↑
=== M' = T2 r
‖ ↑
--→ T3 → C3 r+1
c c+1
=#
Nc = size(network)[2]
@assert 1 <= col <= Nc
env, info = ctmrg_leftmove(Nc + 1 - col, rot180(network), rot180(env), alg)
return rot180(env), info

Check warning on line 78 in src/algorithms/ctmrg/sequential.jl

View check run for this annotation

Codecov / codecov/patch

src/algorithms/ctmrg/sequential.jl#L75-L78

Added lines #L75 - L78 were not covered by tests
end

function ctmrg_iteration(network, env::CTMRGEnv, alg::SequentialCTMRG)
truncation_error = zero(real(scalartype(network)))
condition_number = zero(real(scalartype(network)))
for _ in 1:4 # rotate
for col in 1:size(network, 2) # left move column-wise
env, info = ctmrg_leftmove(col, network, env, alg)
env, info = ctmrg_leftmove(col, network, env, alg.projector_alg)
truncation_error = max(truncation_error, info.truncation_error)
condition_number = max(condition_number, info.condition_number)
end
Expand Down
8 changes: 3 additions & 5 deletions src/algorithms/time_evolution/evoltools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,20 +77,18 @@ to get the reduced tensors
```
2 1
| |
5 - A 3 ====> 4 - X ← 2 1 ← a 3
5 - A - 3 ====> 4 - X ← 2 1 ← a - 3
| ↘ | ↘
4 1 3 2

2 1
| |
5 B - 3 ====> 1 b → 3 4 → Y - 2
5 - B - 3 ====> 1 - b → 3 4 → Y - 2
| ↘ ↘ |
4 1 2 3
```
"""
function _qr_bond(A::PEPSTensor, B::PEPSTensor)
# TODO: relax dual requirement on the bonds
@assert isdual(space(A, 3)) # currently only allow A ← B
X, a = leftorth(A, ((2, 4, 5), (1, 3)))
Y, b = leftorth(B, ((2, 3, 4), (1, 5)))
@assert !isdual(space(a, 1))
Expand Down Expand Up @@ -125,7 +123,7 @@ $(SIGNATURES)

Apply 2-site `gate` on the reduced matrices `a`, `b`
```
-1← a - 3 - b ← -4
-1← a -- 3 -- b ← -4
↓ ↓
1 2
↓ ↓
Expand Down
155 changes: 155 additions & 0 deletions src/algorithms/time_evolution/fullupdate.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
"""
$(TYPEDEF)

Algorithm struct for full update (FU) of infinite PEPS.

## Fields

$(TYPEDFIELDS)
"""
@kwdef struct FullUpdate
"Time evolution step, such that the Trotter gate is exp(-dt * Hᵢⱼ).
Use imaginary `dt` for real time evolution."
dt::Number
"Number of evolution steps without fully reconverging the environment."
niter::Int
"Fix gauge of bond environment."
fixgauge::Bool = true
"Bond truncation algorithm after applying time evolution gate."
opt_alg::Union{ALSTruncation,FullEnvTruncation} = ALSTruncation(;
trscheme=truncerr(1e-10)
)
"CTMRG algorithm to reconverge environment.
Its `projector_alg` is also used for the fast update
of the environment after each FU iteration."
ctm_alg::CTMRGAlgorithm = SequentialCTMRG(;
tol=1e-9,
maxiter=20,
verbosity=1,
trscheme=truncerr(1e-10),
projector_alg=:fullinfinite,
)
end

"""
Full update for the bond between `[row, col]` and `[row, col+1]`.
"""
function _fu_xbond!(

Check warning on line 37 in src/algorithms/time_evolution/fullupdate.jl

View check run for this annotation

Codecov / codecov/patch

src/algorithms/time_evolution/fullupdate.jl#L37

Added line #L37 was not covered by tests
row::Int,
col::Int,
gate::AbstractTensorMap{T,S,2,2},
peps::InfinitePEPS,
env::CTMRGEnv,
alg::FullUpdate,
) where {T<:Number,S<:ElementarySpace}
cp1 = _next(col, size(peps, 2))
A, B = peps[row, col], peps[row, cp1]
X, a, b, Y = _qr_bond(A, B)

Check warning on line 47 in src/algorithms/time_evolution/fullupdate.jl

View check run for this annotation

Codecov / codecov/patch

src/algorithms/time_evolution/fullupdate.jl#L45-L47

Added lines #L45 - L47 were not covered by tests
# positive/negative-definite approximant: benv = ± Z Z†
benv = bondenv_fu(row, col, X, Y, env)
Z = positive_approx(benv)
@debug "cond(benv) before gauge fix: $(LinearAlgebra.cond(Z' * Z))"

Check warning on line 51 in src/algorithms/time_evolution/fullupdate.jl

View check run for this annotation

Codecov / codecov/patch

src/algorithms/time_evolution/fullupdate.jl#L49-L51

Added lines #L49 - L51 were not covered by tests
# fix gauge
if alg.fixgauge
Z, a, b, (Linv, Rinv) = fixgauge_benv(Z, a, b)
X, Y = _fixgauge_benvXY(X, Y, Linv, Rinv)
@debug "cond(L) = $(LinearAlgebra.cond(Linv)); cond(R): $(LinearAlgebra.cond(Rinv))"
@debug "cond(benv) after gauge fix: $(LinearAlgebra.cond(Z' * Z))"

Check warning on line 57 in src/algorithms/time_evolution/fullupdate.jl

View check run for this annotation

Codecov / codecov/patch

src/algorithms/time_evolution/fullupdate.jl#L53-L57

Added lines #L53 - L57 were not covered by tests
end
benv = Z' * Z

Check warning on line 59 in src/algorithms/time_evolution/fullupdate.jl

View check run for this annotation

Codecov / codecov/patch

src/algorithms/time_evolution/fullupdate.jl#L59

Added line #L59 was not covered by tests
# apply gate
need_flip = isdual(space(b, 1))
a, s, b, = _apply_gate(a, b, gate, truncerr(1e-15))
a, b = absorb_s(a, s, b)

Check warning on line 63 in src/algorithms/time_evolution/fullupdate.jl

View check run for this annotation

Codecov / codecov/patch

src/algorithms/time_evolution/fullupdate.jl#L61-L63

Added lines #L61 - L63 were not covered by tests
# optimize a, b
a, s, b, info = bond_truncate(a, b, benv, alg.opt_alg)
a, b = absorb_s(a, s, b)

Check warning on line 66 in src/algorithms/time_evolution/fullupdate.jl

View check run for this annotation

Codecov / codecov/patch

src/algorithms/time_evolution/fullupdate.jl#L65-L66

Added lines #L65 - L66 were not covered by tests
# bond truncation is done with arrow `a ← b`.
# now revert back to `a → b` when needed.
if need_flip
a, b = flip(a, 3), flip(b, 1)

Check warning on line 70 in src/algorithms/time_evolution/fullupdate.jl

View check run for this annotation

Codecov / codecov/patch

src/algorithms/time_evolution/fullupdate.jl#L69-L70

Added lines #L69 - L70 were not covered by tests
end
a /= norm(a, Inf)
b /= norm(b, Inf)
A, B = _qr_bond_undo(X, a, b, Y)
peps.A[row, col] = A / norm(A, Inf)
peps.A[row, cp1] = B / norm(B, Inf)
return s, info

Check warning on line 77 in src/algorithms/time_evolution/fullupdate.jl

View check run for this annotation

Codecov / codecov/patch

src/algorithms/time_evolution/fullupdate.jl#L72-L77

Added lines #L72 - L77 were not covered by tests
end

"""
Update all horizontal bonds in the c-th column
(i.e. `(r,c) (r,c+1)` for all `r = 1, ..., Nr`).
To update rows, rotate the network clockwise by 90 degrees.
The iPEPS `peps` is modified in place.
"""
function _fu_column!(

Check warning on line 86 in src/algorithms/time_evolution/fullupdate.jl

View check run for this annotation

Codecov / codecov/patch

src/algorithms/time_evolution/fullupdate.jl#L86

Added line #L86 was not covered by tests
col::Int, gate::LocalOperator, peps::InfinitePEPS, env::CTMRGEnv, alg::FullUpdate
)
Nr, Nc = size(peps)
@assert 1 <= col <= Nc
fid = 1.0
wts_col = Vector{PEPSWeight}(undef, Nr)
for row in 1:Nr
term = get_gateterm(gate, (CartesianIndex(row, col), CartesianIndex(row, col + 1)))
wts_col[row], info = _fu_xbond!(row, col, term, peps, env, alg)
fid = min(fid, info.fid)
end

Check warning on line 97 in src/algorithms/time_evolution/fullupdate.jl

View check run for this annotation

Codecov / codecov/patch

src/algorithms/time_evolution/fullupdate.jl#L89-L97

Added lines #L89 - L97 were not covered by tests
# update CTMRGEnv
network = InfiniteSquareNetwork(peps)
env2, info = ctmrg_leftmove(col, network, env, alg.ctm_alg.projector_alg)
env2, info = ctmrg_rightmove(_next(col, Nc), network, env2, alg.ctm_alg.projector_alg)
for c in [col, _next(col, Nc)]
env.corners[:, :, c] = env2.corners[:, :, c]
env.edges[:, :, c] = env2.edges[:, :, c]
end
return wts_col, fid

Check warning on line 106 in src/algorithms/time_evolution/fullupdate.jl

View check run for this annotation

Codecov / codecov/patch

src/algorithms/time_evolution/fullupdate.jl#L99-L106

Added lines #L99 - L106 were not covered by tests
end

"""
One round of full update on the input InfinitePEPS `peps` and its CTMRGEnv `env`.

Reference: Physical Review B 92, 035142 (2015)
"""
function fu_iter(gate::LocalOperator, peps::InfinitePEPS, env::CTMRGEnv, alg::FullUpdate)
Nr, Nc = size(peps)
fidmin = 1.0
peps2, env2 = deepcopy(peps), deepcopy(env)
wts = Array{PEPSWeight}(undef, 2, Nr, Nc)
for i in 1:4
N = size(peps2, 2)
for col in 1:N
wts_col, fid_col = _fu_column!(col, gate, peps2, env2, alg)
fidmin = min(fidmin, fid_col)

Check warning on line 123 in src/algorithms/time_evolution/fullupdate.jl

View check run for this annotation

Codecov / codecov/patch

src/algorithms/time_evolution/fullupdate.jl#L114-L123

Added lines #L114 - L123 were not covered by tests
# assign the weights to the un-rotated `wts`
if i == 1
wts[1, :, col] = wts_col
elseif i == 2
wts[2, _next(col, N), :] = reverse(wts_col)
elseif i == 3
wts[1, :, mod1(N - col, N)] = reverse(wts_col)

Check warning on line 130 in src/algorithms/time_evolution/fullupdate.jl

View check run for this annotation

Codecov / codecov/patch

src/algorithms/time_evolution/fullupdate.jl#L125-L130

Added lines #L125 - L130 were not covered by tests
else
wts[2, N + 1 - col, :] = wts_col

Check warning on line 132 in src/algorithms/time_evolution/fullupdate.jl

View check run for this annotation

Codecov / codecov/patch

src/algorithms/time_evolution/fullupdate.jl#L132

Added line #L132 was not covered by tests
end
end
gate, peps2, env2 = rotl90(gate), rotl90(peps2), rotl90(env2)
end
return peps2, env2, SUWeight(collect(wt for wt in wts)), fidmin

Check warning on line 137 in src/algorithms/time_evolution/fullupdate.jl

View check run for this annotation

Codecov / codecov/patch

src/algorithms/time_evolution/fullupdate.jl#L134-L137

Added lines #L134 - L137 were not covered by tests
end

"""
Full update an infinite PEPS with nearest neighbor Hamiltonian.
"""
function fu_iter2(ham::LocalOperator, peps::InfinitePEPS, env::CTMRGEnv, alg::FullUpdate)

Check warning on line 143 in src/algorithms/time_evolution/fullupdate.jl

View check run for this annotation

Codecov / codecov/patch

src/algorithms/time_evolution/fullupdate.jl#L143

Added line #L143 was not covered by tests
# Each NN bond is updated twice in _fu_iter,
# thus `dt` is divided by 2 when exponentiating `ham`.
gate = get_expham(ham, alg.dt / 2)
wts, fidmin = nothing, 1.0
for it in 1:(alg.niter)
peps, env, wts, fid = fu_iter(gate, peps, env, alg)
fidmin = min(fidmin, fid)
end

Check warning on line 151 in src/algorithms/time_evolution/fullupdate.jl

View check run for this annotation

Codecov / codecov/patch

src/algorithms/time_evolution/fullupdate.jl#L146-L151

Added lines #L146 - L151 were not covered by tests
# reconverge environment
env, = leading_boundary(env, peps, alg.ctm_alg)
return peps, env, wts, fidmin

Check warning on line 154 in src/algorithms/time_evolution/fullupdate.jl

View check run for this annotation

Codecov / codecov/patch

src/algorithms/time_evolution/fullupdate.jl#L153-L154

Added lines #L153 - L154 were not covered by tests
end
15 changes: 13 additions & 2 deletions src/algorithms/truncation/bond_truncation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,14 @@ The truncation algorithm can be constructed from the following keyword arguments
* `trscheme::TruncationScheme`: SVD truncation scheme when initilizing the truncated tensors connected by the bond.
* `maxiter::Int=50` : Maximal number of ALS iterations.
* `tol::Float64=1e-15` : ALS converges when fidelity change between two FET iterations is smaller than `tol`.
* `use_pinv::Bool=true`: Use pseudo-inverse (instead of `KrylovKit.linsolve`) to solve linear equations in ALS itertions.
* `check_interval::Int=0` : Set number of iterations to print information. Output is suppressed when `check_interval <= 0`.
"""
@kwdef struct ALSTruncation
trscheme::TruncationScheme
maxiter::Int = 50
tol::Float64 = 1e-15
use_pinv::Bool = true
check_interval::Int = 0
end

Expand Down Expand Up @@ -103,11 +105,20 @@ function bond_truncate(
=#
Ra = _tensor_Ra(benv, b)
Sa = _tensor_Sa(benv, b, a2b2)
a, info_a = _solve_ab(Ra, Sa, a)
a, info_a = if alg.use_pinv
_solve_ab_pinv!(Ra, Sa; trunc=truncerr(1e-10))
else
_solve_ab(Ra, Sa, a)
end
# Fixing `a`, solve for `b` from `Rb b = Sb`
Rb = _tensor_Rb(benv, a)
Sb = _tensor_Sb(benv, a, a2b2)
b, info_b = _solve_ab(Rb, Sb, b)
b, info_b = if alg.use_pinv
_solve_ab_pinv!(Rb, Sb; trunc=truncerr(1e-10))
else
_solve_ab(Rb, Sb, b)
end
@debug "Bond truncation info" info_a info_b
ab = _combine_ab(a, b)
cost = cost_function_als(benv, ab, a2b2)
fid = fidelity(benv, ab, a2b2)
Expand Down
6 changes: 6 additions & 0 deletions src/utility/svd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -613,3 +613,9 @@ function svd_pullback!(
end
return ΔA
end

# Calculate the pseudo-inverse using SVD
function _pinv!(a::AbstractTensorMap; kwargs...)
u, s, vh, ϵ = tsvd!(a; kwargs...)
return vh' * sdiag_pow(s, -1) * u', ϵ
end
3 changes: 2 additions & 1 deletion test/bondenv/bond_truncate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ for Vbondl in (Vint, Vint'), Vbondr in (Vint, Vint')
@info "Fidelity of simple SVD truncation = $fid0.\n"
ss = Dict{String,DiagonalTensorMap}()
for (label, alg) in (
("ALS", ALSTruncation(; trscheme, maxiter, check_interval)),
("ALS", ALSTruncation(; trscheme, maxiter, check_interval, use_pinv=false)),
("ALS (pinv)", ALSTruncation(; trscheme, maxiter, check_interval, use_pinv=true)),
("FET", FullEnvTruncation(; trscheme, maxiter, check_interval, trunc_init=false)),
)
a1, ss[label], b1, info = PEPSKit.bond_truncate(a2, b2, benv, alg)
Expand Down
8 changes: 8 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,14 @@ end
include("timeevol/sitedep_truncation.jl")
end
end
if GROUP == "ALL" || GROUP == "TIMEEVOL"
@time @safetestset "Cluster truncation with projectors" begin
include("timeevol/cluster_projectors.jl")
end
@time @safetestset "Transverse Field Ising model: real-time full update " begin
include("timeevol/tf_ising_fu.jl")
end
end
if GROUP == "ALL" || GROUP == "UTILITY"
@time @safetestset "LocalOperator" begin
include("utility/localoperator.jl")
Expand Down
Loading