Skip to content

Commit 40a881d

Browse files
feat(models): add to_dict & to_json helper methods (#1305)
1 parent 69cdfc3 commit 40a881d

File tree

7 files changed

+155
-14
lines changed

7 files changed

+155
-14
lines changed

README.md

+4-4
Original file line numberDiff line numberDiff line change
@@ -200,10 +200,10 @@ We recommend that you always instantiate a client (e.g., with `client = OpenAI()
200200

201201
## Using types
202202

203-
Nested request parameters are [TypedDicts](https://docs.python.org/3/library/typing.html#typing.TypedDict). Responses are [Pydantic models](https://docs.pydantic.dev), which provide helper methods for things like:
203+
Nested request parameters are [TypedDicts](https://docs.python.org/3/library/typing.html#typing.TypedDict). Responses are [Pydantic models](https://docs.pydantic.dev) which also provide helper methods for things like:
204204

205-
- Serializing back into JSON, `model.model_dump_json(indent=2, exclude_unset=True)`
206-
- Converting to a dictionary, `model.model_dump(exclude_unset=True)`
205+
- Serializing back into JSON, `model.to_json()`
206+
- Converting to a dictionary, `model.to_dict()`
207207

208208
Typed requests and responses provide autocomplete and documentation within your editor. If you would like to see type errors in VS Code to help catch bugs earlier, set `python.analysis.typeCheckingMode` to `basic`.
209209

@@ -594,7 +594,7 @@ completion = client.chat.completions.create(
594594
},
595595
],
596596
)
597-
print(completion.model_dump_json(indent=2))
597+
print(completion.to_json())
598598
```
599599

600600
In addition to the options provided in the base `OpenAI` client, the following options are provided:

examples/azure.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
},
2121
],
2222
)
23-
print(completion.model_dump_json(indent=2))
23+
print(completion.to_json())
2424

2525

2626
deployment_client = AzureOpenAI(
@@ -40,4 +40,4 @@
4040
},
4141
],
4242
)
43-
print(completion.model_dump_json(indent=2))
43+
print(completion.to_json())

examples/azure_ad.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -27,4 +27,4 @@
2727
},
2828
],
2929
)
30-
print(completion.model_dump_json(indent=2))
30+
print(completion.to_json())

examples/streaming.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,12 @@ def sync_main() -> None:
2222

2323
# You can manually control iteration over the response
2424
first = next(response)
25-
print(f"got response data: {first.model_dump_json(indent=2)}")
25+
print(f"got response data: {first.to_json()}")
2626

2727
# Or you could automatically iterate through all of data.
2828
# Note that the for loop will not exit until *all* of the data has been processed.
2929
for data in response:
30-
print(data.model_dump_json())
30+
print(data.to_json())
3131

3232

3333
async def async_main() -> None:
@@ -43,12 +43,12 @@ async def async_main() -> None:
4343
# You can manually control iteration over the response.
4444
# In Python 3.10+ you can also use the `await anext(response)` builtin instead
4545
first = await response.__anext__()
46-
print(f"got response data: {first.model_dump_json(indent=2)}")
46+
print(f"got response data: {first.to_json()}")
4747

4848
# Or you could automatically iterate through all of data.
4949
# Note that the for loop will not exit until *all* of the data has been processed.
5050
async for data in response:
51-
print(data.model_dump_json())
51+
print(data.to_json())
5252

5353

5454
sync_main()

src/openai/_models.py

+73
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,79 @@ def model_fields_set(self) -> set[str]:
9090
class Config(pydantic.BaseConfig): # pyright: ignore[reportDeprecated]
9191
extra: Any = pydantic.Extra.allow # type: ignore
9292

