1
1
using LoopVectorization
2
2
using TriangularSolve: ldiv!
3
3
using LinearAlgebra: BlasInt, BlasFloat, LU, UnitLowerTriangular, checknonsingular, BLAS,
4
- LinearAlgebra, Adjoint, Transpose, UpperTriangular, AbstractVecOrMat
4
+ LinearAlgebra, Adjoint, Transpose, UpperTriangular, AbstractVecOrMat
5
5
using StrideArraysCore
6
6
using Polyester: @batch
7
7
@@ -41,32 +41,35 @@ init_pivot(::Val{true}, minmn) = Vector{BlasInt}(undef, minmn)
41
41
42
42
if CUSTOMIZABLE_PIVOT && isdefined (LinearAlgebra, :_ipiv_cols! )
43
43
function LinearAlgebra. _ipiv_cols! (:: LU{<:Any, <:Any, NotIPIV} , :: OrdinalRange ,
44
- B:: StridedVecOrMat )
44
+ B:: StridedVecOrMat )
45
45
return B
46
46
end
47
47
end
48
48
if CUSTOMIZABLE_PIVOT && isdefined (LinearAlgebra, :_ipiv_rows! )
49
49
function LinearAlgebra. _ipiv_rows! (:: (LU{T, <:AbstractMatrix{T}, NotIPIV} where {T}) ,
50
- :: OrdinalRange ,
51
- B:: StridedVecOrMat )
50
+ :: OrdinalRange ,
51
+ B:: StridedVecOrMat )
52
52
return B
53
53
end
54
54
end
55
55
if CUSTOMIZABLE_PIVOT
56
56
function LinearAlgebra. ldiv! (A:: LU{T, <:StridedMatrix, <:NotIPIV} ,
57
- B:: StridedVecOrMat{T} ) where {T <: BlasFloat }
57
+ B:: StridedVecOrMat{T} ) where {T <: BlasFloat }
58
58
ldiv! (UpperTriangular (A. factors), ldiv! (UnitLowerTriangular (A. factors), B))
59
59
end
60
60
end
61
61
62
- function lu! (A, pivot = Val (true ), thread = Val (false ); check = true , kwargs... )
62
+ function lu! (A, pivot = Val (true ), thread = Val (false );
63
+ check:: Union{Bool, Val{true}, Val{false}} = Val (true ), kwargs... )
63
64
m, n = size (A)
64
65
minmn = min (m, n)
65
66
npivot = normalize_pivot (pivot)
66
67
# we want the type on both branches to match. When pivot = Val(false), we construct
67
68
# a `NotIPIV`, which `LinearAlgebra.generic_lufact!` does not.
68
69
F = if pivot === Val (true ) && minmn < 10 # avx introduces small performance degradation
69
- LinearAlgebra. generic_lufact! (A, to_stdlib_pivot (pivot); check = check)
70
+ LinearAlgebra. generic_lufact! (A, to_stdlib_pivot (pivot);
71
+ check = ((check isa Bool && check) || (check === Val (true )))
72
+ )
70
73
else
71
74
lu! (A, init_pivot (npivot, minmn), npivot, thread; check = check,
72
75
kwargs... )
@@ -87,11 +90,11 @@ recurse(_) = false
87
90
_ptrarray (ipiv) = PtrArray (ipiv)
88
91
_ptrarray (ipiv:: NotIPIV ) = ipiv
89
92
function lu! (A:: AbstractMatrix{T} , ipiv:: AbstractVector{<:Integer} ,
90
- pivot = Val (true ), thread = Val (false );
91
- check:: Bool = true ,
92
- # the performance is not sensitive wrt blocksize, and 8 is a good default
93
- blocksize:: Integer = length (A) ≥ 40_000 ? 8 : 16 ,
94
- threshold:: Integer = pick_threshold ()) where {T}
93
+ pivot = Val (true ), thread = Val (false );
94
+ check:: Union{ Bool, Val{true}, Val{false}} = Val ( true ) ,
95
+ # the performance is not sensitive wrt blocksize, and 8 is a good default
96
+ blocksize:: Integer = length (A) ≥ 40_000 ? 8 : 16 ,
97
+ threshold:: Integer = pick_threshold ()) where {T}
95
98
pivot = normalize_pivot (pivot)
96
99
info = zero (BlasInt)
97
100
m, n = size (A)
@@ -113,12 +116,12 @@ function lu!(A::AbstractMatrix{T}, ipiv::AbstractVector{<:Integer},
113
116
else # generic fallback
114
117
info = _generic_lufact! (A, pivot, ipiv, info)
115
118
end
116
- check && checknonsingular (info)
119
+ (( check isa Bool && check) || (check === Val ( true ))) && checknonsingular (info)
117
120
LU (A, ipiv, info)
118
121
end
119
122
120
123
@inline function recurse! (A, :: Val{Pivot} , m, n, mnmin, ipiv, info, blocksize,
121
- :: Val{true} ) where {Pivot}
124
+ :: Val{true} ) where {Pivot}
122
125
if length (A) * _sizeof (eltype (A)) >
123
126
0.92 * LoopVectorization. VectorizationBase. cache_size (Val (2 ))
124
127
_recurse! (A, Val {Pivot} (), m, n, mnmin, ipiv, info, blocksize, Val (true ))
@@ -127,11 +130,11 @@ end
127
130
end
128
131
end
129
132
@inline function recurse! (A, :: Val{Pivot} , m, n, mnmin, ipiv, info, blocksize,
130
- :: Val{false} ) where {Pivot}
133
+ :: Val{false} ) where {Pivot}
131
134
_recurse! (A, Val {Pivot} (), m, n, mnmin, ipiv, info, blocksize, Val (false ))
132
135
end
133
136
@inline function _recurse! (A, :: Val{Pivot} , m, n, mnmin, ipiv, info, blocksize,
134
- :: Val{Thread} ) where {Pivot, Thread}
137
+ :: Val{Thread} ) where {Pivot, Thread}
135
138
info = reckernel! (A, Val (Pivot), m, mnmin, ipiv, info, blocksize, Val (Thread)):: Int
136
139
@inbounds if m < n # fat matrix
137
140
# [AL AR]
@@ -175,7 +178,7 @@ Base.@propagate_inbounds function apply_permutation!(P, A, ::Val{false})
175
178
nothing
176
179
end
177
180
function reckernel! (A:: AbstractMatrix{T} , pivot:: Val{Pivot} , m, n, ipiv, info, blocksize,
178
- thread):: BlasInt where {T, Pivot}
181
+ thread):: BlasInt where {T, Pivot}
179
182
@inbounds begin
180
183
if n <= max (blocksize, 1 )
181
184
info = _generic_lufact! (A, Val (Pivot), ipiv, info)
0 commit comments