Skip to content

Commit b9ffe27

Browse files
authored
Restore support for ndim>2 vector arrays in vec_transform (#105)
* support ndim 3+ in vec_transform * refactor again * bump version
1 parent 3d8ee29 commit b9ffe27

File tree

3 files changed

+52
-42
lines changed

3 files changed

+52
-42
lines changed

pylinalg/vector.py

+21-41
Original file line numberDiff line numberDiff line change
@@ -106,50 +106,30 @@ def vec_transform(
106106
"""
107107

108108
matrix = np.asarray(matrix)
109-
vectors = np.asarray(vectors)
110-
111-
# this code has been micro-optimized for the 1D and 2D cases
112-
vectors_ndim = vectors.ndim
113-
if vectors_ndim > 2:
114-
raise ValueError("vectors must be a 1D or 2D array")
115-
116-
# determine if we are working with a batch of vectors
117-
batch = vectors_ndim != 1
118-
119-
# we don't need to work in homogeneous vector space
120-
# if matrix is purely affine and vectors is a single vector
121-
homogeneous = projection or batch
122-
123-
if homogeneous:
124-
vectors = vec_homogeneous(vectors, w=w)
125-
matmul_matrix = matrix
126-
else:
127-
# if we are not working in homogeneous space, it's
128-
# more efficient to matmul the 3x3 (scale + rotation)
129-
# part of the matrix with the vectors and then add
130-
# the translation part after
131-
matmul_matrix = matrix[:-1, :-1]
132-
133-
if batch:
109+
vectors = vec_homogeneous(vectors, w=w)
110+
111+
# yes, the ndim > 2 version can also handle ndim=1
112+
# and ndim=2, but it's slower
113+
if vectors.ndim == 1:
114+
vectors = matrix @ vectors
115+
if projection:
116+
vectors = vectors[:-1] / vectors[-1]
117+
else:
118+
vectors = vectors[:-1]
119+
elif vectors.ndim == 2:
134120
# transposing the vectors array performs better
135121
# than transposing the matrix
136-
vectors = (matmul_matrix @ vectors.T).T
122+
vectors = (matrix @ vectors.T).T
123+
if projection:
124+
vectors = vectors[:, :-1] / vectors[:, -1, None]
125+
else:
126+
vectors = vectors[:, :-1]
137127
else:
138-
vectors = matmul_matrix @ vectors
139-
if not homogeneous:
140-
# as alluded to before, we add the translation
141-
# part of the matrix after the matmul
142-
# if we are not working in homogeneous space
143-
vectors = vectors + matrix[:-1, -1]
144-
145-
if projection:
146-
# if we are projecting, we divide by the last
147-
# element of the vectors array
148-
vectors = vectors[..., :-1] / vectors[..., -1, None]
149-
elif homogeneous:
150-
# if we are NOT projecting but we are working in
151-
# homogeneous space, just drop the last element
152-
vectors = vectors[..., :-1]
128+
vectors = matrix @ vectors[..., None]
129+
if projection:
130+
vectors = vectors[..., :-1, 0] / vectors[..., -1, :]
131+
else:
132+
vectors = vectors[..., :-1, 0]
153133

154134
if out is None:
155135
out = vectors

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
[project]
44
name = "pylinalg"
5-
version = "0.6.6"
5+
version = "0.6.7"
66
description = "Linear algebra utilities for Python"
77
readme = "README.md"
88
license = { file = "LICENSE" }

tests/test_vectors.py

+30
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,36 @@ def test_vec_transform_projection_flag():
120120
npt.assert_array_equal(result, expected_out)
121121

122122

123+
def test_vec_transform_ndim():
124+
vectors_2d = np.array(
125+
[
126+
[1, 0, 0],
127+
[1, 2, 3],
128+
[1, 1, 1],
129+
[1, 1, -1],
130+
[0, 0, 0],
131+
[7, 8, -9],
132+
],
133+
dtype="f8",
134+
)
135+
translation = np.array([-1, 2, 2], dtype="f8")
136+
137+
vectors_3d = vectors_2d.reshape((3, 2, 3))
138+
vectors_4d = vectors_2d.reshape((6, 1, 1, 3))
139+
140+
expected_3d = vectors_3d + translation[None, None, :]
141+
expected_4d = vectors_4d + translation[None, None, None, :]
142+
143+
matrix = la.mat_from_translation(translation)
144+
145+
for projection in [True, False]:
146+
result = la.vec_transform(vectors_3d, matrix, projection=projection)
147+
npt.assert_array_equal(result, expected_3d)
148+
149+
result = la.vec_transform(vectors_4d, matrix, projection=projection)
150+
npt.assert_array_equal(result, expected_4d)
151+
152+
123153
@given(ct.test_spherical, none())
124154
@example((1, 0, np.pi / 2), (0, 0, 1))
125155
@example((1, np.pi / 2, np.pi / 2), (1, 0, 0))

0 commit comments

Comments
 (0)