diff --git a/src/latexify/codegen/function_codegen.py b/src/latexify/codegen/function_codegen.py index 2d6355d..e0fdcc5 100644 --- a/src/latexify/codegen/function_codegen.py +++ b/src/latexify/codegen/function_codegen.py @@ -541,10 +541,17 @@ def visit_Match(self, node: ast.Match) -> str: latex = r"\left\{ \begin{array}{ll} " subject_latex = self.visit(node.subject) for i, match_case in enumerate(node.cases): + if len(match_case.body) != 1: + raise exceptions.LatexifySyntaxError( + "Multiple statements are not supported in Match nodes." + ) true_latex = self.visit(match_case.body[0]) cond_latex = self.visit(match_case.pattern) - if i < len(node.cases) - 1: # no wildcard + if i < len(node.cases)-1: # no wildcard + if (match_case.guard): + cond_latex = self.visit(match_case.guard) + subject_latex = "" # getting variable from cond_latex if not cond_latex: raise exceptions.LatexifySyntaxError( "Match subtrees must contain only one wildcard at the end." @@ -563,8 +570,9 @@ def visit_Match(self, node: ast.Match) -> str: def visit_MatchValue(self, node: ast.MatchValue) -> str: """Visit a MatchValue node""" latex = self.visit(node.value) - return r"subject_name = " + latex + return "subject_name = " + latex + def visit_MatchAs(self, node: ast.MatchAs) -> str: """Visit a MatchAs node""" """If MatchAs is a wildcard, return 'otherwise' case, else throw error""" diff --git a/src/latexify/codegen/function_codegen_test.py b/src/latexify/codegen/function_codegen_test.py index 2660ad8..064b7be 100644 --- a/src/latexify/codegen/function_codegen_test.py +++ b/src/latexify/codegen/function_codegen_test.py @@ -831,6 +831,26 @@ def test_multiple_matchvalue_no_wildcards() -> None: FunctionCodegen().visit(tree) +def test_multiple_matchvalue_no_wildcards() -> None: + tree = ast.parse( + textwrap.dedent( + """ + match x: + case 0: + return 1 + case 1: + return 2 + """ + ) + ).body[0] + + with pytest.raises( + exceptions.LatexifySyntaxError, + match=r"Match subtrees must contain only one wildcard at the end.", + ): + FunctionCodegen().visit(tree) + + @test_utils.require_at_least(10) def test_multiple_matchas_wildcards() -> None: tree = ast.parse( @@ -896,6 +916,88 @@ def test_matchas_nonempty_end() -> None: FunctionCodegen().visit(tree) +def test_matchvalue_mutliple_statements() -> None: + tree = ast.parse( + textwrap.dedent( + """ + match x: + case 0: + x = 5 + return 1 + case 1: + return 2 + """ + ) + ).body[0] + + with pytest.raises( + exceptions.LatexifySyntaxError, + match=r"Multiple statements are not supported in Match nodes.", + ): + FunctionCodegen().visit(tree) + +def test_matchcase_with_guard() -> None: + tree = ast.parse( + textwrap.dedent( + """ + match x: + case x if x>0: + return 1 + case _: + return 2 + """ + ) + ).body[0] + + assert FunctionCodegen().visit(tree) == r"\left\{ \begin{array}{ll} {1}, & \mathrm{if} \ {x > {0}} \\ {2}, & \mathrm{otherwise}\end{array} \right." + +def test_matchcase_with_and_guard() -> None: + tree = ast.parse( + textwrap.dedent( + """ + match x: + case x if x>0 and x<=10: + return 1 + case _: + return 2 + """ + ) + ).body[0] + + assert FunctionCodegen().visit(tree) == r"\left\{ \begin{array}{ll} {1}, & \mathrm{if} \ {{x > {0}} \land {x \le {10}}} \\ {2}, & \mathrm{otherwise}\end{array} \right." + +def test_matchcase_with_or_guard() -> None: + tree = ast.parse( + textwrap.dedent( + """ + match x: + case x if x>0 or x<=10: + return 1 + case _: + return 2 + """ + ) + ).body[0] + + assert FunctionCodegen().visit(tree) == r"\left\{ \begin{array}{ll} {1}, & \mathrm{if} \ {{x > {0}} \lor {x \le {10}}} \\ {2}, & \mathrm{otherwise}\end{array} \right." + +def test_matchcase_with_multiple_guards() -> None: + tree = ast.parse( + textwrap.dedent( + """ + match x: + case x if 0 < x <= 10: + return 1 + case _: + return 2 + """ + ) + ).body[0] + + + assert FunctionCodegen().visit(tree) == r"\left\{ \begin{array}{ll} {1}, & \mathrm{if} \ {{0} < x \le {10}} \\ {2}, & \mathrm{otherwise}\end{array} \right." + + @test_utils.require_at_least(10) def test_matchor() -> None: tree = ast.parse(