@@ -25,6 +25,7 @@ def __init__(
25
25
use_math_symbols : bool = False ,
26
26
use_signature : bool = True ,
27
27
use_set_symbols : bool = False ,
28
+ _match_subject_stack : list [str ] = [],
28
29
) -> None :
29
30
"""Initializer.
30
31
@@ -34,6 +35,7 @@ def __init__(
34
35
use_signature: Whether to add the function signature before the expression
35
36
or not.
36
37
use_set_symbols: Whether to use set symbols or not.
38
+ _match_subject_stack: a stack of subject names that are used in match
37
39
"""
38
40
self ._expression_codegen = expression_codegen .ExpressionCodegen (
39
41
use_math_symbols = use_math_symbols , use_set_symbols = use_set_symbols
@@ -42,6 +44,7 @@ def __init__(
42
44
use_math_symbols = use_math_symbols
43
45
)
44
46
self ._use_signature = use_signature
47
+ self ._match_subject_stack = _match_subject_stack
45
48
46
49
def generic_visit (self , node : ast .AST ) -> str :
47
50
raise exceptions .LatexifyNotSupportedError (
@@ -141,6 +144,9 @@ def visit_If(self, node: ast.If) -> str:
141
144
142
145
def visit_Match (self , node : ast .Match ) -> str :
143
146
"""Visit a Match node."""
147
+ subject_latex = self ._expression_codegen .visit (node .subject )
148
+ self ._match_subject_stack .append (subject_latex )
149
+
144
150
if not (
145
151
len (node .cases ) >= 2
146
152
and isinstance (node .cases [- 1 ].pattern , ast .MatchAs )
@@ -162,8 +168,6 @@ def visit_Match(self, node: ast.Match) -> str:
162
168
if i < len (node .cases ) - 1 :
163
169
body_latex = self .visit (case .body [0 ])
164
170
cond_latex = self .visit (case .pattern )
165
- # if case.guard is not None:
166
- # cond_latex = self._expression_codegen.visit(case.guard)
167
171
168
172
case_latexes .append (body_latex + r", & \mathrm{if} \ " + cond_latex )
169
173
else :
@@ -177,21 +181,14 @@ def visit_Match(self, node: ast.Match) -> str:
177
181
+ r" \end{array} \right."
178
182
)
179
183
180
- latex_final = latex . replace ( "subject_name" , subject_latex )
181
- return latex_final
184
+ self . _match_subject_stack . pop ( )
185
+ return latex
182
186
183
187
def visit_MatchValue (self , node : ast .MatchValue ) -> str :
184
188
"""Visit a MatchValue node."""
185
189
latex = self ._expression_codegen .visit (node .value )
186
- return "subject_name = " + latex
190
+ return self . _match_subject_stack [ - 1 ] + " = " + latex
187
191
188
192
def visit_MatchOr (self , node : ast .MatchOr ) -> str :
189
193
"""Visit a MatchOr node."""
190
- # case_latexes = []
191
- # for i, pattern in enumerate(node.patterns):
192
- # if i == 0:
193
- # case_latexes.append(self.visit(pattern))
194
- # else:
195
- # case_latexes.append(r" \lor " + self.visit(pattern))
196
- # return "".join(case_latexes)
197
194
return r" \lor " .join (self .visit (p ) for p in node .patterns )
0 commit comments