Skip to content

Commit f44ccd8

Browse files
committed
fix
1 parent 992034f commit f44ccd8

File tree

3 files changed

+79
-62
lines changed

3 files changed

+79
-62
lines changed

src/enzyme_ad/jax/Implementations/StableHLOAutoDiffOpInterfaceImpl.cpp

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -203,27 +203,36 @@ class AutoDiffBroadcastInDimRev
203203
SmallVector<int64_t> bcastDims(op.getBroadcastDimensions().begin(),
204204
op.getBroadcastDimensions().end());
205205

206-
SmallVector<int64_t> newDims;
207-
SmallVector<int64_t> reduceShape;
206+
SmallVector<int64_t> reducedDims;
207+
SmallVector<int64_t> iterShape;
208208
for (auto en : llvm::enumerate(outTy.getShape())) {
209-
if (llvm::is_contained(bcastDims, en.index())) {
210-
if (en.value() != 1) {
211-
newDims.push_back(en.index());
209+
ssize_t bcastIdx = -1;
210+
for (auto en2 : llvm::enumerate(bcastDims)) {
211+
if (en2.value() == en.index()) {
212+
bcastIdx = en2.index();
213+
break;
214+
}
215+
}
216+
if (bcastIdx != -1) {
217+
if (en.value() != inTy.getShape()[bcastIdx]) {
218+
reducedDims.push_back(en.index());
219+
assert(inTy.getShape()[bcastIdx] == 1);
220+
} else {
221+
iterShape.push_back(inTy.getShape()[bcastIdx]);
212222
}
213223
continue;
214224
}
215-
reduceShape.push_back(en.value());
216-
newDims.push_back(en.index());
225+
reducedDims.push_back(en.index());
217226
}
218227

219-
auto reduceTy = RankedTensorType::get(reduceShape, inTy.getElementType());
228+
auto reduceTy = RankedTensorType::get(iterShape, inTy.getElementType());
220229

221230
Value zero = gutils->getShadowType(reduceTy)
222231
.cast<AutoDiffTypeInterface>()
223232
.createNullValue(builder, op.getLoc());
224233

225234
auto red = builder.create<ReduceOp>(op.getLoc(), TypeRange(zero.getType()),
226-
inDiffe, zero, newDims);
235+
inDiffe, zero, reducedDims);
227236
red.getBody().push_back(new Block());
228237
Block &body = red.getBody().front();
229238
OpBuilder bodyBuilder(orig->getContext());

test/bench_vs_xla.py

Lines changed: 53 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -70,13 +70,9 @@ def harness(self, name, in_fn, ins, dins, douts):
7070

7171
print(
7272
name + " JaX Primal: ",
73-
timeit.Timer(
74-
primalstr,
75-
globals={
76-
"fn": rfn_jax,
77-
}
78-
| primalins,
79-
).timeit(number)
73+
timeit.Timer(primalstr, globals={"fn": rfn_jax,} | primalins,).timeit(
74+
number
75+
)
8076
/ number,
8177
)
8278

@@ -97,13 +93,7 @@ def harness(self, name, in_fn, ins, dins, douts):
9793
fwdins = primalins | {("din" + str(i)): dins[0] for i in range(len(dins))}
9894
print(
9995
name + " JaX Fwd: ",
100-
timeit.Timer(
101-
fwdstr,
102-
globals={
103-
"fwd": fwd_jax,
104-
}
105-
| fwdins,
106-
).timeit(number)
96+
timeit.Timer(fwdstr, globals={"fwd": fwd_jax,} | fwdins,).timeit(number)
10797
/ number,
10898
)
10999

@@ -124,13 +114,7 @@ def harness(self, name, in_fn, ins, dins, douts):
124114

125115
print(
126116
name + " JaX Rev: ",
127-
timeit.Timer(
128-
revstr,
129-
globals={
130-
"rev": rev_jax,
131-
}
132-
| revins,
133-
).timeit(number)
117+
timeit.Timer(revstr, globals={"rev": rev_jax,} | revins,).timeit(number)
134118
/ number,
135119
)
136120

@@ -148,11 +132,7 @@ def harness(self, name, in_fn, ins, dins, douts):
148132
name,
149133
") Primal: ",
150134
timeit.Timer(
151-
primalstr,
152-
globals={
153-
"fn": rfn_enzyme,
154-
}
155-
| primalins,
135+
primalstr, globals={"fn": rfn_enzyme,} | primalins,
156136
).timeit(number)
157137
/ number,
158138
)
@@ -174,13 +154,9 @@ def harness(self, name, in_fn, ins, dins, douts):
174154
name + " EnzymeMLIR(",
175155
name,
176156
") Fwd: ",
177-
timeit.Timer(
178-
fwdstr,
179-
globals={
180-
"fwd": fwd_enzyme,
181-
}
182-
| fwdins,
183-
).timeit(number)
157+
timeit.Timer(fwdstr, globals={"fwd": fwd_enzyme,} | fwdins,).timeit(
158+
number
159+
)
184160
/ number,
185161
)
186162

