Skip to content

Commit 8acf74b

Browse files
committed
add lazy diag
1 parent aa88b3f commit 8acf74b

File tree

3 files changed

+34
-1
lines changed

3 files changed

+34
-1
lines changed

autoray/lazy/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
tensordot,
1111
einsum,
1212
trace,
13+
diag,
1314
matmul,
1415
kron,
1516
clip,
@@ -69,6 +70,7 @@
6970
"einsum",
7071
"conj",
7172
"trace",
73+
"diag",
7274
"matmul",
7375
"kron",
7476
"clip",

autoray/lazy/core.py

+18-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
get_dtype_name,
1313
register_function,
1414
astype,
15-
complex_add_re_im,
1615
)
1716

1817

@@ -1119,6 +1118,24 @@ def trace(a):
11191118
)
11201119

11211120

1121+
@lazy_cache("diag")
1122+
def diag(a):
1123+
a = ensure_lazy(a)
1124+
1125+
if a.ndim == 1:
1126+
new_shape = (a.size, a.size)
1127+
elif a.ndim == 2:
1128+
new_shape = (min(a.shape),)
1129+
else:
1130+
raise ValueError("Input must be 1- or 2-d.")
1131+
1132+
return a.to(
1133+
fn=get_lib_fn(a.backend, "diag"),
1134+
args=(a,),
1135+
shape=new_shape,
1136+
)
1137+
1138+
11221139
@lazy_cache("matmul")
11231140
def matmul(x1, x2):
11241141
backend = find_common_backend(x1, x2)

tests/test_lazy.py

+14
Original file line numberDiff line numberDiff line change
@@ -379,6 +379,20 @@ def test_indexing():
379379
assert_allclose(a[key], b[key].compute())
380380

381381

382+
@pytest.mark.parametrize('shape', [
383+
(3,),
384+
(2, 2),
385+
(3, 4),
386+
(4, 3),
387+
])
388+
def test_diag(shape):
389+
a = do('random.uniform', size=shape, like='numpy')
390+
b = lazy.array(a)
391+
ad = do('diag', a)
392+
bd = do('diag', b)
393+
assert_allclose(ad, bd.compute())
394+
395+
382396
def test_einsum():
383397
a = do('random.uniform', size=(2, 3, 4, 5), like='numpy')
384398
b = do('random.uniform', size=(4, 5), like='numpy')

0 commit comments

Comments
 (0)