Skip to content

Commit 26de024

Browse files
authored
Merge pull request #93 from JuliaLinearAlgebra/checkdoesnotconstprop
Allow passing check as a `Val`, because constprop fails.
2 parents 425e436 + 58fdb2f commit 26de024

File tree

2 files changed

+21
-18
lines changed

2 files changed

+21
-18
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "RecursiveFactorization"
22
uuid = "f2c3362d-daeb-58d1-803e-2bc74f2840b4"
33
authors = ["Yingbo Ma <[email protected]>"]
4-
version = "0.2.21"
4+
version = "0.2.22"
55

66
[deps]
77
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

src/lu.jl

+20-17
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
using LoopVectorization
22
using TriangularSolve: ldiv!
33
using LinearAlgebra: BlasInt, BlasFloat, LU, UnitLowerTriangular, checknonsingular, BLAS,
4-
LinearAlgebra, Adjoint, Transpose, UpperTriangular, AbstractVecOrMat
4+
LinearAlgebra, Adjoint, Transpose, UpperTriangular, AbstractVecOrMat
55
using StrideArraysCore
66
using Polyester: @batch
77

@@ -41,32 +41,35 @@ init_pivot(::Val{true}, minmn) = Vector{BlasInt}(undef, minmn)
4141

4242
if CUSTOMIZABLE_PIVOT && isdefined(LinearAlgebra, :_ipiv_cols!)
4343
function LinearAlgebra._ipiv_cols!(::LU{<:Any, <:Any, NotIPIV}, ::OrdinalRange,
44-
B::StridedVecOrMat)
44+
B::StridedVecOrMat)
4545
return B
4646
end
4747
end
4848
if CUSTOMIZABLE_PIVOT && isdefined(LinearAlgebra, :_ipiv_rows!)
4949
function LinearAlgebra._ipiv_rows!(::(LU{T, <:AbstractMatrix{T}, NotIPIV} where {T}),
50-
::OrdinalRange,
51-
B::StridedVecOrMat)
50+
::OrdinalRange,
51+
B::StridedVecOrMat)
5252
return B
5353
end
5454
end
5555
if CUSTOMIZABLE_PIVOT
5656
function LinearAlgebra.ldiv!(A::LU{T, <:StridedMatrix, <:NotIPIV},
57-
B::StridedVecOrMat{T}) where {T <: BlasFloat}
57+
B::StridedVecOrMat{T}) where {T <: BlasFloat}
5858
ldiv!(UpperTriangular(A.factors), ldiv!(UnitLowerTriangular(A.factors), B))
5959
end
6060
end
6161

62-
function lu!(A, pivot = Val(true), thread = Val(false); check = true, kwargs...)
62+
function lu!(A, pivot = Val(true), thread = Val(false);
63+
check::Union{Bool, Val{true}, Val{false}} = Val(true), kwargs...)
6364
m, n = size(A)
6465
minmn = min(m, n)
6566
npivot = normalize_pivot(pivot)
6667
# we want the type on both branches to match. When pivot = Val(false), we construct
6768
# a `NotIPIV`, which `LinearAlgebra.generic_lufact!` does not.
6869
F = if pivot === Val(true) && minmn < 10 # avx introduces small performance degradation
69-
LinearAlgebra.generic_lufact!(A, to_stdlib_pivot(pivot); check = check)
70+
LinearAlgebra.generic_lufact!(A, to_stdlib_pivot(pivot);
71+
check = ((check isa Bool && check) || (check === Val(true)))
72+
)
7073
else
7174
lu!(A, init_pivot(npivot, minmn), npivot, thread; check = check,
7275
kwargs...)
@@ -87,11 +90,11 @@ recurse(_) = false
8790
_ptrarray(ipiv) = PtrArray(ipiv)
8891
_ptrarray(ipiv::NotIPIV) = ipiv
8992
function lu!(A::AbstractMatrix{T}, ipiv::AbstractVector{<:Integer},
90-
pivot = Val(true), thread = Val(false);
91-
check::Bool = true,
92-
# the performance is not sensitive wrt blocksize, and 8 is a good default
93-
blocksize::Integer = length(A) 40_000 ? 8 : 16,
94-
threshold::Integer = pick_threshold()) where {T}
93+
pivot = Val(true), thread = Val(false);
94+
check::Union{Bool, Val{true}, Val{false}} = Val(true),
95+
# the performance is not sensitive wrt blocksize, and 8 is a good default
96+
blocksize::Integer = length(A) 40_000 ? 8 : 16,
97+
threshold::Integer = pick_threshold()) where {T}
9598
pivot = normalize_pivot(pivot)
9699
info = zero(BlasInt)
97100
m, n = size(A)
@@ -113,12 +116,12 @@ function lu!(A::AbstractMatrix{T}, ipiv::AbstractVector{<:Integer},
113116
else # generic fallback
114117
info = _generic_lufact!(A, pivot, ipiv, info)
115118
end
116-
check && checknonsingular(info)
119+
((check isa Bool && check) || (check === Val(true))) && checknonsingular(info)
117120
LU(A, ipiv, info)
118121
end
119122

120123
@inline function recurse!(A, ::Val{Pivot}, m, n, mnmin, ipiv, info, blocksize,
121-
::Val{true}) where {Pivot}
124+
::Val{true}) where {Pivot}
122125
if length(A) * _sizeof(eltype(A)) >
123126
0.92 * LoopVectorization.VectorizationBase.cache_size(Val(2))
124127
_recurse!(A, Val{Pivot}(), m, n, mnmin, ipiv, info, blocksize, Val(true))
@@ -127,11 +130,11 @@ end
127130
end
128131
end
129132
@inline function recurse!(A, ::Val{Pivot}, m, n, mnmin, ipiv, info, blocksize,
130-
::Val{false}) where {Pivot}
133+
::Val{false}) where {Pivot}
131134
_recurse!(A, Val{Pivot}(), m, n, mnmin, ipiv, info, blocksize, Val(false))
132135
end
133136
@inline function _recurse!(A, ::Val{Pivot}, m, n, mnmin, ipiv, info, blocksize,
134-
::Val{Thread}) where {Pivot, Thread}
137+
::Val{Thread}) where {Pivot, Thread}
135138
info = reckernel!(A, Val(Pivot), m, mnmin, ipiv, info, blocksize, Val(Thread))::Int
136139
@inbounds if m < n # fat matrix
137140
# [AL AR]
@@ -175,7 +178,7 @@ Base.@propagate_inbounds function apply_permutation!(P, A, ::Val{false})
175178
nothing
176179
end
177180
function reckernel!(A::AbstractMatrix{T}, pivot::Val{Pivot}, m, n, ipiv, info, blocksize,
178-
thread)::BlasInt where {T, Pivot}
181+
thread)::BlasInt where {T, Pivot}
179182
@inbounds begin
180183
if n <= max(blocksize, 1)
181184
info = _generic_lufact!(A, Val(Pivot), ipiv, info)

0 commit comments

Comments
 (0)