-
-
Notifications
You must be signed in to change notification settings - Fork 772
/
Copy pathtest_simple.py
256 lines (211 loc) · 6.65 KB
/
test_simple.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
from enum import Enum
from typing import Literal
import anthropic
import pytest
from pydantic import BaseModel, field_validator
import instructor
from instructor.retry import InstructorRetryException
client = instructor.from_anthropic(
anthropic.Anthropic(), mode=instructor.Mode.ANTHROPIC_TOOLS
)
def test_simple():
class User(BaseModel):
name: str
age: int
@field_validator("name")
def name_is_uppercase(cls, v: str):
assert v.isupper(), f"{v} is not an uppercased string. Note that all characters in {v} must be uppercase (EG. TIM SARAH ADAM)."
return v
resp = client.messages.create(
model="claude-3-haiku-20240307",
max_tokens=4096,
max_retries=2,
system="Make sure to follow the instructions carefully and return a response object that matches the json schema requested. Age is an integer.",
messages=[
{
"role": "user",
"content": "Extract John is 18 years old.",
},
],
response_model=User,
) # type: ignore
assert isinstance(resp, User)
assert resp.name == "JOHN" # due to validation
assert resp.age == 18
def test_nested_type():
class Address(BaseModel):
house_number: int
street_name: str
class User(BaseModel):
name: str
age: int
address: Address
resp = client.messages.create(
model="claude-3-haiku-20240307",
max_tokens=4096,
max_retries=0,
messages=[
{
"role": "user",
"content": "Extract John is 18 years old and lives at 123 First Avenue.",
}
],
response_model=User,
) # type: ignore
assert isinstance(resp, User)
assert resp.name == "John"
assert resp.age == 18
assert isinstance(resp.address, Address)
assert resp.address.house_number == 123
assert resp.address.street_name == "First Avenue"
def test_list_str():
class User(BaseModel):
name: str
age: int
family: list[str]
resp = client.messages.create(
model="claude-3-haiku-20240307",
max_tokens=1024,
max_retries=0,
system="Make sure to follow the instructions carefully and return a response object that matches the json schema requested. Family members here is just asking for a list of names",
messages=[
{
"role": "user",
"content": "Create a user for a model with a name, age, and family members.",
}
],
response_model=User,
)
assert isinstance(resp, User)
assert isinstance(resp.family, list)
for member in resp.family:
assert isinstance(member, str)
@pytest.mark.skip("Just use Literal!")
def test_enum():
class Role(str, Enum):
ADMIN = "admin"
USER = "user"
class User(BaseModel):
name: str
role: Role
resp = client.messages.create(
model="claude-3-haiku-20240307",
max_tokens=1024,
max_retries=1,
messages=[
{
"role": "user",
"content": "Create a user for a model with a name and role of admin.",
}
],
response_model=User,
)
assert isinstance(resp, User)
assert resp.role == Role.ADMIN
def test_literal():
class User(BaseModel):
name: str
role: Literal["admin", "user"]
resp = client.messages.create(
model="claude-3-haiku-20240307",
max_tokens=4096,
max_retries=2,
messages=[
{
"role": "user",
"content": "Create a admin user for a model with a name and role.",
}
],
response_model=User,
) # type: ignore
assert isinstance(resp, User)
assert resp.role == "admin"
def test_nested_list():
class Properties(BaseModel):
key: str
value: str
class User(BaseModel):
name: str
age: int
properties: list[Properties]
resp = client.messages.create(
model="claude-3-haiku-20240307",
max_tokens=1024,
max_retries=0,
messages=[
{
"role": "user",
"content": "Create a user for a model with a name, age, and properties.",
}
],
response_model=User,
)
assert isinstance(resp, User)
for property in resp.properties:
assert isinstance(property, Properties)
def test_system_messages_allcaps():
class User(BaseModel):
name: str
age: int
resp = client.messages.create(
model="claude-3-haiku-20240307",
max_tokens=1024,
max_retries=0,
messages=[
{
"role": "system",
"content": "Please make sure to follow the instructions carefully and return a valid response object. All strings must be fully capitalised in all caps. (Eg. THIS IS AN UPPERCASE STRING) and age is an integer.",
},
{
"role": "user",
"content": "Create a user for a model with a name and age.",
},
],
response_model=User,
)
assert isinstance(resp, User)
assert resp.name.isupper()
def test_retry_error():
class User(BaseModel):
name: str
@field_validator("name")
def validate_name(cls, _):
raise ValueError("Never succeed")
try:
client.messages.create(
model="claude-3-haiku-20240307",
max_tokens=1024,
max_retries=2,
messages=[
{
"role": "user",
"content": "Extract John is 18 years old",
},
],
response_model=User,
)
except InstructorRetryException as e:
assert e.total_usage.input_tokens > 0 and e.total_usage.output_tokens > 0
@pytest.mark.asyncio
async def test_async_retry_error():
client = instructor.from_anthropic(anthropic.AsyncAnthropic())
class User(BaseModel):
name: str
@field_validator("name")
def validate_name(cls, _):
raise ValueError("Never succeed")
try:
await client.messages.create(
model="claude-3-haiku-20240307",
max_tokens=1024,
max_retries=2,
messages=[
{
"role": "user",
"content": "Extract John is 18 years old",
},
],
response_model=User,
)
except InstructorRetryException as e:
assert e.total_usage.input_tokens > 0 and e.total_usage.output_tokens > 0