93+
def to_dict(
94+
self,
95+
*,
96+
mode: Literal["json", "python"] = "python",
97+
use_api_names: bool = True,
98+
exclude_unset: bool = True,
99+
exclude_defaults: bool = False,
100+
exclude_none: bool = False,
101+
warnings: bool = True,
102+
) -> dict[str, object]:
103+
"""Recursively generate a dictionary representation of the model, optionally specifying which fields to include or exclude.
104+
105+
By default, fields that were not set by the API will not be included,
106+
and keys will match the API response, *not* the property names from the model.
107+
108+
For example, if the API responds with `"fooBar": true` but we've defined a `foo_bar: bool` property,
109+
the output will use the `"fooBar"` key (unless `use_api_names=False` is passed).
110+
111+
Args:
112+
mode:
113+
If mode is 'json', the dictionary will only contain JSON serializable types. e.g. `datetime` will be turned into a string, `"2024-3-22T18:11:19.117000Z"`.
114+
If mode is 'python', the dictionary may contain any Python objects. e.g. `datetime(2024, 3, 22)`
115+
116+
use_api_names: Whether to use the key that the API responded with or the property name. Defaults to `True`.
117+
exclude_unset: Whether to exclude fields that have not been explicitly set.
118+
exclude_defaults: Whether to exclude fields that are set to their default value from the output.
119+
exclude_none: Whether to exclude fields that have a value of `None` from the output.
120+
warnings: Whether to log warnings when invalid fields are encountered. This is only supported in Pydantic v2.
121+
"""
122+
return self.model_dump(
123+
mode=mode,
124+
by_alias=use_api_names,
125+
exclude_unset=exclude_unset,
126+
exclude_defaults=exclude_defaults,
127+
exclude_none=exclude_none,
128+
warnings=warnings,
129+
)
130+
131+
def to_json(
132+
self,
133+
*,
134+
indent: int | None = 2,
135+
use_api_names: bool = True,
136+
exclude_unset: bool = True,
137+
exclude_defaults: bool = False,
138+
exclude_none: bool = False,
139+
warnings: bool = True,
140+
) -> str:
141+
"""Generates a JSON string representing this model as it would be received from or sent to the API (but with indentation).
142+
143+
By default, fields that were not set by the API will not be included,
144+
and keys will match the API response, *not* the property names from the model.
145+
146+
For example, if the API responds with `"fooBar": true` but we've defined a `foo_bar: bool` property,
147+
the output will use the `"fooBar"` key (unless `use_api_names=False` is passed).
148+
149+
Args:
150+
indent: Indentation to use in the JSON output. If `None` is passed, the output will be compact. Defaults to `2`
151+
use_api_names: Whether to use the key that the API responded with or the property name. Defaults to `True`.
152+
exclude_unset: Whether to exclude fields that have not been explicitly set.
153+
exclude_defaults: Whether to exclude fields that have the default value.
154+
exclude_none: Whether to exclude fields that have a value of `None`.
155+
warnings: Whether to show any warnings that occurred during serialization. This is only supported in Pydantic v2.
156+
"""
157+
return self.model_dump_json(
158+
indent=indent,
159+
by_alias=use_api_names,
160+
exclude_unset=exclude_unset,
161+
exclude_defaults=exclude_defaults,
162+
exclude_none=exclude_none,
163+
warnings=warnings,
164+
)
165+
93166
@override
94167
def __str__(self) -> str:
95168
# mypy complains about an invalid self arg

