diff --git a/Project.toml b/Project.toml index c0098d3..5da2630 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "RectangularFullPacked" uuid = "27983f2f-6524-42ba-a408-2b5a31c238e4" -version = "0.1.0" +version = "0.2.0" [deps] LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" diff --git a/src/RectangularFullPacked.jl b/src/RectangularFullPacked.jl index c1031c1..d1b227f 100644 --- a/src/RectangularFullPacked.jl +++ b/src/RectangularFullPacked.jl @@ -6,6 +6,8 @@ using LinearAlgebra using LinearAlgebra: BlasFloat, checksquare import Base: \ +import LinearAlgebra.BLAS: syrk! +import LinearAlgebra: Hermitian abstract type AbstractRFP{T} <: AbstractMatrix{T} end diff --git a/src/cholesky.jl b/src/cholesky.jl index 66b6e4a..f4289a3 100644 --- a/src/cholesky.jl +++ b/src/cholesky.jl @@ -4,8 +4,13 @@ struct CholeskyRFP{T<:BlasFloat} <: Factorization{T} uplo::Char end -LinearAlgebra.cholesky!(A::HermitianRFP{T}) where {T<:BlasFloat} = - CholeskyRFP(LAPACK_RFP.pftrf!(A.transr, A.uplo, A.data), A.transr, A.uplo) +function LinearAlgebra.cholesky!(A::HermitianRFP{T}) where {T<:BlasFloat} + return CholeskyRFP( + LAPACK_RFP.pftrf!(A.transr, A.uplo, A.data), + A.transr, + A.uplo, + ) +end LinearAlgebra.cholesky(A::HermitianRFP{T}) where {T<:BlasFloat} = cholesky!(copy(A)) LinearAlgebra.factorize(A::HermitianRFP) = cholesky(A) diff --git a/src/hermitian.jl b/src/hermitian.jl index 428a850..1c910aa 100644 --- a/src/hermitian.jl +++ b/src/hermitian.jl @@ -9,6 +9,18 @@ end #HermitianRFP(A::TriangularRFP) = HermitianRFP(A.data, A.transr, A.uplo) +function Hermitian(A::TriangularRFP{<:LinearAlgebra.BlasReal}, uplo::Symbol) + Symbol(A.uplo) == uplo || + throw(ArgumentError("A.uplo = $(A.uplo) conflicts with argument uplo = $uplo")) + return Hermitian(A) +end + +function Hermitian(A::TriangularRFP{<:LinearAlgebra.BlasReal}) + return HermitianRFP(A.data, A.transr, A.uplo) +end + +Base.copy(A::HermitianRFP{T}) where {T} = HermitianRFP{T}(copy(A.data), A.transr, A.uplo) + function Base.getindex(A::HermitianRFP, i::Integer, j::Integer) (A.uplo == 'L' ? i < j : i > j) && return conj(getindex(A, j, i)) n, k, l = checkbounds(A, i, j) @@ -28,4 +40,16 @@ function Ac_mul_A_RFP(A::Matrix{T}, uplo = :U) where {T<:BlasFloat} return HermitianRFP(LAPACK_RFP.sfrk!('N', ul, tr, 1.0, A, 0.0, par), 'N', ul) end -Base.copy(A::HermitianRFP) = HermitianRFP(copy(A.data), A.transr, A.uplo) +function syrk!( + trans::AbstractChar, + α::Real, + A::StridedMatrix{T}, + β::Real, + C::HermitianRFP{T}, +) where {T} + return HermitianRFP( + LAPACK_RFP.sfrk!(C.transr, C.uplo, Char(trans), α, A, β, C.data), + C.transr, + C.uplo, + ) +end diff --git a/src/triangular.jl b/src/triangular.jl index a1681b7..2fe91f5 100644 --- a/src/triangular.jl +++ b/src/triangular.jl @@ -4,7 +4,7 @@ struct TriangularRFP{T<:BlasFloat} <: AbstractRFP{T} uplo::Char end -function TriangularRFP(A::Matrix{T}, uplo::Symbol = :U; transr::Symbol=:N) where {T} +function TriangularRFP(A::Matrix{T}, uplo::Symbol = :U; transr::Symbol = :N) where {T} n = checksquare(A) ul = first(string(uplo)) if ul ∉ "UL" @@ -21,7 +21,7 @@ function TriangularRFP(A::Matrix{T}, uplo::Symbol = :U; transr::Symbol=:N) where ul, ) end - + function Base.Array(A::TriangularRFP{T}) where {T} n, k, l = _rfpsize(A) C = Array{T}(undef, (n, n)) @@ -36,7 +36,7 @@ function Base.getindex(A::TriangularRFP{T}, i::Integer, j::Integer) where {T} (A.uplo == 'L' ? i < j : i > j) && return zero(T) rs, doconj = _packedinds(A, Int(i), Int(j), iseven(n), l) val = A.data[first(rs), last(rs)] - return doconj ? conj(val) : val + return doconj ? conj(val) : val end function Base.setindex!(A::TriangularRFP{T}, x::T, i::Integer, j::Integer) where {T} diff --git a/src/utilities.jl b/src/utilities.jl index f14035b..40064cf 100644 --- a/src/utilities.jl +++ b/src/utilities.jl @@ -23,7 +23,7 @@ end function _packedinds(i::Int, j::Int, lower::Bool, neven::Bool, tr::Bool, l::Int) if lower conj = l < j - inds = conj ? (j - l, i + !neven - l) : (i + neven, j) + inds = conj ? (j - l, i + !neven - l) : (i + neven, j) else conj = (j + !neven) ≤ l inds = conj ? (l + neven + j, i) : (i, j + !neven - l) @@ -55,7 +55,8 @@ function _rfpsize(A::AbstractRFP) dsz = size(A.data) k, l = A.transr == 'N' ? dsz : reverse(dsz) L = 2l - isone(abs(k - L)) || throw(ArgumentError("size(A.data) = $dsz is not consistent with RFP")) + isone(abs(k - L)) || + throw(ArgumentError("size(A.data) = $dsz is not consistent with RFP")) return k - (L < k), k, l end @@ -73,3 +74,13 @@ function Base.size(A::AbstractRFP) n, k, l = _rfpsize(A) return (n, n) end + +function LinearAlgebra.rmul!(A::AbstractRFP, B::Number) + rmul!(A.data, B) + return A +end + +function LinearAlgebra.lmul!(A::Number, B::AbstractRFP) + lmul!(A, B.data) + return B +end diff --git a/test/runtests.jl b/test/runtests.jl index 88e809f..dea31bc 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -69,6 +69,7 @@ import RectangularFullPacked: Ac_mul_A_RFP, TriangularRFP A = rand(elty, 10, n) AcA = A'A AcA_RFP = Ac_mul_A_RFP(A, uplo) + @test AcA_RFP ≈ BLAS.syrk!(elty <: Complex ? 'C' : 'T', 1.0, A, 0.0, copy(AcA_RFP)) o = ones(elty, n) @test AcA ≈ AcA_RFP @@ -97,4 +98,18 @@ import RectangularFullPacked: Ac_mul_A_RFP, TriangularRFP @test A \ o ≈ A_RFP \ o @test inv(A) ≈ Array(inv(A_RFP)) end + + @testset "In-place scalar multiplication" begin + U = lu(rand(7, 7)).U + B = sqrt(π) + @test rmul!(copy(U), B) ≈ rmul!(TriangularRFP(U, :U), B) + @test lmul!(B, copy(U)) ≈ lmul!(B, TriangularRFP(U, :U; transr=:T)) + end + + @testset "Hermitian from Triangular" begin + U = lu(rand(7,7)).U + @test Hermitian(TriangularRFP(U, :U)) ≈ Hermitian(U, :U) + @test Hermitian(TriangularRFP(U, :U), :U) ≈ Hermitian(U, :U) + @test_throws ArgumentError Hermitian(TriangularRFP(U, :U), :L) + end end