Skip to content

Commit 149d1c7

Browse files
authored
Systematically used localized expressions in the trace and add pdl__result in them (#760)
1 parent 883feec commit 149d1c7

File tree

8 files changed

+178
-75
lines changed

8 files changed

+178
-75
lines changed

pdl-live-react/src/pdl_ast.d.ts

+1
Original file line numberDiff line numberDiff line change
@@ -4055,6 +4055,7 @@ export interface ContributeValue {
40554055
}
40564056
export interface LocalizedExpression {
40574057
expr: Expr
4058+
pdl__result?: unknown
40584059
pdl__location?: PdlLocationType | null
40594060
}
40604061
export interface Expr {

src/pdl/pdl-schema.json

+10
Original file line numberDiff line numberDiff line change
@@ -7294,6 +7294,16 @@
72947294
"expr": {
72957295
"title": "Expr"
72967296
},
7297+
"pdl__result": {
7298+
"anyOf": [
7299+
{},
7300+
{
7301+
"type": "null"
7302+
}
7303+
],
7304+
"default": null,
7305+
"title": "Pdl Result"
7306+
},
72977307
"pdl__location": {
72987308
"anyOf": [
72997309
{

src/pdl/pdl_ast.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,8 @@ class LocalizedExpression(BaseModel, Generic[LocalizedExpressionT]):
8484
arbitrary_types_allowed=True,
8585
model_title_generator=(lambda _: "LocalizedExpression"),
8686
)
87-
expr: LocalizedExpressionT
87+
expr: Any
88+
pdl__result: Optional[LocalizedExpressionT] = None
8889
pdl__location: Optional[PdlLocationType] = None
8990

9091

src/pdl/pdl_dumper.py

+39-22
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
DataBlock,
1919
EmptyBlock,
2020
ErrorBlock,
21+
ExpressionType,
2122
FunctionBlock,
2223
GetBlock,
2324
GraniteioModelBlock,
@@ -29,6 +30,7 @@
2930
LastOfBlock,
3031
LitellmModelBlock,
3132
LitellmParameters,
33+
LocalizedExpression,
3234
MatchBlock,
3335
MessageBlock,
3436
ObjectBlock,
@@ -112,29 +114,28 @@ def block_to_dict( # noqa: C901
112114
match block:
113115
case LitellmModelBlock():
114116
d["platform"] = str(block.platform)
115-
d["model"] = block.model
116-
if block.input is not None:
117-
d["input"] = block_to_dict(block.input, json_compatible)
117+
d["model"] = expr_to_dict(block.model, json_compatible)
118+
d["input"] = block_to_dict(block.input, json_compatible)
118119
if block.parameters is not None:
119120
if isinstance(block.parameters, LitellmParameters):
120121
d["parameters"] = block.parameters.model_dump(
121122
exclude_unset=True, exclude_defaults=True
122123
)
123124
else:
124-
d["parameters"] = block.parameters
125+
d["parameters"] = expr_to_dict(block.parameters, json_compatible)
125126
if block.modelResponse is not None:
126127
d["modelResponse"] = block.modelResponse
127128
if block.pdl__usage is not None:
128129
d["pdl__usage"] = usage_to_dict(block.pdl__usage)
129130
case GraniteioModelBlock():
130-
d["model"] = block.model
131+
d["model"] = expr_to_dict(block.model, json_compatible)
131132
d["platform"] = str(block.platform)
132-
d["backend"] = block.backend
133-
d["processor"] = block.processor
134-
if block.input is not None:
135-
d["input"] = block_to_dict(block.input, json_compatible)
133+
d["backend"] = expr_to_dict(block.backend, json_compatible)
134+
if block.processor is not None:
135+
d["processor"] = expr_to_dict(block.processor, json_compatible)
136+
d["input"] = block_to_dict(block.input, json_compatible)
136137
if block.parameters is not None:
137-
d["parameters"] = block.parameters
138+
d["parameters"] = expr_to_dict(block.parameters, json_compatible)
138139
if block.modelResponse is not None:
139140
d["modelResponse"] = block.modelResponse
140141
if block.pdl__usage is not None:
@@ -147,7 +148,7 @@ def block_to_dict( # noqa: C901
147148
case GetBlock():
148149
d["get"] = block.get
149150
case DataBlock():
150-
d["data"] = data_to_dict(block.data, json_compatible)
151+
d["data"] = expr_to_dict(block.data, json_compatible)
151152
if block.raw:
152153
d["raw"] = block.raw
153154
case TextBlock():
@@ -171,7 +172,7 @@ def block_to_dict( # noqa: C901
171172
case MessageBlock():
172173
d["content"] = block_to_dict(block.content, json_compatible)
173174
case ReadBlock():
174-
d["read"] = block.read
175+
d["read"] = expr_to_dict(block.read, json_compatible)
175176
d["message"] = block.message
176177
d["multiline"] = block.multiline
177178
case IncludeBlock():
@@ -183,18 +184,18 @@ def block_to_dict( # noqa: C901
183184
if block.pdl__trace:
184185
d["pdl__trace"] = block_to_dict(block.pdl__trace, json_compatible)
185186
case IfBlock():
186-
d["if"] = block.condition
187+
d["if"] = expr_to_dict(block.condition, json_compatible)
187188
d["then"] = block_to_dict(block.then, json_compatible)
188189
if block.else_ is not None:
189190
d["else"] = block_to_dict(block.else_, json_compatible)
190191
if block.if_result is not None:
191192
d["if_result"] = block.if_result
192193
case MatchBlock():
193-
d["match"] = block.match_
194+
d["match"] = expr_to_dict(block.match_, json_compatible)
194195
d["with"] = [
195196
{
196197
"case": pattern_to_dict(match_case.case),
197-
"if": match_case.if_,
198+
"if": expr_to_dict(match_case.if_, json_compatible),
198199
"then": block_to_dict(match_case.then, json_compatible),
199200
"pdl__case_result": match_case.pdl__case_result,
200201
"pdl__if_result": match_case.pdl__if_result,
@@ -203,11 +204,17 @@ def block_to_dict( # noqa: C901
203204
for match_case in block.with_
204205
]
205206
case RepeatBlock():
206-
d["for"] = block.for_
207-
d["while"] = block.while_
207+
if block.for_ is not None:
208+
d["for"] = expr_to_dict(block.for_, json_compatible)
209+
if block.while_ is not None:
210+
d["while"] = expr_to_dict(block.while_, json_compatible)
208211
d["repeat"] = block_to_dict(block.repeat, json_compatible)
209-
d["until"] = block.until
210-
d["max_iterations"] = block.max_iterations
212+
if block.until is not None:
213+
d["until"] = expr_to_dict(block.until, json_compatible)
214+
if block.max_iterations is not None:
215+
d["max_iterations"] = expr_to_dict(
216+
block.max_iterations, json_compatible
217+
)
211218
d["join"] = join_to_dict(block.join)
212219
if block.pdl__trace is not None:
213220
d["pdl__trace"] = [
@@ -219,8 +226,8 @@ def block_to_dict( # noqa: C901
219226
# if block.scope is not None:
220227
# d["scope"] = scope_to_dict(block.scope, json_compatible)
221228
case CallBlock():
222-
d["call"] = block.call
223-
d["args"] = data_to_dict(block.args, json_compatible)
229+
d["call"] = expr_to_dict(block.call, json_compatible)
230+
d["args"] = expr_to_dict(block.args, json_compatible)
224231
if block.pdl__trace is not None:
225232
d["pdl__trace"] = block_to_dict(
226233
block.pdl__trace, json_compatible
@@ -257,14 +264,24 @@ def block_to_dict( # noqa: C901
257264
return d
258265

259266

260-
def data_to_dict(data: Any, json_compatible):
267+
def data_to_dict(data: Any, json_compatible: bool):
261268
if json_compatible:
262269
d = as_json(data)
263270
else:
264271
d = data
265272
return d
266273

267274

275+
def expr_to_dict(expr: ExpressionType, json_compatible: bool):
276+
if isinstance(expr, LocalizedExpression):
277+
d = {"expr": data_to_dict(expr.expr, json_compatible)}
278+
if expr.pdl__result is not None:
279+
d["pdl__result"] = data_to_dict(expr.pdl__result, json_compatible)
280+
else:
281+
d = data_to_dict(expr, json_compatible)
282+
return d
283+
284+
268285
def timing_to_dict(timing: PdlTiming) -> dict:
269286
d: dict = {}
270287
if timing.start_nanos != 0:

src/pdl/pdl_granite_io.py

+24-16
Original file line numberDiff line numberDiff line change
@@ -13,26 +13,27 @@
1313
)
1414
from .pdl_lazy import PdlConst, PdlLazy, lazy_apply
1515
from .pdl_llms import _LOOP
16+
from .pdl_utils import value_of_expr
1617

1718

1819
class GraniteioModel:
1920
@staticmethod
2021
def processor_of_block(block: GraniteioModelBlock):
22+
model = value_of_expr(block.model)
23+
backend = value_of_expr(block.backend)
24+
assert isinstance(model, str), f"The model should be a string: {model}"
2125
assert isinstance(
22-
block.model, str
23-
), f"The model should be a string: {block.model}"
24-
assert isinstance(
25-
block.backend, (dict, str)
26-
), f"The backend should be a string or a dictionnary: {block.backend}"
27-
match block.backend:
26+
backend, (dict, str)
27+
), f"The backend should be a string or a dictionnary: {backend}"
28+
match backend:
2829
case {"transformers": device}:
29-
assert isinstance(block.backend, dict)
30+
assert isinstance(backend, dict)
3031
from granite_io import make_backend
3132

3233
backend = make_backend(
3334
"transformers",
3435
{
35-
"model_name": block.model,
36+
"model_name": model,
3637
"device": device,
3738
},
3839
)
@@ -42,14 +43,15 @@ def processor_of_block(block: GraniteioModelBlock):
4243
backend = make_backend(
4344
backend_name,
4445
{
45-
"model_name": block.model,
46+
"model_name": model,
4647
},
4748
)
4849
case _:
49-
assert False, f"Unexpected backend: {block.backend}"
50-
processor_name = block.processor
51-
if processor_name is None:
52-
processor_name = block.model
50+
assert False, f"Unexpected backend: {backend}"
51+
if block.processor is None:
52+
processor_name = model
53+
else:
54+
processor_name = value_of_expr(block.processor)
5355
assert isinstance(
5456
processor_name, str
5557
), f"The processor should be a string: {processor_name}"
@@ -73,10 +75,14 @@ async def async_generate_text(
7375
block: GraniteioModelBlock,
7476
messages: ModelInput,
7577
) -> tuple[dict[str, Any], Any]:
78+
if block.parameters is None:
79+
parameters = None
80+
else:
81+
parameters = value_of_expr(block.parameters)
7682
try:
77-
assert block.parameters is None or isinstance(block.parameters, dict)
83+
assert parameters is None or isinstance(parameters, dict)
7884
io_processor = GraniteioModel.processor_of_block(block)
79-
inputs = GraniteioModel.build_message(messages, block.parameters)
85+
inputs = GraniteioModel.build_message(messages, parameters)
8086
result = io_processor.create_chat_completion(inputs) # pyright: ignore
8187
try: # TODO: update when new version of granite-io is released
8288
message = result.next_message.model_dump()
@@ -88,7 +94,9 @@ async def async_generate_text(
8894
raw_result,
8995
)
9096
except Exception as exc:
91-
message = f"Error during '{block.model}' model call: {repr(exc)}"
97+
message = (
98+
f"Error during '{value_of_expr(block.model)}' model call: {repr(exc)}"
99+
)
92100
loc = block.pdl__location
93101
raise PDLRuntimeError(
94102
message,

0 commit comments

Comments
 (0)