Skip to content

Adds parse_expr utility #144

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Dec 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/latexify/analyzers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import pytest

from latexify import analyzers, exceptions, test_utils
from latexify import analyzers, ast_utils, exceptions, test_utils


@test_utils.require_at_least(8)
Expand Down Expand Up @@ -105,7 +105,7 @@ def test_analyze_range(
stop_int: int | None,
step_int: int | None,
) -> None:
node = ast.parse(code).body[0].value
node = ast_utils.parse_expr(code)
assert isinstance(node, ast.Call)

info = analyzers.analyze_range(node)
Expand Down Expand Up @@ -143,7 +143,7 @@ def check_int(observed: int | None, expected: int | None) -> None:
],
)
def test_analyze_range_invalid(code: str) -> None:
node = ast.parse(code).body[0].value
node = ast_utils.parse_expr(code)
assert isinstance(node, ast.Call)

with pytest.raises(
Expand Down
12 changes: 12 additions & 0 deletions src/latexify/ast_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,18 @@
from typing import Any


def parse_expr(code: str) -> ast.expr:
"""Parses given Python expression.

Args:
code: Python expression to parse.

Returns:
ast.expr corresponding to `code`.
"""
return ast.parse(code, mode="eval").body


def make_name(id: str) -> ast.Name:
"""Generates a new Name node.

Expand Down
11 changes: 11 additions & 0 deletions src/latexify/ast_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,17 @@
from latexify import ast_utils, test_utils


def test_parse_expr() -> None:
test_utils.assert_ast_equal(
ast_utils.parse_expr("a + b"),
ast.BinOp(
left=ast_utils.make_name("a"),
op=ast.Add(),
right=ast_utils.make_name("b"),
),
)


def test_make_name() -> None:
test_utils.assert_ast_equal(
ast_utils.make_name("foo"), ast.Name(id="foo", ctx=ast.Load())
Expand Down
34 changes: 17 additions & 17 deletions src/latexify/codegen/function_codegen_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import pytest

from latexify import exceptions, test_utils
from latexify import ast_utils, exceptions, test_utils
from latexify.codegen import FunctionCodegen, function_codegen


Expand Down Expand Up @@ -116,7 +116,7 @@ def f(x):
],
)
def test_visit_listcomp(code: str, latex: str) -> None:
node = ast.parse(code).body[0].value
node = ast_utils.parse_expr(code)
assert isinstance(node, ast.ListComp)
assert FunctionCodegen().visit(node) == latex

Expand Down Expand Up @@ -163,7 +163,7 @@ def test_visit_listcomp(code: str, latex: str) -> None:
],
)
def test_visit_setcomp(code: str, latex: str) -> None:
node = ast.parse(code).body[0].value
node = ast_utils.parse_expr(code)
assert isinstance(node, ast.SetComp)
assert FunctionCodegen().visit(node) == latex

Expand Down Expand Up @@ -222,7 +222,7 @@ def test_visit_setcomp(code: str, latex: str) -> None:
)
def test_visit_call_sum_prod(src_suffix: str, dest_suffix: str) -> None:
for src_fn, dest_fn in [("sum", r"\sum"), ("prod", r"\prod")]:
node = ast.parse(src_fn + src_suffix).body[0].value
node = ast_utils.parse_expr(src_fn + src_suffix)
assert isinstance(node, ast.Call)
assert FunctionCodegen().visit(node) == dest_fn + dest_suffix

Expand Down Expand Up @@ -273,7 +273,7 @@ def test_visit_call_sum_prod(src_suffix: str, dest_suffix: str) -> None:
],
)
def test_visit_call_sum_prod_multiple_comprehension(code: str, latex: str) -> None:
node = ast.parse(code).body[0].value
node = ast_utils.parse_expr(code)
assert isinstance(node, ast.Call)
assert FunctionCodegen().visit(node) == latex

Expand All @@ -299,7 +299,7 @@ def test_visit_call_sum_prod_multiple_comprehension(code: str, latex: str) -> No
)
def test_visit_call_sum_prod_with_if(src_suffix: str, dest_suffix: str) -> None:
for src_fn, dest_fn in [("sum", r"\sum"), ("prod", r"\prod")]:
node = ast.parse(src_fn + src_suffix).body[0].value
node = ast_utils.parse_expr(src_fn + src_suffix)
assert isinstance(node, ast.Call)
assert FunctionCodegen().visit(node) == dest_fn + dest_suffix

Expand Down Expand Up @@ -331,7 +331,7 @@ def test_visit_call_sum_prod_with_if(src_suffix: str, dest_suffix: str) -> None:
],
)
def test_if_then_else(code: str, latex: str) -> None:
node = ast.parse(code).body[0].value
node = ast_utils.parse_expr(code)
assert isinstance(node, ast.IfExp)
assert FunctionCodegen().visit(node) == latex

Expand Down Expand Up @@ -507,7 +507,7 @@ def test_if_then_else(code: str, latex: str) -> None:
],
)
def test_visit_binop(code: str, latex: str) -> None:
tree = ast.parse(code).body[0].value
tree = ast_utils.parse_expr(code)
assert isinstance(tree, ast.BinOp)
assert function_codegen.FunctionCodegen().visit(tree) == latex

Expand Down Expand Up @@ -546,7 +546,7 @@ def test_visit_binop(code: str, latex: str) -> None:
],
)
def test_visit_unaryop(code: str, latex: str) -> None:
tree = ast.parse(code).body[0].value
tree = ast_utils.parse_expr(code)
assert isinstance(tree, ast.UnaryOp)
assert function_codegen.FunctionCodegen().visit(tree) == latex

Expand Down Expand Up @@ -600,7 +600,7 @@ def test_visit_unaryop(code: str, latex: str) -> None:
],
)
def test_visit_compare(code: str, latex: str) -> None:
tree = ast.parse(code).body[0].value
tree = ast_utils.parse_expr(code)
assert isinstance(tree, ast.Compare)
assert function_codegen.FunctionCodegen().visit(tree) == latex

Expand Down Expand Up @@ -646,7 +646,7 @@ def test_visit_compare(code: str, latex: str) -> None:
],
)
def test_visit_boolop(code: str, latex: str) -> None:
tree = ast.parse(code).body[0].value
tree = ast_utils.parse_expr(code)
assert isinstance(tree, ast.BoolOp)
assert function_codegen.FunctionCodegen().visit(tree) == latex

Expand All @@ -671,7 +671,7 @@ def test_visit_boolop(code: str, latex: str) -> None:
],
)
def test_visit_constant_lagacy(code: str, cls: type[ast.expr], latex: str) -> None:
tree = ast.parse(code).body[0].value
tree = ast_utils.parse_expr(code)
assert isinstance(tree, cls)
assert function_codegen.FunctionCodegen().visit(tree) == latex

Expand All @@ -696,7 +696,7 @@ def test_visit_constant_lagacy(code: str, cls: type[ast.expr], latex: str) -> No
],
)
def test_visit_constant(code: str, latex: str) -> None:
tree = ast.parse(code).body[0].value
tree = ast_utils.parse_expr(code)
assert isinstance(tree, ast.Constant)


Expand All @@ -711,7 +711,7 @@ def test_visit_constant(code: str, latex: str) -> None:
],
)
def test_visit_subscript(code: str, latex: str) -> None:
tree = ast.parse(code).body[0].value
tree = ast_utils.parse_expr(code)
assert isinstance(tree, ast.Subscript)
assert function_codegen.FunctionCodegen().visit(tree) == latex

Expand All @@ -726,7 +726,7 @@ def test_visit_subscript(code: str, latex: str) -> None:
],
)
def test_use_set_symbols_binop(code: str, latex: str) -> None:
tree = ast.parse(code).body[0].value
tree = ast_utils.parse_expr(code)
assert isinstance(tree, ast.BinOp)
assert function_codegen.FunctionCodegen(use_set_symbols=True).visit(tree) == latex

Expand All @@ -741,7 +741,7 @@ def test_use_set_symbols_binop(code: str, latex: str) -> None:
],
)
def test_use_set_symbols_compare(code: str, latex: str) -> None:
tree = ast.parse(code).body[0].value
tree = ast_utils.parse_expr(code)
assert isinstance(tree, ast.Compare)
assert function_codegen.FunctionCodegen(use_set_symbols=True).visit(tree) == latex

Expand Down Expand Up @@ -784,6 +784,6 @@ def test_use_set_symbols_compare(code: str, latex: str) -> None:
],
)
def test_numpy_array(code: str, latex: str) -> None:
tree = ast.parse(code).body[0].value
tree = ast_utils.parse_expr(code)
assert isinstance(tree, ast.Call)
assert function_codegen.FunctionCodegen().visit(tree) == latex