Skip to content

Commit adfc254

Browse files
committed
review suggestions
1 parent 873f090 commit adfc254

File tree

6 files changed

+98
-19
lines changed

6 files changed

+98
-19
lines changed

docs/models/huggingface.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ from pydantic_ai import Agent
5757
from pydantic_ai.models.huggingface import HuggingFaceModel
5858
from pydantic_ai.providers.huggingface import HuggingFaceProvider
5959

60-
model = HuggingFaceModel('Qwen/Qwen3-235B-A22B', provider=HuggingFaceProvider(api_key='hf_token', provider='nebius'))
60+
model = HuggingFaceModel('Qwen/Qwen3-235B-A22B', provider=HuggingFaceProvider(api_key='hf_token', provider_name='nebius'))
6161
agent = Agent(model)
6262
...
6363
```

pydantic_ai_slim/pydantic_ai/providers/huggingface.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations as _annotations
22

33
import os
4+
from typing import overload
45

56
from httpx import AsyncClient
67

@@ -32,13 +33,26 @@ def base_url(self) -> str:
3233
def client(self) -> AsyncInferenceClient:
3334
return self._client
3435

36+
@overload
37+
def __init__(self, *, base_url: str, api_key: str | None = None) -> None: ...
38+
@overload
39+
def __init__(self, *, provider_name: str, api_key: str | None = None) -> None: ...
40+
@overload
41+
def __init__(self, *, hf_client: AsyncInferenceClient, api_key: str | None = None) -> None: ...
42+
@overload
43+
def __init__(self, *, hf_client: AsyncInferenceClient, base_url: str, api_key: str | None = None) -> None: ...
44+
@overload
45+
def __init__(self, *, hf_client: AsyncInferenceClient, provider_name: str, api_key: str | None = None) -> None: ...
46+
@overload
47+
def __init__(self, *, api_key: str | None = None) -> None: ...
48+
3549
def __init__(
3650
self,
3751
base_url: str | None = None,
3852
api_key: str | None = None,
3953
hf_client: AsyncInferenceClient | None = None,
4054
http_client: AsyncClient | None = None,
41-
provider: str | None = None,
55+
provider_name: str | None = None,
4256
) -> None:
4357
"""Create a new Hugging Face provider.
4458
@@ -50,9 +64,9 @@ def __init__(
5064
[`AsyncInferenceClient`](https://huggingface.co/docs/huggingface_hub/v0.29.3/en/package_reference/inference_client#huggingface_hub.AsyncInferenceClient)
5165
client to use. If not provided, a new instance will be created.
5266
http_client: (currently ignored) An existing `httpx.AsyncClient` to use for making HTTP requests.
53-
provider : Name of the provider to use for inference. available providers can be found in the [HF Inference Providers documentation](https://huggingface.co/docs/inference-providers/index#partners).
67+
provider_name : Name of the provider to use for inference. available providers can be found in the [HF Inference Providers documentation](https://huggingface.co/docs/inference-providers/index#partners).
5468
defaults to "auto", which will select the first available provider for the model, the first of the providers available for the model, sorted by the user's order in https://hf.co/settings/inference-providers.
55-
If `base_url` is passed, then `provider` is not used.
69+
If `base_url` is passed, then `provider_name` is not used.
5670
"""
5771
api_key = api_key or os.environ.get('HF_TOKEN')
5872

@@ -63,12 +77,12 @@ def __init__(
6377
)
6478

6579
if http_client is not None:
66-
raise ValueError('`http_client` is ignored for HuggingFace provider, please use `hf_client` instead')
80+
raise ValueError('`http_client` is ignored for HuggingFace provider, please use `hf_client` instead.')
6781

68-
if base_url is not None and provider is not None:
69-
raise ValueError('Cannot provide both `base_url` and `provider`')
82+
if base_url is not None and provider_name is not None:
83+
raise ValueError('Cannot provide both `base_url` and `provider_name`.')
7084

7185
if hf_client is None:
72-
self._client = AsyncInferenceClient(api_key=api_key, provider=provider, base_url=base_url) # type: ignore
86+
self._client = AsyncInferenceClient(api_key=api_key, provider=provider_name, base_url=base_url) # type: ignore
7387
else:
7488
self._client = hf_client

tests/conftest.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,7 @@ def openrouter_api_key() -> str:
294294

295295
@pytest.fixture(scope='session')
296296
def huggingface_api_key() -> str:
297-
return os.getenv('HF_TOKEN', 'hf_token') or os.getenv('HUGGINGFACE_API_KEY', 'hf_token')
297+
return os.getenv('HF_TOKEN', 'hf_token')
298298

299299

300300
@pytest.fixture(scope='session')
@@ -428,7 +428,7 @@ def model(
428428

429429
return HuggingFaceModel(
430430
'Qwen/Qwen2.5-72B-Instruct',
431-
provider=HuggingFaceProvider(provider='nebius', api_key=huggingface_api_key),
431+
provider=HuggingFaceProvider(provider_name='nebius', api_key=huggingface_api_key),
432432
)
433433
else:
434434
raise ValueError(f'Unknown model: {request.param}')

tests/models/cassettes/test_huggingface/test_hf_model_instructions.yaml

Lines changed: 64 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,66 @@
11
interactions:
2+
- request:
3+
body: null
4+
headers:
5+
accept:
6+
- '*/*'
7+
accept-encoding:
8+
- gzip, deflate
9+
connection:
10+
- keep-alive
11+
method: GET
12+
uri: https://huggingface.co/api/models/Qwen/Qwen2.5-72B-Instruct?expand=inferenceProviderMapping
13+
response:
14+
headers:
15+
access-control-allow-origin:
16+
- https://huggingface.co
17+
access-control-expose-headers:
18+
- X-Repo-Commit,X-Request-Id,X-Error-Code,X-Error-Message,X-Total-Count,ETag,Link,Accept-Ranges,Content-Range,X-Linked-Size,X-Linked-ETag,X-Xet-Hash
19+
connection:
20+
- keep-alive
21+
content-length:
22+
- '701'
23+
content-type:
24+
- application/json; charset=utf-8
25+
cross-origin-opener-policy:
26+
- same-origin
27+
etag:
28+
- W/"2bd-diYmxjldwbIbFgWNRPBqJ3SEIak"
29+
referrer-policy:
30+
- strict-origin-when-cross-origin
31+
vary:
32+
- Origin
33+
parsed_body:
34+
_id: 66e81cefd1b1391042d0e47e
35+
id: Qwen/Qwen2.5-72B-Instruct
36+
inferenceProviderMapping:
37+
featherless-ai:
38+
providerId: Qwen/Qwen2.5-72B-Instruct
39+
status: live
40+
task: conversational
41+
fireworks-ai:
42+
providerId: accounts/fireworks/models/qwen2p5-72b-instruct
43+
status: live
44+
task: conversational
45+
hyperbolic:
46+
providerId: Qwen/Qwen2.5-72B-Instruct
47+
status: live
48+
task: conversational
49+
nebius:
50+
providerId: Qwen/Qwen2.5-72B-Instruct-fast
51+
status: live
52+
task: conversational
53+
novita:
54+
providerId: qwen/qwen-2.5-72b-instruct
55+
status: live
56+
task: conversational
57+
together:
58+
providerId: Qwen/Qwen2.5-72B-Instruct-Turbo
59+
status: live
60+
task: conversational
61+
status:
62+
code: 200
63+
message: OK
264
- request:
365
body: null
466
headers: {}
@@ -40,8 +102,8 @@ interactions:
40102
role: assistant
41103
tool_calls: []
42104
stop_reason: null
43-
created: 1749475551
44-
id: chatcmpl-6fa46f85f4f04beda9c936d5996b22a8
105+
created: 1751470757
106+
id: chatcmpl-b3936940372c481b8d886e596dc75524
45107
model: Qwen/Qwen2.5-72B-Instruct-fast
46108
object: chat.completion
47109
prompt_logprobs: null

tests/models/test_huggingface.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,8 @@ async def test_simple_completion(allow_model_requests: None):
125125
c = completion_message(ChatCompletionInputMessage(content='world', role='assistant')) # type:ignore
126126
mock_client = MockHuggingFace.create_mock(c)
127127
model = HuggingFaceModel(
128-
'Qwen/Qwen2.5-72B-Instruct', provider=HuggingFaceProvider(provider='nebius', hf_client=mock_client, api_key='x')
128+
'Qwen/Qwen2.5-72B-Instruct',
129+
provider=HuggingFaceProvider(provider_name='nebius', hf_client=mock_client, api_key='x'),
129130
)
130131
agent = Agent(model)
131132

@@ -148,7 +149,8 @@ async def test_request_simple_usage(allow_model_requests: None):
148149
c = completion_message(ChatCompletionInputMessage(content='world', role='assistant')) # type:ignore
149150
mock_client = MockHuggingFace.create_mock(c)
150151
model = HuggingFaceModel(
151-
'Qwen/Qwen2.5-72B-Instruct', provider=HuggingFaceProvider(provider='nebius', hf_client=mock_client, api_key='x')
152+
'Qwen/Qwen2.5-72B-Instruct',
153+
provider=HuggingFaceProvider(provider_name='nebius', hf_client=mock_client, api_key='x'),
152154
)
153155
agent = Agent(model)
154156

@@ -181,7 +183,8 @@ async def test_request_structured_response(allow_model_requests: None):
181183

182184
mock_client = MockHuggingFace.create_mock(c)
183185
model = HuggingFaceModel(
184-
'Qwen/Qwen2.5-72B-Instruct', provider=HuggingFaceProvider(provider='nebius', hf_client=mock_client, api_key='x')
186+
'Qwen/Qwen2.5-72B-Instruct',
187+
provider=HuggingFaceProvider(provider_name='nebius', hf_client=mock_client, api_key='x'),
185188
)
186189
agent = Agent(model, output_type=list[int])
187190

@@ -652,7 +655,7 @@ def test_model_status_error(allow_model_requests: None) -> None:
652655
@pytest.mark.vcr()
653656
async def test_request_simple_success_with_vcr(allow_model_requests: None, huggingface_api_key: str):
654657
m = HuggingFaceModel(
655-
'Qwen/Qwen2.5-72B-Instruct', provider=HuggingFaceProvider(provider='nebius', api_key=huggingface_api_key)
658+
'Qwen/Qwen2.5-72B-Instruct', provider=HuggingFaceProvider(provider_name='nebius', api_key=huggingface_api_key)
656659
)
657660
agent = Agent(m)
658661
result = await agent.run('hello')
@@ -664,7 +667,7 @@ async def test_request_simple_success_with_vcr(allow_model_requests: None, huggi
664667
@pytest.mark.vcr()
665668
async def test_hf_model_instructions(allow_model_requests: None, huggingface_api_key: str):
666669
m = HuggingFaceModel(
667-
'Qwen/Qwen2.5-72B-Instruct', provider=HuggingFaceProvider(provider='nebius', api_key=huggingface_api_key)
670+
'Qwen/Qwen2.5-72B-Instruct', provider=HuggingFaceProvider(provider_name='nebius', api_key=huggingface_api_key)
668671
)
669672

670673
def simple_instructions(ctx: RunContext):
@@ -684,7 +687,7 @@ def simple_instructions(ctx: RunContext):
684687
usage=Usage(requests=1, request_tokens=26, response_tokens=2, total_tokens=28),
685688
model_name='Qwen/Qwen2.5-72B-Instruct-fast',
686689
timestamp=IsDatetime(),
687-
vendor_id='chatcmpl-6fa46f85f4f04beda9c936d5996b22a8',
690+
vendor_id='chatcmpl-b3936940372c481b8d886e596dc75524',
688691
),
689692
]
690693
)

tests/providers/test_huggingface.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def test_huggingface_provider_pass_http_client() -> None:
4444
ValueError,
4545
match=re.escape('`http_client` is ignored for HuggingFace provider, please use `hf_client` instead'),
4646
):
47-
HuggingFaceProvider(http_client=http_client, api_key='api-key')
47+
HuggingFaceProvider(http_client=http_client, api_key='api-key') # type: ignore
4848

4949

5050
def test_huggingface_provider_pass_hf_client() -> None:

0 commit comments

Comments
 (0)