Skip to content

Commit 0cba4c9

Browse files
authored
fix wrappign around pow (#195)
1 parent 1cfb8e7 commit 0cba4c9

File tree

2 files changed

+30
-5
lines changed

2 files changed

+30
-5
lines changed

src/latexify/codegen/expression_codegen.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -408,16 +408,22 @@ def visit_Call(self, node: ast.Call) -> str:
408408

409409
if rule.is_unary and len(node.args) == 1:
410410
# Unary function. Applies the same wrapping policy with the unary operators.
411+
precedence = expression_rules.get_precedence(node)
412+
arg = node.args[0]
411413
# NOTE(odashi):
412414
# Factorial "x!" is treated as a special case: it requires both inner/outer
413415
# parentheses for correct interpretation.
414-
precedence = expression_rules.get_precedence(node)
415-
arg = node.args[0]
416-
force_wrap = isinstance(arg, ast.Call) and (
416+
force_wrap_factorial = isinstance(arg, ast.Call) and (
417417
func_name == "factorial"
418418
or ast_utils.extract_function_name_or_none(arg) == "factorial"
419419
)
420-
arg_latex = self._wrap_operand(arg, precedence, force_wrap)
420+
# Note(odashi):
421+
# Wrapping is also required if the argument is pow.
422+
# https://github.com/google/latexify_py/issues/189
423+
force_wrap_pow = isinstance(arg, ast.BinOp) and isinstance(arg.op, ast.Pow)
424+
arg_latex = self._wrap_operand(
425+
arg, precedence, force_wrap_factorial or force_wrap_pow
426+
)
421427
elements = [rule.left, arg_latex, rule.right]
422428
else:
423429
arg_latex = ", ".join(self.visit(arg) for arg in node.args)
@@ -490,7 +496,7 @@ def _wrap_operand(
490496
latex = self.visit(child)
491497
child_prec = expression_rules.get_precedence(child)
492498

493-
if child_prec < parent_prec or force_wrap and child_prec == parent_prec:
499+
if force_wrap or child_prec < parent_prec:
494500
return rf"\mathopen{{}}\left( {latex} \mathclose{{}}\right)"
495501

496502
return latex

src/latexify/codegen/expression_codegen_test.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,25 @@ def test_visit_call(code: str, latex: str) -> None:
218218
assert expression_codegen.ExpressionCodegen().visit(node) == latex
219219

220220

221+
@pytest.mark.parametrize(
222+
"code,latex",
223+
[
224+
("log(x)**2", r"\mathopen{}\left( \log x \mathclose{}\right)^{2}"),
225+
("log(x**2)", r"\log \mathopen{}\left( x^{2} \mathclose{}\right)"),
226+
(
227+
"log(x**2)**3",
228+
r"\mathopen{}\left("
229+
r" \log \mathopen{}\left( x^{2} \mathclose{}\right)"
230+
r" \mathclose{}\right)^{3}",
231+
),
232+
],
233+
)
234+
def test_visit_call_with_pow(code: str, latex: str) -> None:
235+
node = ast_utils.parse_expr(code)
236+
assert isinstance(node, (ast.Call, ast.BinOp))
237+
assert expression_codegen.ExpressionCodegen().visit(node) == latex
238+
239+
221240
@pytest.mark.parametrize(
222241
"src_suffix,dest_suffix",
223242
[

0 commit comments

Comments
 (0)