Skip to content

Commit cfa173b

Browse files
authored
Merge pull request ValeevGroup#286 from Krzmbrzl/fix-eval-expr-root
Fix scalar * tensor binarizing to tree with tensor as root
2 parents e01d2c1 + df50dea commit cfa173b

File tree

2 files changed

+44
-8
lines changed

2 files changed

+44
-8
lines changed

SeQuant/core/eval_expr.cpp

+32-8
Original file line numberDiff line numberDiff line change
@@ -557,14 +557,38 @@ EvalExprNode binarize(Product const& prod) {
557557
auto left = fold_left_to_node(factors | move, make_prod);
558558
auto right = binarize(Constant{prod.scalar()});
559559

560-
auto h = left->hash_value();
561-
hash::combine(h, right->hash_value());
562-
auto result = EvalExpr{EvalOp::Prod, //
563-
left->result_type(), //
564-
left->expr(), //
565-
left->canon_indices(), //
566-
1, //
567-
h};
560+
auto result = [&]() -> EvalExpr {
561+
assert(!factors.empty());
562+
EvalExpr res = *left;
563+
564+
if (factors.size() == 1) {
565+
// Special case for when the product is just a scalar times a leaf
566+
// In this case, we need to make sure to use a different label for
567+
// the result (otherwise, we'd have the semantic of this expression
568+
// being meant to overwrite the contained leaf with the scaled version)
569+
auto imed = make_imed(*left, *right, EvalOp::Prod);
570+
571+
if (imed.is<Constant>()) {
572+
res = EvalExpr(imed.as<Constant>());
573+
} else if (imed.is<Variable>()) {
574+
res = EvalExpr(imed.as<Variable>());
575+
} else if (imed.is<Tensor>()) {
576+
res = EvalExpr(imed.as<Tensor>());
577+
} else {
578+
throw std::runtime_error(
579+
"Encountered unexpected intermediate type during binarization");
580+
}
581+
}
582+
583+
auto h = res.hash_value();
584+
hash::combine(h, right->hash_value());
585+
return EvalExpr{EvalOp::Prod, //
586+
res.result_type(), //
587+
res.expr(), //
588+
res.canon_indices(), //
589+
1, //
590+
h};
591+
}();
568592
return EvalExprNode{std::move(result), std::move(left), std::move(right)};
569593
}
570594
}

tests/unit/test_eval_expr.cpp

+12
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,18 @@ TEST_CASE("eval_expr", "[EvalExpr]") {
127127
) == ResultType::Scalar);
128128
}
129129

130+
SECTION("result expr") {
131+
ExprPtr expr = parse_expr(L"2 var");
132+
ExprPtr res = binarize(expr)->expr();
133+
REQUIRE(res->is<Variable>());
134+
REQUIRE(*res != *expr);
135+
136+
expr = parse_expr(L"2 t{a1;i1}");
137+
res = binarize(expr)->expr();
138+
REQUIRE(res->is<Tensor>());
139+
REQUIRE(*res != *expr);
140+
}
141+
130142
SECTION("Sequant expression") {
131143
const auto& str_t1 = L"g_{a1,a2}^{a3,a4}";
132144
const auto& str_t2 = L"t_{a3,a4}^{i1,i2}";

0 commit comments

Comments
 (0)