Skip to content

Commit 0daeb72

Browse files
committed
used _match_subject_stack as suggested
1 parent fd2ddee commit 0daeb72

File tree

1 file changed

+9
-12
lines changed

1 file changed

+9
-12
lines changed

src/latexify/codegen/function_codegen.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ def __init__(
2525
use_math_symbols: bool = False,
2626
use_signature: bool = True,
2727
use_set_symbols: bool = False,
28+
_match_subject_stack: list[str] = [],
2829
) -> None:
2930
"""Initializer.
3031
@@ -34,6 +35,7 @@ def __init__(
3435
use_signature: Whether to add the function signature before the expression
3536
or not.
3637
use_set_symbols: Whether to use set symbols or not.
38+
_match_subject_stack: a stack of subject names that are used in match
3739
"""
3840
self._expression_codegen = expression_codegen.ExpressionCodegen(
3941
use_math_symbols=use_math_symbols, use_set_symbols=use_set_symbols
@@ -42,6 +44,7 @@ def __init__(
4244
use_math_symbols=use_math_symbols
4345
)
4446
self._use_signature = use_signature
47+
self._match_subject_stack = _match_subject_stack
4548

4649
def generic_visit(self, node: ast.AST) -> str:
4750
raise exceptions.LatexifyNotSupportedError(
@@ -141,6 +144,9 @@ def visit_If(self, node: ast.If) -> str:
141144

142145
def visit_Match(self, node: ast.Match) -> str:
143146
"""Visit a Match node."""
147+
subject_latex = self._expression_codegen.visit(node.subject)
148+
self._match_subject_stack.append(subject_latex)
149+
144150
if not (
145151
len(node.cases) >= 2
146152
and isinstance(node.cases[-1].pattern, ast.MatchAs)
@@ -162,8 +168,6 @@ def visit_Match(self, node: ast.Match) -> str:
162168
if i < len(node.cases) - 1:
163169
body_latex = self.visit(case.body[0])
164170
cond_latex = self.visit(case.pattern)
165-
# if case.guard is not None:
166-
# cond_latex = self._expression_codegen.visit(case.guard)
167171

168172
case_latexes.append(body_latex + r", & \mathrm{if} \ " + cond_latex)
169173
else:
@@ -177,21 +181,14 @@ def visit_Match(self, node: ast.Match) -> str:
177181
+ r" \end{array} \right."
178182
)
179183

180-
latex_final = latex.replace("subject_name", subject_latex)
181-
return latex_final
184+
self._match_subject_stack.pop()
185+
return latex
182186

183187
def visit_MatchValue(self, node: ast.MatchValue) -> str:
184188
"""Visit a MatchValue node."""
185189
latex = self._expression_codegen.visit(node.value)
186-
return "subject_name = " + latex
190+
return self._match_subject_stack[-1] + " = " + latex
187191

188192
def visit_MatchOr(self, node: ast.MatchOr) -> str:
189193
"""Visit a MatchOr node."""
190-
# case_latexes = []
191-
# for i, pattern in enumerate(node.patterns):
192-
# if i == 0:
193-
# case_latexes.append(self.visit(pattern))
194-
# else:
195-
# case_latexes.append(r" \lor " + self.visit(pattern))
196-
# return "".join(case_latexes)
197194
return r" \lor ".join(self.visit(p) for p in node.patterns)

0 commit comments

Comments
 (0)