Skip to content

Commit 9ed373a

Browse files
committed
fix
1 parent 992034f commit 9ed373a

File tree

3 files changed

+67
-22
lines changed

3 files changed

+67
-22
lines changed

src/enzyme_ad/jax/Implementations/StableHLOAutoDiffOpInterfaceImpl.cpp

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -203,27 +203,36 @@ class AutoDiffBroadcastInDimRev
203203
SmallVector<int64_t> bcastDims(op.getBroadcastDimensions().begin(),
204204
op.getBroadcastDimensions().end());
205205

206-
SmallVector<int64_t> newDims;
207-
SmallVector<int64_t> reduceShape;
206+
SmallVector<int64_t> reducedDims;
207+
SmallVector<int64_t> iterShape;
208208
for (auto en : llvm::enumerate(outTy.getShape())) {
209-
if (llvm::is_contained(bcastDims, en.index())) {
210-
if (en.value() != 1) {
211-
newDims.push_back(en.index());
209+
ssize_t bcastIdx = -1;
210+
for (auto en2 : llvm::enumerate(bcastDims)) {
211+
if (en2.value() == en.index()) {
212+
bcastIdx = en2.index();
213+
break;
214+
}
215+
}
216+
if (bcastIdx != -1) {
217+
if (en.value() != inTy.getShape()[bcastIdx]) {
218+
reducedDims.push_back(en.index());
219+
assert(inTy.getShape()[bcastIdx] == 1);
220+
} else {
221+
iterShape.push_back(inTy.getShape()[bcastIdx]);
212222
}
213223
continue;
214224
}
215-
reduceShape.push_back(en.value());
216-
newDims.push_back(en.index());
225+
reducedDims.push_back(en.index());
217226
}
218227

219-
auto reduceTy = RankedTensorType::get(reduceShape, inTy.getElementType());
228+
auto reduceTy = RankedTensorType::get(iterShape, inTy.getElementType());
220229

221230
Value zero = gutils->getShadowType(reduceTy)
222231
.cast<AutoDiffTypeInterface>()
223232
.createNullValue(builder, op.getLoc());
224233

225234
auto red = builder.create<ReduceOp>(op.getLoc(), TypeRange(zero.getType()),
226-
inDiffe, zero, newDims);
235+
inDiffe, zero, reducedDims);
227236
red.getBody().push_back(new Block());
228237
Block &body = red.getBody().front();
229238
OpBuilder bodyBuilder(orig->getContext());

test/bench_vs_xla.py

Lines changed: 41 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,9 @@ class Slicing(EnzymeJaxTest):
291291
def setUp(self):
292292
dim = 3
293293
self.ins = [jnp.array(range(dim), dtype=jnp.float32).reshape(1, dim, 1)]
294-
self.dins = [jnp.array([i * i for i in range(dim)], dtype=jnp.float32).reshape(1, dim, 1)]
294+
self.dins = [
295+
jnp.array([i * i for i in range(dim)], dtype=jnp.float32).reshape(1, dim, 1)
296+
]
295297
self.douts = [jnp.array([i * i for i in range(dim)], dtype=jnp.float32)]
296298

297299
def nomlir(x):
@@ -311,16 +313,24 @@ def setUp(self):
311313
dim = 12
312314
self.ins = [jnp.array(range(dim), dtype=jnp.float32)]
313315
self.dins = [jnp.array([i * i for i in range(dim)], dtype=jnp.float32)]
314-
self.douts = [jnp.array([i * i for i in range(2*dim)], dtype=jnp.float32).reshape((2, dim))]
316+
self.douts = [
317+
jnp.array([i * i for i in range(2 * dim)], dtype=jnp.float32).reshape(
318+
(2, dim)
319+
)
320+
]
315321

316322
def nomlir(x):
317-
return [(name, a) for (name, a) in x if name != "NewXLAMLIR" and name != "NewXLA" and name != "OldXLA"]
323+
return [
324+
(name, a)
325+
for (name, a) in x
326+
if name != "NewXLAMLIR" and name != "NewXLA" and name != "OldXLA"
327+
]
318328

319329
self.revfilter = nomlir
320330

321331
def f(x):
322332
toconv2 = jnp.ones((dim, dim))
323-
k = jnp.einsum('jk,k->j', toconv2, x)
333+
k = jnp.einsum("jk,k->j", toconv2, x)
324334
kcl = jnp.zeros((1, dim))
325335
h = jnp.reshape(k, (1, dim))
326336
kcl = jnp.append(kcl, h, axis=0)
@@ -329,15 +339,24 @@ def f(x):
329339
self.fn = f
330340
self.name = "activitymismatch"
331341

342+
332343
class GenDot(EnzymeJaxTest):
333344
def setUp(self):
334345
dim = 12
335346
self.ins = [jnp.array(range(dim), dtype=jnp.float32)]
336347
self.dins = [jnp.array([i * i for i in range(dim)], dtype=jnp.float32)]
337-
self.douts = [jnp.array([i * i for i in range(2*dim)], dtype=jnp.float32).reshape((2, dim))]
348+
self.douts = [
349+
jnp.array([i * i for i in range(2 * dim)], dtype=jnp.float32).reshape(
350+
(2, dim)
351+
)
352+
]
338353

339354
def nomlir(x):
340-
return [(name, a) for (name, a) in x if name != "NewXLAMLIR" and name != "NewXLA" and name != "OldXLA"]
355+
return [
356+
(name, a)
357+
for (name, a) in x
358+
if name != "NewXLAMLIR" and name != "NewXLA" and name != "OldXLA"
359+
]
341360

342361
self.revfilter = nomlir
343362

@@ -349,7 +368,7 @@ def f(x):
349368
k = jnp.reshape(jnp.einsum("ijk,ik -> ij", toconv2, k_tmp), (dim,))
350369

351370
kcl = jnp.zeros((1, dim))
352-
371+
353372
h = jnp.reshape(k, (1, dim))
354373
kcl = jnp.append(kcl, h, axis=0)
355374
return kcl
@@ -361,12 +380,22 @@ def f(x):
361380
class Concat(EnzymeJaxTest):
362381
def setUp(self):
363382
dim = 12
364-
self.ins = [jnp.array(range(dim), dtype=jnp.float32), 10*jnp.array(range(dim), dtype=jnp.float32)]
365-
self.dins = [jnp.array([i * i for i in range(dim)], dtype=jnp.float32), jnp.array([i * i *i / 3. for i in range(dim)], dtype=jnp.float32)]
366-
self.douts = [jnp.array([i * i for i in range(2*dim)], dtype=jnp.float32)]
383+
self.ins = [
384+
jnp.array(range(dim), dtype=jnp.float32),
385+
10 * jnp.array(range(dim), dtype=jnp.float32),
386+
]
387+
self.dins = [
388+
jnp.array([i * i for i in range(dim)], dtype=jnp.float32),
389+
jnp.array([i * i * i / 3.0 for i in range(dim)], dtype=jnp.float32),
390+
]
391+
self.douts = [jnp.array([i * i for i in range(2 * dim)], dtype=jnp.float32)]
367392

368393
def nomlir(x):
369-
return [(name, a) for (name, a) in x if name != "NewXLAMLIR" and name != "NewXLA" and name != "OldXLA"]
394+
return [
395+
(name, a)
396+
for (name, a) in x
397+
if name != "NewXLAMLIR" and name != "NewXLA" and name != "OldXLA"
398+
]
370399

371400
self.revfilter = nomlir
372401

@@ -376,5 +405,6 @@ def f(x, y):
376405
self.fn = f
377406
self.name = "Concat"
378407

408+
379409
if __name__ == "__main__":
380410
absltest.main()

test/llama.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,7 @@ def sfn(x, weights, key_cache, value_cache):
349349
# mlir = jax.jit(partial(forward, config)).lower(1, weights, key_cache, value_cache).compiler_ir(dialect="mhlo")
350350

351351
if True:
352+
352353
@jax.jit
353354
def jfwd(x, dx, weights, dweights, kc, dkc, vc, dvc):
354355
return jax.jvp(jfunc, (x, weights, kc, vc), (x, weights, dkc, dvc))
@@ -411,8 +412,13 @@ def erev(x, weights, kc, vc, dx, dkc, dvc):
411412
jres = jrev(x, weights, key_cache, value_cache, dx, dkc, dvc)
412413
print("Jax rev", jres)
413414

414-
jrev2 = enzyme_jax.enzyme_jax_ir(argv=argv, pipeline_options=enzyme_jax.JaXPipeline("inline{default-pipeline=canonicalize max-iterations=4},"
415-
+ "canonicalize,cse,enzyme-hlo-opt,cse"))(jrev)
415+
jrev2 = enzyme_jax.enzyme_jax_ir(
416+
argv=argv,
417+
pipeline_options=enzyme_jax.JaXPipeline(
418+
"inline{default-pipeline=canonicalize max-iterations=4},"
419+
+ "canonicalize,cse,enzyme-hlo-opt,cse"
420+
),
421+
)(jrev)
416422

417423
jres2 = jrev2(x, weights, key_cache, value_cache, dx, dkc, dvc)
418424
print("Jax2 rev", jres2)

0 commit comments

Comments
 (0)