Skip to content

Add copy, lmul! and rmul! methods #8

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Oct 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "RectangularFullPacked"
uuid = "27983f2f-6524-42ba-a408-2b5a31c238e4"
version = "0.1.0"
version = "0.2.0"

[deps]
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand Down
2 changes: 2 additions & 0 deletions src/RectangularFullPacked.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ using LinearAlgebra
using LinearAlgebra: BlasFloat, checksquare

import Base: \
import LinearAlgebra.BLAS: syrk!
import LinearAlgebra: Hermitian

abstract type AbstractRFP{T} <: AbstractMatrix{T} end

Expand Down
9 changes: 7 additions & 2 deletions src/cholesky.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,13 @@ struct CholeskyRFP{T<:BlasFloat} <: Factorization{T}
uplo::Char
end

LinearAlgebra.cholesky!(A::HermitianRFP{T}) where {T<:BlasFloat} =
CholeskyRFP(LAPACK_RFP.pftrf!(A.transr, A.uplo, A.data), A.transr, A.uplo)
function LinearAlgebra.cholesky!(A::HermitianRFP{T}) where {T<:BlasFloat}
return CholeskyRFP(
LAPACK_RFP.pftrf!(A.transr, A.uplo, A.data),
A.transr,
A.uplo,
)
end
LinearAlgebra.cholesky(A::HermitianRFP{T}) where {T<:BlasFloat} = cholesky!(copy(A))
LinearAlgebra.factorize(A::HermitianRFP) = cholesky(A)

Expand Down
26 changes: 25 additions & 1 deletion src/hermitian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,18 @@ end

#HermitianRFP(A::TriangularRFP) = HermitianRFP(A.data, A.transr, A.uplo)

function Hermitian(A::TriangularRFP{<:LinearAlgebra.BlasReal}, uplo::Symbol)
Symbol(A.uplo) == uplo ||
throw(ArgumentError("A.uplo = $(A.uplo) conflicts with argument uplo = $uplo"))
return Hermitian(A)
end

function Hermitian(A::TriangularRFP{<:LinearAlgebra.BlasReal})
return HermitianRFP(A.data, A.transr, A.uplo)
end

Base.copy(A::HermitianRFP{T}) where {T} = HermitianRFP{T}(copy(A.data), A.transr, A.uplo)

function Base.getindex(A::HermitianRFP, i::Integer, j::Integer)
(A.uplo == 'L' ? i < j : i > j) && return conj(getindex(A, j, i))
n, k, l = checkbounds(A, i, j)
Expand All @@ -28,4 +40,16 @@ function Ac_mul_A_RFP(A::Matrix{T}, uplo = :U) where {T<:BlasFloat}
return HermitianRFP(LAPACK_RFP.sfrk!('N', ul, tr, 1.0, A, 0.0, par), 'N', ul)
end

Base.copy(A::HermitianRFP) = HermitianRFP(copy(A.data), A.transr, A.uplo)
function syrk!(
trans::AbstractChar,
α::Real,
A::StridedMatrix{T},
β::Real,
C::HermitianRFP{T},
) where {T}
return HermitianRFP(
LAPACK_RFP.sfrk!(C.transr, C.uplo, Char(trans), α, A, β, C.data),
C.transr,
C.uplo,
)
end
6 changes: 3 additions & 3 deletions src/triangular.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ struct TriangularRFP{T<:BlasFloat} <: AbstractRFP{T}
uplo::Char
end

function TriangularRFP(A::Matrix{T}, uplo::Symbol = :U; transr::Symbol=:N) where {T}
function TriangularRFP(A::Matrix{T}, uplo::Symbol = :U; transr::Symbol = :N) where {T}
n = checksquare(A)
ul = first(string(uplo))
if ul ∉ "UL"
Expand All @@ -21,7 +21,7 @@ function TriangularRFP(A::Matrix{T}, uplo::Symbol = :U; transr::Symbol=:N) where
ul,
)
end

function Base.Array(A::TriangularRFP{T}) where {T}
n, k, l = _rfpsize(A)
C = Array{T}(undef, (n, n))
Expand All @@ -36,7 +36,7 @@ function Base.getindex(A::TriangularRFP{T}, i::Integer, j::Integer) where {T}
(A.uplo == 'L' ? i < j : i > j) && return zero(T)
rs, doconj = _packedinds(A, Int(i), Int(j), iseven(n), l)
val = A.data[first(rs), last(rs)]
return doconj ? conj(val) : val
return doconj ? conj(val) : val
end

function Base.setindex!(A::TriangularRFP{T}, x::T, i::Integer, j::Integer) where {T}
Expand Down
15 changes: 13 additions & 2 deletions src/utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ end
function _packedinds(i::Int, j::Int, lower::Bool, neven::Bool, tr::Bool, l::Int)
if lower
conj = l < j
inds = conj ? (j - l, i + !neven - l) : (i + neven, j)
inds = conj ? (j - l, i + !neven - l) : (i + neven, j)
else
conj = (j + !neven) ≤ l
inds = conj ? (l + neven + j, i) : (i, j + !neven - l)
Expand Down Expand Up @@ -55,7 +55,8 @@ function _rfpsize(A::AbstractRFP)
dsz = size(A.data)
k, l = A.transr == 'N' ? dsz : reverse(dsz)
L = 2l
isone(abs(k - L)) || throw(ArgumentError("size(A.data) = $dsz is not consistent with RFP"))
isone(abs(k - L)) ||
throw(ArgumentError("size(A.data) = $dsz is not consistent with RFP"))
return k - (L < k), k, l
end

Expand All @@ -73,3 +74,13 @@ function Base.size(A::AbstractRFP)
n, k, l = _rfpsize(A)
return (n, n)
end

function LinearAlgebra.rmul!(A::AbstractRFP, B::Number)
rmul!(A.data, B)
return A
end

function LinearAlgebra.lmul!(A::Number, B::AbstractRFP)
lmul!(A, B.data)
return B
end
15 changes: 15 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ import RectangularFullPacked: Ac_mul_A_RFP, TriangularRFP
A = rand(elty, 10, n)
AcA = A'A
AcA_RFP = Ac_mul_A_RFP(A, uplo)
@test AcA_RFP ≈ BLAS.syrk!(elty <: Complex ? 'C' : 'T', 1.0, A, 0.0, copy(AcA_RFP))
o = ones(elty, n)

@test AcA ≈ AcA_RFP
Expand Down Expand Up @@ -97,4 +98,18 @@ import RectangularFullPacked: Ac_mul_A_RFP, TriangularRFP
@test A \ o ≈ A_RFP \ o
@test inv(A) ≈ Array(inv(A_RFP))
end

@testset "In-place scalar multiplication" begin
U = lu(rand(7, 7)).U
B = sqrt(π)
@test rmul!(copy(U), B) ≈ rmul!(TriangularRFP(U, :U), B)
@test lmul!(B, copy(U)) ≈ lmul!(B, TriangularRFP(U, :U; transr=:T))
end

@testset "Hermitian from Triangular" begin
U = lu(rand(7,7)).U
@test Hermitian(TriangularRFP(U, :U)) ≈ Hermitian(U, :U)
@test Hermitian(TriangularRFP(U, :U), :U) ≈ Hermitian(U, :U)
@test_throws ArgumentError Hermitian(TriangularRFP(U, :U), :L)
end
end