@@ -198,13 +174,9 @@ def harness(self, name, in_fn, ins, dins, douts):
198174
name + " EnzymeMLIR(",
199175
name,
200176
") Rev: ",
201-
timeit.Timer(
202-
revstr,
203-
globals={
204-
"rev": rev_enzyme,
205-
}
206-
| revins,
207-
).timeit(number)
177+
timeit.Timer(revstr, globals={"rev": rev_enzyme,} | revins,).timeit(
178+
number
179+
)
208180
/ number,
209181
)
210182

@@ -291,7 +263,9 @@ class Slicing(EnzymeJaxTest):
291263
def setUp(self):
292264
dim = 3
293265
self.ins = [jnp.array(range(dim), dtype=jnp.float32).reshape(1, dim, 1)]
294-
self.dins = [jnp.array([i * i for i in range(dim)], dtype=jnp.float32).reshape(1, dim, 1)]
266+
self.dins = [
267+
jnp.array([i * i for i in range(dim)], dtype=jnp.float32).reshape(1, dim, 1)
268+
]
295269
self.douts = [jnp.array([i * i for i in range(dim)], dtype=jnp.float32)]
296270

297271
def nomlir(x):
@@ -311,16 +285,24 @@ def setUp(self):
311285
dim = 12
312286
self.ins = [jnp.array(range(dim), dtype=jnp.float32)]
313287
self.dins = [jnp.array([i * i for i in range(dim)], dtype=jnp.float32)]
314-
self.douts = [jnp.array([i * i for i in range(2*dim)], dtype=jnp.float32).reshape((2, dim))]
288+
self.douts = [
289+
jnp.array([i * i for i in range(2 * dim)], dtype=jnp.float32).reshape(
290+
(2, dim)
291+
)
292+
]
315293

316294
def nomlir(x):
317-
return [(name, a) for (name, a) in x if name != "NewXLAMLIR" and name != "NewXLA" and name != "OldXLA"]
295+
return [
296+
(name, a)
297+
for (name, a) in x
298+
if name != "NewXLAMLIR" and name != "NewXLA" and name != "OldXLA"
299+
]
318300

319301
self.revfilter = nomlir
320302

321303
def f(x):
322304
toconv2 = jnp.ones((dim, dim))
323-
k = jnp.einsum('jk,k->j', toconv2, x)
305+
k = jnp.einsum("jk,k->j", toconv2, x)
324306
kcl = jnp.zeros((1, dim))
325307
h = jnp.reshape(k, (1, dim))
326308
kcl = jnp.append(kcl, h, axis=0)
@@ -329,15 +311,24 @@ def f(x):
329311
self.fn = f
330312
self.name = "activitymismatch"
331313

314+
332315
class GenDot(EnzymeJaxTest):
333316
def setUp(self):
334317
dim = 12
335318
self.ins = [jnp.array(range(dim), dtype=jnp.float32)]
336319
self.dins = [jnp.array([i * i for i in range(dim)], dtype=jnp.float32)]
337-
self.douts = [jnp.array([i * i for i in range(2*dim)], dtype=jnp.float32).reshape((2, dim))]
320+
self.douts = [
321+
jnp.array([i * i for i in range(2 * dim)], dtype=jnp.float32).reshape(
322+
(2, dim)
323+
)
324+
]
338325