src/openai/lib/_validators.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -678,9 +678,11 @@ def write_out_file(df: pd.DataFrame, fname: str, any_remediations: bool, auto_ac
678678
df_train = df.sample(n=n_train, random_state=42)
679679
df_valid = df.drop(df_train.index)
680680
df_train[["prompt", "completion"]].to_json( # type: ignore
681-
fnames[0], lines=True, orient="records", force_ascii=False
681+
fnames[0], lines=True, orient="records", force_ascii=False, indent=None
682+
)
683+
df_valid[["prompt", "completion"]].to_json(
684+
fnames[1], lines=True, orient="records", force_ascii=False, indent=None
682685
)
683-
df_valid[["prompt", "completion"]].to_json(fnames[1], lines=True, orient="records", force_ascii=False)
684686

685687
n_classes, pos_class = get_classification_hyperparams(df)
686688
additional_params += " --compute_classification_metrics"
@@ -690,7 +692,9 @@ def write_out_file(df: pd.DataFrame, fname: str, any_remediations: bool, auto_ac
690692
additional_params += f" --classification_n_classes {n_classes}"
691693
else:
692694
assert len(fnames) == 1
693-
df[["prompt", "completion"]].to_json(fnames[0], lines=True, orient="records", force_ascii=False)
695+
df[["prompt", "completion"]].to_json(
696+
fnames[0], lines=True, orient="records", force_ascii=False, indent=None
697+
)
694698

695699
# Add -v VALID_FILE if we split the file into train / valid
696700
files_string = ("s" if split else "") + " to `" + ("` and `".join(fnames))

tests/test_models.py

+64
Original file line numberDiff line numberDiff line change
@@ -501,6 +501,42 @@ class Model(BaseModel):
501501
assert "resource_id" in m.model_fields_set
502502

503503

504+
def test_to_dict() -> None:
505+
class Model(BaseModel):
506+
foo: Optional[str] = Field(alias="FOO", default=None)
507+
508+
m = Model(FOO="hello")
509+
assert m.to_dict() == {"FOO": "hello"}
510+
assert m.to_dict(use_api_names=False) == {"foo": "hello"}
511+
512+
m2 = Model()
513+
assert m2.to_dict() == {}
514+
assert m2.to_dict(exclude_unset=False) == {"FOO": None}
515+
assert m2.to_dict(exclude_unset=False, exclude_none=True) == {}
516+
assert m2.to_dict(exclude_unset=False, exclude_defaults=True) == {}
517+
518+
m3 = Model(FOO=None)
519+
assert m3.to_dict() == {"FOO": None}
520+
assert m3.to_dict(exclude_none=True) == {}
521+
assert m3.to_dict(exclude_defaults=True) == {}
522+
523+
if PYDANTIC_V2:
524+
525+
class Model2(BaseModel):
526+
created_at: datetime
527+
528+
time_str = "2024-03-21T11:39:01.275859"
529+
m4 = Model2.construct(created_at=time_str)
530+
assert m4.to_dict(mode="python") == {"created_at": datetime.fromisoformat(time_str)}
531+
assert m4.to_dict(mode="json") == {"created_at": time_str}
532+
else:
533+
with pytest.raises(ValueError, match="mode is only supported in Pydantic v2"):
534+
m.to_dict(mode="json")
535+
536+
with pytest.raises(ValueError, match="warnings is only supported in Pydantic v2"):
537+
m.to_dict(warnings=False)
538+
539+
504540
def test_forwards_compat_model_dump_method() -> None:
505541
class Model(BaseModel):
506542
foo: Optional[str] = Field(alias="FOO", default=None)
@@ -532,6 +568,34 @@ class Model(BaseModel):
532568
m.model_dump(warnings=False)
533569

534570

571+
def test_to_json() -> None:
572+
class Model(BaseModel):
573+
foo: Optional[str] = Field(alias="FOO", default=None)
574+
575+
m = Model(FOO="hello")
576+
assert json.loads(m.to_json()) == {"FOO": "hello"}
577+
assert json.loads(m.to_json(use_api_names=False)) == {"foo": "hello"}
578+
579+
if PYDANTIC_V2:
580+
assert m.to_json(indent=None) == '{"FOO":"hello"}'
581+
else:
582+
assert m.to_json(indent=None) == '{"FOO": "hello"}'
583+
584+
m2 = Model()
585+
assert json.loads(m2.to_json()) == {}
586+
assert json.loads(m2.to_json(exclude_unset=False)) == {"FOO": None}
587+
assert json.loads(m2.to_json(exclude_unset=False, exclude_none=True)) == {}
588+
assert json.loads(m2.to_json(exclude_unset=False, exclude_defaults=True)) == {}
589+
590+
m3 = Model(FOO=None)
591+
assert json.loads(m3.to_json()) == {"FOO": None}
592+
assert json.loads(m3.to_json(exclude_none=True)) == {}
593+
594+
if not PYDANTIC_V2:
595+
with pytest.raises(ValueError, match="warnings is only supported in Pydantic v2"):
596+
m.to_json(warnings=False)
597+
598+
535599
def test_forwards_compat_model_dump_json_method() -> None:
536600
class Model(BaseModel):
537601
foo: Optional[str] = Field(alias="FOO", default=None)

0 commit comments

Comments
 (0)