@@ -557,14 +557,38 @@ EvalExprNode binarize(Product const& prod) {
557
557
auto left = fold_left_to_node (factors | move, make_prod);
558
558
auto right = binarize (Constant{prod.scalar ()});
559
559
560
- auto h = left->hash_value ();
561
- hash::combine (h, right->hash_value ());
562
- auto result = EvalExpr{EvalOp::Prod, //
563
- left->result_type (), //
564
- left->expr (), //
565
- left->canon_indices (), //
566
- 1 , //
567
- h};
560
+ auto result = [&]() -> EvalExpr {
561
+ assert (!factors.empty ());
562
+ EvalExpr res = *left;
563
+
564
+ if (factors.size () == 1 ) {
565
+ // Special case for when the product is just a scalar times a leaf
566
+ // In this case, we need to make sure to use a different label for
567
+ // the result (otherwise, we'd have the semantic of this expression
568
+ // being meant to overwrite the contained leaf with the scaled version)
569
+ auto imed = make_imed (*left, *right, EvalOp::Prod);
570
+
571
+ if (imed.is <Constant>()) {
572
+ res = EvalExpr (imed.as <Constant>());
573
+ } else if (imed.is <Variable>()) {
574
+ res = EvalExpr (imed.as <Variable>());
575
+ } else if (imed.is <Tensor>()) {
576
+ res = EvalExpr (imed.as <Tensor>());
577
+ } else {
578
+ throw std::runtime_error (
579
+ " Encountered unexpected intermediate type during binarization" );
580
+ }
581
+ }
582
+
583
+ auto h = res.hash_value ();
584
+ hash::combine (h, right->hash_value ());
585
+ return EvalExpr{EvalOp::Prod, //
586
+ res.result_type (), //
587
+ res.expr (), //
588
+ res.canon_indices (), //
589
+ 1 , //
590
+ h};
591
+ }();
568
592
return EvalExprNode{std::move (result), std::move (left), std::move (right)};
569
593
}
570
594
}
0 commit comments