339326
def nomlir(x):
340-
return [(name, a) for (name, a) in x if name != "NewXLAMLIR" and name != "NewXLA" and name != "OldXLA"]
327+
return [
328+
(name, a)
329+
for (name, a) in x
330+
if name != "NewXLAMLIR" and name != "NewXLA" and name != "OldXLA"
331+
]
341332

342333
self.revfilter = nomlir
343334

@@ -349,7 +340,7 @@ def f(x):
349340
k = jnp.reshape(jnp.einsum("ijk,ik -> ij", toconv2, k_tmp), (dim,))
350341

351342
kcl = jnp.zeros((1, dim))
352-
343+
353344
h = jnp.reshape(k, (1, dim))
354345
kcl = jnp.append(kcl, h, axis=0)
355346
return kcl
@@ -361,12 +352,22 @@ def f(x):
361352
class Concat(EnzymeJaxTest):
362353
def setUp(self):
363354
dim = 12
364-
self.ins = [jnp.array(range(dim), dtype=jnp.float32), 10*jnp.array(range(dim), dtype=jnp.float32)]
365-
self.dins = [jnp.array([i * i for i in range(dim)], dtype=jnp.float32), jnp.array([i * i *i / 3. for i in range(dim)], dtype=jnp.float32)]
366-
self.douts = [jnp.array([i * i for i in range(2*dim)], dtype=jnp.float32)]
355+
self.ins = [
356+
jnp.array(range(dim), dtype=jnp.float32),
357+
10 * jnp.array(range(dim), dtype=jnp.float32),
358+
]
359+
self.dins = [
360+
jnp.array([i * i for i in range(dim)], dtype=jnp.float32),
361+
jnp.array([i * i * i / 3.0 for i in range(dim)], dtype=jnp.float32),
362+
]
363+
self.douts = [jnp.array([i * i for i in range(2 * dim)], dtype=jnp.float32)]
367364

368365
def nomlir(x):
369-
return [(name, a) for (name, a) in x if name != "NewXLAMLIR" and name != "NewXLA" and name != "OldXLA"]
366+
return [
367+
(name, a)
368+
for (name, a) in x
369+
if name != "NewXLAMLIR" and name != "NewXLA" and name != "OldXLA"
370+
]
370371

371372
self.revfilter = nomlir
372373

@@ -376,5 +377,6 @@ def f(x, y):
376377
self.fn = f
377378
self.name = "Concat"
378379

380+
379381
if __name__ == "__main__":
380382
absltest.main()

test/llama.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,7 @@ def sfn(x, weights, key_cache, value_cache):
349349
# mlir = jax.jit(partial(forward, config)).lower(1, weights, key_cache, value_cache).compiler_ir(dialect="mhlo")
350350

351351
if True:
352+
352353
@jax.jit
353354
def jfwd(x, dx, weights, dweights, kc, dkc, vc, dvc):
354355
return jax.jvp(jfunc, (x, weights, kc, vc), (x, weights, dkc, dvc))
@@ -411,8 +412,13 @@ def erev(x, weights, kc, vc, dx, dkc, dvc):
411412
jres = jrev(x, weights, key_cache, value_cache, dx, dkc, dvc)
412413
print("Jax rev", jres)
413414

414-
jrev2 = enzyme_jax.enzyme_jax_ir(argv=argv, pipeline_options=enzyme_jax.JaXPipeline("inline{default-pipeline=canonicalize max-iterations=4},"
415-
+ "canonicalize,cse,enzyme-hlo-opt,cse"))(jrev)
415+
jrev2 = enzyme_jax.enzyme_jax_ir(
416+
argv=argv,
417+
pipeline_options=enzyme_jax.JaXPipeline(
418+
"inline{default-pipeline=canonicalize max-iterations=4},"
419+
+ "canonicalize,cse,enzyme-hlo-opt,cse"
420+
),
421+
)(jrev)
416422

417423
jres2 = jrev2(x, weights, key_cache, value_cache, dx, dkc, dvc)
418424
print("Jax2 rev", jres2)

0 commit comments

Comments
 (0)