Skip to content

Commit d3f42d2

Browse files
Fe-r-ozKrastanov
andauthored
better bit-wrangling abstraction using less boilerplate (#367)
Co-authored-by: Stefan Krastanov <[email protected]>
1 parent 4cc664e commit d3f42d2

File tree

7 files changed

+68
-75
lines changed

7 files changed

+68
-75
lines changed

ext/QuantumCliffordGPUExt/apply_noise.jl

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,11 @@
1-
using QuantumClifford: _div, _mod
1+
using QuantumClifford: get_bitmask_idxs
22

33
#according to https://github.com/JuliaGPU/CUDA.jl/blob/ac1bc29a118e7be56d9edb084a4dea4224c1d707/test/core/device/random.jl#L33
44
#CUDA.jl supports calling rand() inside kernel
55
function applynoise!(frame::PauliFrameGPU{T},noise::UnbiasedUncorrelatedNoise,i::Int) where {T <: Unsigned}
66
p = noise.p
7-
lowbit = T(1)
8-
ibig = _div(T,i-1)+1
9-
ismall = _mod(T,i-1)
10-
ismallm = lowbit<<(ismall)
11-
12-
stab = frame.frame
13-
xzs = tab(stab).xzs
7+
xzs = tab(frame.frame).xzs
8+
lowbit, ibig, ismall, ismallm = get_bitmask_idxs(xzs,i)
149
rows = size(stab, 1)
1510

1611
@run_cuda applynoise_kernel(xzs, p, ibig, ismallm, rows) rows

ext/QuantumCliffordGPUExt/pauli_frames.jl

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
using QuantumClifford: get_bitmask_idxs
2+
13
##############################
24
# sMZ
35
##############################
@@ -21,10 +23,7 @@ function apply!(frame::PauliFrameGPU{T}, op::QuantumClifford.sMZ) where {T <: Un
2123
op.bit == 0 && return frame
2224
i = op.qubit
2325
xzs = frame.frame.tab.xzs
24-
lowbit = T(1)
25-
ibig = QuantumClifford._div(T,i-1)+1
26-
ismall = QuantumClifford._mod(T,i-1)
27-
ismallm = lowbit<<(ismall)
26+
lowbit, ibig, ismall, ismallm = get_bitmask_idxs(xzs,i)
2827
(@run_cuda apply_sMZ_kernel!(xzs, frame.measurements, op, ibig, ismallm, length(frame)) length(frame))
2928
return frame
3029
end
@@ -55,10 +54,7 @@ end
5554
function apply!(frame::PauliFrameGPU{T}, op::QuantumClifford.sMRZ) where {T <: Unsigned} # TODO sMRX, sMRY
5655
i = op.qubit
5756
xzs = frame.frame.tab.xzs
58-
lowbit = T(1)
59-
ibig = QuantumClifford._div(T,i-1)+1
60-
ismall = QuantumClifford._mod(T,i-1)
61-
ismallm = lowbit<<(ismall)
57+
lowbit, ibig, ismall, ismallm = get_bitmask_idxs(xzs,i)
6258
(@run_cuda apply_sMRZ_kernel!(xzs, frame.measurements, op, ibig, ismallm, length(frame)) length(frame))
6359
return frame
6460
end

src/QuantumClifford.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -912,6 +912,30 @@ function unsafe_bitfindnext_(chunks::AbstractVector{T}, start::Int) where T<:Uns
912912
return nothing
913913
end
914914

915+
"""
916+
$(TYPEDSIGNATURES)
917+
918+
Computes bitmask indices for an unsigned integer at index `i`
919+
within the binary structure of a `Tableau` or `PauliOperator`.
920+
921+
For `Tableau`, the method operates on the `.xzs` field, while
922+
for `PauliOperator`, it uses the `.xz` field. It calculates
923+
the following values based on the index `i`:
924+
925+
- `lowbit`, the lowest bit.
926+
- `ibig`, the index of the word containing the bit.
927+
- `ismall`, the position of the bit within the word.
928+
- `ismallm`, a bitmask isolating the specified bit.
929+
"""
930+
@inline function get_bitmask_idxs(xzs::AbstractArray{<:Unsigned}, i::Int)
931+
T = eltype(xzs)
932+
lowbit = T(1)
933+
ibig = _div(T, i-1) + 1
934+
ismall = _mod(T, i-1)
935+
ismallm = lowbit << ismall
936+
return lowbit, ibig, ismall, ismallm
937+
end
938+
915939
"""Permute the qubits (i.e., columns) of the tableau in place."""
916940
function Base.permute!(s::Tableau, perm::AbstractVector)
917941
for r in 1:size(s,1)

src/pauli_frames.jl

Lines changed: 16 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -82,12 +82,8 @@ end
8282
function apply!(frame::PauliFrame, op::sMZ) # TODO sMY, and faster sMX
8383
op.bit == 0 && return frame
8484
i = op.qubit
85-
xzs = frame.frame.tab.xzs
86-
T = eltype(xzs)
87-
lowbit = T(1)
88-
ibig = _div(T,i-1)+1
89-
ismall = _mod(T,i-1)
90-
ismallm = lowbit<<(ismall)
85+
xzs = tab(frame.frame).xzs
86+
lowbit, ibig, ismall, ismallm = get_bitmask_idxs(xzs,i)
9187

9288
@inbounds @simd for f in eachindex(frame)
9389
should_flip = !iszero(xzs[ibig,f] & ismallm)
@@ -99,12 +95,8 @@ end
9995

10096
function apply!(frame::PauliFrame, op::sMRZ) # TODO sMRY, and faster sMRX
10197
i = op.qubit
102-
xzs = frame.frame.tab.xzs
103-
T = eltype(xzs)
104-
lowbit = T(1)
105-
ibig = _div(T,i-1)+1
106-
ismall = _mod(T,i-1)
107-
ismallm = lowbit<<(ismall)
98+
xzs = tab(frame.frame).xzs
99+
lowbit, ibig, ismall, ismallm = get_bitmask_idxs(xzs,i)
108100

109101
if op.bit != 0
110102
@inbounds @simd for f in eachindex(frame)
@@ -122,45 +114,39 @@ end
122114

123115
function applynoise!(frame::PauliFrame,noise::UnbiasedUncorrelatedNoise,i::Int)
124116
p = noise.p
125-
T = eltype(frame.frame.tab.xzs)
117+
xzs = tab(frame.frame).xzs
126118

127-
lowbit = T(1)
128-
ibig = _div(T,i-1)+1
129-
ismall = _mod(T,i-1)
130-
ismallm = lowbit<<(ismall)
119+
lowbit, ibig, ismall, ismallm = get_bitmask_idxs(xzs,i)
131120
p = p/3
132121

133122
@inbounds @simd for f in eachindex(frame)
134123
r = rand()
135124
if r < p # X error
136-
frame.frame.tab.xzs[ibig,f] ⊻= ismallm
125+
xzs[ibig,f] ⊻= ismallm
137126
elseif r < 2p # Z error
138-
frame.frame.tab.xzs[end÷2+ibig,f] ⊻= ismallm
127+
xzs[end÷2+ibig,f] ⊻= ismallm
139128
elseif r < 3p # Y error
140-
frame.frame.tab.xzs[ibig,f] ⊻= ismallm
141-
frame.frame.tab.xzs[end÷2+ibig,f] ⊻= ismallm
129+
xzs[ibig,f] ⊻= ismallm
130+
xzs[end÷2+ibig,f] ⊻= ismallm
142131
end
143132
end
144133
return frame
145134
end
146135

147136
function applynoise!(frame::PauliFrame,noise::PauliNoise,i::Int)
148-
T = eltype(frame.frame.tab.xzs)
137+
xzs = tab(frame.frame).xzs
149138

150-
lowbit = T(1)
151-
ibig = _div(T,i-1)+1
152-
ismall = _mod(T,i-1)
153-
ismallm = lowbit<<(ismall)
139+
lowbit, ibig, ismall, ismallm = get_bitmask_idxs(xzs,i)
154140

155141
@inbounds @simd for f in eachindex(frame)
156142
r = rand()
157143
if r < noise.px # X error
158-
frame.frame.tab.xzs[ibig,f] ⊻= ismallm
144+
xzs[ibig,f] ⊻= ismallm
159145
elseif r < noise.px+noise.pz # Z error
160-
frame.frame.tab.xzs[end÷2+ibig,f] ⊻= ismallm
146+
xzs[end÷2+ibig,f] ⊻= ismallm
161147
elseif r < noise.px+noise.pz+noise.py # Y error
162-
frame.frame.tab.xzs[ibig,f] ⊻= ismallm
163-
frame.frame.tab.xzs[end÷2+ibig,f] ⊻= ismallm
148+
xzs[ibig,f] ⊻= ismallm
149+
xzs[end÷2+ibig,f] ⊻= ismallm
164150
end
165151
end
166152
return frame

src/pauli_operator.jl

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -87,19 +87,23 @@ macro P_str(a)
8787
quote _P_str($a) end
8888
end
8989

90-
Base.getindex(p::PauliOperator{Tₚ,Tᵥ}, i::Int) where {Tₚ, Tᵥₑ<:Unsigned, Tᵥ<:AbstractVector{Tᵥₑ}} = ((p.xz[_div(Tᵥₑ, i-1)+1] & Tᵥₑ(0x1)<<_mod(Tᵥₑ,i-1))!=0x0)::Bool, ((p.xz[end÷2+_div(Tᵥₑ,i-1)+1] & Tᵥₑ(0x1)<<_mod(Tᵥₑ,i-1))!=0x0)::Bool
90+
function Base.getindex(p::PauliOperator{Tₚ,Tᵥ}, i::Int) where {Tₚ, Tᵥₑ<:Unsigned, Tᵥ<:AbstractVector{Tᵥₑ}}
91+
_, ibig, _, ismallm = get_bitmask_idxs(p.xz,i)
92+
((p.xz[ibig] & ismallm) != 0x0)::Bool, ((p.xz[end÷2+ibig] & ismallm) != 0x0)::Bool
93+
end
9194
Base.getindex(p::PauliOperator{Tₚ,Tᵥ}, r) where {Tₚ, Tᵥₑ<:Unsigned, Tᵥ<:AbstractVector{Tᵥₑ}} = PauliOperator(p.phase[], xbit(p)[r], zbit(p)[r])
9295

9396
function Base.setindex!(p::PauliOperator{Tₚ,Tᵥ}, (x,z)::Tuple{Bool,Bool}, i) where {Tₚ, Tᵥₑ, Tᵥ<:AbstractVector{Tᵥₑ}}
97+
_, ibig, _, ismallm = get_bitmask_idxs(p.xz,i)
9498
if x
95-
p.xz[_div(Tᵥₑ,i-1)+1] |= Tᵥₑ(0x1)<<_mod(Tᵥₑ,i-1)
99+
p.xz[ibig] |= ismallm
96100
else
97-
p.xz[_div(Tᵥₑ,i-1)+1] &= ~(Tᵥₑ(0x1)<<_mod(Tᵥₑ,i-1))
101+
p.xz[ibig] &= ~(ismallm)
98102
end
99103
if z
100-
p.xz[end÷2+_div(Tᵥₑ,i-1)+1] |= Tᵥₑ(0x1)<<_mod(Tᵥₑ,i-1)
104+
p.xz[end÷2+ibig] |= ismallm
101105
else
102-
p.xz[end÷2+_div(Tᵥₑ,i-1)+1] &= ~(Tᵥₑ(0x1)<<_mod(Tᵥₑ,i-1))
106+
p.xz[end÷2+ibig] &= ~(ismallm)
103107
end
104108
p
105109
end

src/project_trace_reset.jl

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -42,17 +42,15 @@ function _generate!(pauli::PauliOperator{Tₚ,Tᵥ}, stabilizer::Stabilizer{Tabl
4242
xzs = tab(stabilizer).xzs
4343
xs = @view xzs[1:end÷2,:]
4444
zs = @view xzs[end÷2+1:end,:]
45-
lowbit = Tₘₑ(0x1)
4645
zerobit = Tₘₑ(0x0)
4746
px,pz = xview(pauli), zview(pauli)
4847
used_indices = Int[]
4948
used = 0
5049
# remove Xs
5150
while (i=unsafe_bitfindnext_(px,1); i !== nothing) # TODO awkward notation due to https://github.com/JuliaLang/julia/issues/45499
52-
jbig = _div(Tₘₑ,i-1)+1
53-
jsmall = lowbit<<_mod(Tₘₑ,i-1)
54-
candidate = findfirst(e->e&jsmall!=zerobit, # TODO some form of reinterpret might be faster than equality check
55-
xs[jbig,used+1:end])
51+
_, ibig, _, ismallm = get_bitmask_idxs(xzs,i)
52+
candidate = findfirst(e->e&ismallm!=zerobit, # TODO some form of reinterpret might be faster than equality check
53+
xs[ibig,used+1:end])
5654
if isnothing(candidate)
5755
return nothing
5856
else
@@ -63,10 +61,9 @@ function _generate!(pauli::PauliOperator{Tₚ,Tᵥ}, stabilizer::Stabilizer{Tabl
6361
end
6462
# remove Zs
6563
while (i=unsafe_bitfindnext_(pz,1); i !== nothing) # TODO awkward notation due to https://github.com/JuliaLang/julia/issues/45499
66-
jbig = _div(Tₘₑ,i-1)+1
67-
jsmall = lowbit<<_mod(Tₘₑ,i-1)
68-
candidate = findfirst(e->e&jsmall!=zerobit, # TODO some form of reinterpret might be faster than equality check
69-
zs[jbig,used+1:end])
64+
_, ibig, _, ismallm = get_bitmask_idxs(xzs,i)
65+
candidate = findfirst(e->e&ismallm!=zerobit, # TODO some form of reinterpret might be faster than equality check
66+
zs[ibig,used+1:end])
7067
if isnothing(candidate)
7168
return nothing
7269
else

src/symbolic_cliffords.jl

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -415,12 +415,9 @@ LinearAlgebra.inv(op::sInvSQRTZZ) = sSQRTZZ(op.q1, op.q2)
415415
"""Apply a Pauli Z to the `i`-th qubit of state `s`. You should use `apply!(stab,sZ(i))` instead of this."""
416416
function apply_single_z!(stab::AbstractStabilizer, i)
417417
s = tab(stab)
418-
Tₘₑ = eltype(s.xzs)
419-
bigi = _div(Tₘₑ,i-1)+1
420-
smalli = _mod(Tₘₑ,i-1)
421-
mask = Tₘₑ(0x1)<<smalli
418+
_, ibig, _, ismallm = get_bitmask_idxs(s.xzs,i)
422419
@inbounds @simd for row in 1:size(s.xzs,2)
423-
if !iszero(s.xzs[bigi,row] & mask)
420+
if !iszero(s.xzs[ibig,row] & ismallm)
424421
s.phases[row] = (s.phases[row]+0x2)&0x3
425422
end
426423
end
@@ -430,12 +427,9 @@ end
430427
"""Apply a Pauli X to the `i`-th qubit of state `s`. You should use `apply!(stab,sX(i))` instead of this."""
431428
function apply_single_x!(stab::AbstractStabilizer, i)
432429
s = tab(stab)
433-
Tₘₑ = eltype(s.xzs)
434-
bigi = _div(Tₘₑ,i-1)+1
435-
smalli = _mod(Tₘₑ,i-1)
436-
mask = Tₘₑ(0x1)<<smalli
430+
_, ibig, _, ismallm = get_bitmask_idxs(s.xzs,i)
437431
@inbounds @simd for row in 1:size(s.xzs,2)
438-
if !iszero(s.xzs[end÷2+bigi,row] & mask)
432+
if !iszero(s.xzs[end÷2+ibig,row] & ismallm)
439433
s.phases[row] = (s.phases[row]+0x2)&0x3
440434
end
441435
end
@@ -445,12 +439,9 @@ end
445439
"""Apply a Pauli Y to the `i`-th qubit of state `s`. You should use `apply!(stab,sY(i))` instead of this."""
446440
function apply_single_y!(stab::AbstractStabilizer, i)
447441
s = tab(stab)
448-
Tₘₑ = eltype(s.xzs)
449-
bigi = _div(Tₘₑ,i-1)+1
450-
smalli = _mod(Tₘₑ,i-1)
451-
mask = Tₘₑ(0x1)<<smalli
442+
_, ibig, _, ismallm = get_bitmask_idxs(s.xzs,i)
452443
@inbounds @simd for row in 1:size(s.xzs,2)
453-
if !iszero((s.xzs[bigi,row] & mask) (s.xzs[end÷2+bigi,row] & mask))
444+
if !iszero((s.xzs[ibig,row] & ismallm) (s.xzs[end÷2+ibig,row] & ismallm))
454445
s.phases[row] = (s.phases[row]+0x2)&0x3
455446
end
456447
end

0 commit comments

Comments
 (0)