Skip to content

Rewrite special-purpose constructors to use in-place operations #232

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/TensorKit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ export blocksectors, blockdim, block, blocks
export randisometry, randisometry!, rand, rand!, randn, randn!

# special purpose constructors
export zero, one, one!, id, isomorphism, unitary, isometry
export zero, one, one!, id, id!, isomorphism, isomorphism!, unitary, unitary!, isometry,
isometry!

# reexport most of VectorInterface and some more tensor algebra
export zerovector, zerovector!, zerovector!!, scale, scale!, scale!!, add, add!, add!!
Expand Down
59 changes: 37 additions & 22 deletions src/tensors/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,24 +64,31 @@
return t
end

"""
@doc """
id([T::Type=Float64,] V::TensorSpace) -> TensorMap
id!(t::AbstractTensorMap) -> AbstractTensorMap

Construct the identity endomorphism on space `V`, i.e. return a `t::TensorMap` with
`domain(t) == codomain(t) == V`, where either `scalartype(t) = T` if `T` is a `Number` type
or `storagetype(t) = T` if `T` is a `DenseVector` type.
"""

See also [`one!`](@ref).
""" id, id!

id(V::TensorSpace) = id(Float64, V)
function id(A::Type, V::TensorSpace{S}) where {S}
W = V ← V
N = length(codomain(W))
return one!(tensormaptype(S, N, N, A)(undef, W))
dst = tensormaptype(S, N, N, A)(undef, W)
return id!(dst)
end
const id! = one!

"""
@doc """
isomorphism([T::Type=Float64,] codomain::TensorSpace, domain::TensorSpace) -> TensorMap
isomorphism([T::Type=Float64,] codomain ← domain) -> TensorMap
isomorphism([T::Type=Float64,] domain → codomain) -> TensorMap
isomorphism!(t::AbstractTensorMap) -> AbstractTensorMap

Construct a specific isomorphism between the codomain and the domain, i.e. return a
`t::TensorMap` where either `scalartype(t) = T` if `T` is a `Number` type or
Expand All @@ -93,21 +100,22 @@
that `isomorphism(cod, dom) == inv(isomorphism(dom, cod))`.

See also [`unitary`](@ref) when `InnerProductStyle(cod) === EuclideanInnerProduct()`.
"""
function isomorphism(A::Type, V::TensorMapSpace{S,N₁,N₂}) where {S,N₁,N₂}
codomain(V) ≅ domain(V) ||
throw(SpaceMismatch("codomain and domain are not isomorphic: $V"))
t = tensormaptype(S, N₁, N₂, A)(undef, V)
""" isomorphism, isomorphism!

function isomorphism!(t::AbstractTensorMap)
domain(t) ≅ codomain(t) ||
throw(SpaceMismatch(lazy"domain and codomain are not isomorphic: $(space(t))"))
for (_, b) in blocks(t)
MatrixAlgebra.one!(b)
end
return t
end

"""
@doc """
unitary([T::Type=Float64,] codomain::TensorSpace, domain::TensorSpace) -> TensorMap
unitary([T::Type=Float64,] codomain ← domain) -> TensorMap
unitary([T::Type=Float64,] domain → codomain) -> TensorMap
unitary!(t::AbstractTensorMap) -> AbstractTensorMap

Construct a specific unitary morphism between the codomain and the domain, i.e. return a
`t::TensorMap` where either `scalartype(t) = T` if `T` is a `Number` type or
Expand All @@ -119,16 +127,18 @@
`unitary(cod, dom) == inv(unitary(dom, cod)) = adjoint(unitary(dom, cod))`.

See also [`isomorphism`](@ref) and [`isometry`](@ref).
"""
function unitary(A::Type, V::TensorMapSpace{S,N₁,N₂}) where {S,N₁,N₂}
InnerProductStyle(S) === EuclideanInnerProduct() || throw_invalid_innerproduct(:unitary)
return isomorphism(A, V)
""" unitary, unitary!

function unitary!(t::AbstractTensorMap)
InnerProductStyle(t) === EuclideanInnerProduct() || throw_invalid_innerproduct(:unitary)
return isomorphism!(t)
end

"""
@doc """
isometry([T::Type=Float64,] codomain::TensorSpace, domain::TensorSpace) -> TensorMap
isometry([T::Type=Float64,] codomain ← domain) -> TensorMap
isometry([T::Type=Float64,] domain → codomain) -> TensorMap
isometry!(t::AbstractTensorMap) -> AbstractTensorMap

Construct a specific isometry between the codomain and the domain, i.e. return a
`t::TensorMap` where either `scalartype(t) = T` if `T` is a `Number` type or
Expand All @@ -137,13 +147,13 @@
isometric inclusion, an error will be thrown.

See also [`isomorphism`](@ref) and [`unitary`](@ref).
"""
function isometry(A::Type, V::TensorMapSpace{S,N₁,N₂}) where {S,N₁,N₂}
InnerProductStyle(S) === EuclideanInnerProduct() ||
""" isometry, isometry!

function isometry!(t::AbstractTensorMap)
InnerProductStyle(t) === EuclideanInnerProduct() ||
throw_invalid_innerproduct(:isometry)
domain(V) ≾ codomain(V) ||
throw(SpaceMismatch("$V does not allow for an isometric inclusion"))
t = tensormaptype(S, N₁, N₂, A)(undef, V)
domain(t) ≾ codomain(t) ||
throw(SpaceMismatch(lazy"domain and codomain are not isometrically embeddable: $(space(t))"))
for (_, b) in blocks(t)
MatrixAlgebra.one!(b)
end
Expand All @@ -152,13 +162,18 @@

# expand methods with default arguments
for morphism in (:isomorphism, :unitary, :isometry)
morphism! = Symbol(morphism, :!)
@eval begin
$morphism(V::TensorMapSpace) = $morphism(Float64, V)
$morphism(codomain::TensorSpace, domain::TensorSpace) = $morphism(codomain ← domain)
function $morphism(A::Type, codomain::TensorSpace, domain::TensorSpace)
return $morphism(A, codomain ← domain)
end
$morphism(t::AbstractTensorMap) = $morphism(storagetype(t), space(t))
function $morphism(A::Type, V::TensorMapSpace{S,N₁,N₂}) where {S,N₁,N₂}
t = tensormaptype(S, N₁, N₂, A)(undef, V)
return $morphism!(t)
end
$morphism(t::AbstractTensorMap) = $morphism!(similar(t))

Check warning on line 176 in src/tensors/linalg.jl

View check run for this annotation

Codecov / codecov/patch

src/tensors/linalg.jl#L176

Added line #L176 was not covered by tests
end
end

Expand Down