Skip to content

Commit 4fd2721

Browse files
author
FaydSpeare
committed
rebase
1 parent 28d9ffd commit 4fd2721

File tree

4 files changed

+195
-10
lines changed

4 files changed

+195
-10
lines changed

ensembledata/api/_async_client.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
import sys
44
from typing import TYPE_CHECKING, Any, Mapping, Sequence
55

6+
from ._http import AsyncHttpClient
7+
68
if sys.version_info < (3, 8):
79
from typing_extensions import Literal
810
else:
@@ -949,7 +951,7 @@ async def post_comments(
949951
class EDAsyncClient:
950952
def __init__(self, token: str, *, timeout: float = 600, max_network_retries: int = 3):
951953
self.requester = AsyncRequester(
952-
token, timeout=timeout, max_network_retries=max_network_retries
954+
token, timeout=timeout, max_network_retries=max_network_retries, http_client=http_client
953955
)
954956
self.customer = CustomerEndpoints(self.requester)
955957
self.tiktok = TiktokEndpoints(self.requester)
@@ -960,3 +962,6 @@ def __init__(self, token: str, *, timeout: float = 600, max_network_retries: int
960962

961963
async def request(self, uri: str, params: Mapping[str, Any] | None = None) -> EDResponse:
962964
return await self.requester.get(uri, params=params or {})
965+
966+
async def close(self):
967+
await self.requester.http_client.close()

ensembledata/api/_client.py

+2
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
import sys
44
from typing import TYPE_CHECKING, Any, Mapping, Sequence
55

6+
from ._http import HttpClient
7+
68
if sys.version_info < (3, 8):
79
from typing_extensions import Literal
810
else:

ensembledata/api/_http.py

+160
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
from __future__ import annotations
2+
3+
import sys
4+
from typing import Any, Mapping
5+
6+
if sys.version_info < (3, 8):
7+
from typing_extensions import Protocol
8+
else:
9+
from typing import Protocol
10+
11+
try:
12+
import requests
13+
except ImportError:
14+
requests = None
15+
16+
try:
17+
import httpx
18+
except ImportError:
19+
httpx = None
20+
21+
try:
22+
import aiohttp
23+
except ImportError:
24+
aiohttp = None
25+
26+
import json
27+
import urllib.error
28+
import urllib.parse
29+
import urllib.request
30+
31+
32+
class HttpClient(Protocol):
33+
def get(self, url: str, params: Mapping[str, Any]) -> tuple[int, Any, Mapping[str, Any]]: ...
34+
35+
def close(self): ...
36+
37+
38+
class AsyncHttpClient(Protocol):
39+
async def get(
40+
self, url: str, params: Mapping[str, Any]
41+
) -> tuple[int, Any, Mapping[str, Any]]: ...
42+
43+
async def close(self): ...
44+
45+
46+
DEFAULT_TIMEOUT = 600
47+
48+
49+
def default_async_client(timeout: int = DEFAULT_TIMEOUT):
50+
if httpx:
51+
print("HttpxAsyncClient")
52+
return HttpxAsyncClient(timeout=timeout)
53+
if aiohttp:
54+
print("AioHttpClient")
55+
return AioHttpClient(timeout=timeout)
56+
57+
raise ImportError(
58+
"No async HTTP client available. To use async please make sure to install"
59+
" either httpx or aiohttp. Alternatively, you can implement your own version of the"
60+
" AsyncHttpClient protocol."
61+
)
62+
63+
64+
def default_sync_client(timeout: int = DEFAULT_TIMEOUT):
65+
if httpx:
66+
print("HttpxClient")
67+
return HttpxClient(timeout=timeout)
68+
if requests:
69+
print("RequestsClient")
70+
return RequestsClient(timeout=timeout)
71+
72+
print("UrllibClient")
73+
return UrllibClient()
74+
75+
76+
class UrllibClient(HttpClient):
77+
def __init__(self, timeout: int = DEFAULT_TIMEOUT):
78+
self._timeout = timeout
79+
80+
def get(self, url: str, params: Mapping[str, Any]):
81+
url += "?" + urllib.parse.urlencode(params, doseq=True, safe="/")
82+
req = urllib.request.Request(url, method="GET", headers={})
83+
84+
try:
85+
with urllib.request.urlopen(req, timeout=self._timeout) as res:
86+
return res.status, json.loads(res.read().decode()), dict(res.headers)
87+
except urllib.error.HTTPError as e:
88+
return e.code, json.loads(e.read().decode()), dict(e.headers)
89+
90+
def close(self):
91+
pass
92+
93+
94+
class RequestsClient(HttpClient):
95+
def __init__(self, timeout: int = DEFAULT_TIMEOUT):
96+
if not requests:
97+
raise ImportError("requests is not installed")
98+
self._session = requests.Session()
99+
self._timeout = timeout
100+
101+
def get(self, url: str, params: Mapping[str, Any]):
102+
assert requests is not None
103+
res = self._session.get(url, params=params, timeout=self._timeout)
104+
return res.status_code, res.json(), res.headers
105+
106+
def close(self):
107+
self._session.close()
108+
109+
110+
class HttpxClient(HttpClient):
111+
def __init__(self, timeout: int = DEFAULT_TIMEOUT):
112+
if not httpx:
113+
raise ImportError("httpx is not installed")
114+
self._client = httpx.Client(timeout=timeout)
115+
116+
def get(self, url: str, params: Mapping[str, Any]):
117+
res = self._client.get(url, params=params)
118+
return res.status_code, res.json(), res.headers
119+
120+
def close(self):
121+
self._client.close()
122+
123+
124+
class HttpxAsyncClient(AsyncHttpClient):
125+
def __init__(self, timeout: int = DEFAULT_TIMEOUT):
126+
if not httpx:
127+
raise ImportError("httpx is not installed")
128+
self._client = httpx.AsyncClient(timeout=timeout)
129+
130+
async def get(self, url: str, params: Mapping[str, Any]):
131+
print("get httpx")
132+
res = await self._client.get(url, params=params)
133+
return res.status_code, res.json(), res.headers
134+
135+
async def close(self):
136+
await self._client.aclose()
137+
138+
139+
class AioHttpClient(AsyncHttpClient):
140+
def __init__(self, timeout: int = DEFAULT_TIMEOUT):
141+
if not aiohttp:
142+
raise ImportError("aiohttp is not installed")
143+
self._client = aiohttp.ClientSession(read_timeout=timeout)
144+
145+
async def get(self, url: str, params: Mapping[str, Any]):
146+
print("get aiohttp")
147+
async with self._client.get(url, params=params) as response:
148+
return response.status, await response.json(), response.headers
149+
150+
async def close(self):
151+
await self._client.close()
152+
153+
154+
if __name__ == "__main__":
155+
client = UrllibClient()
156+
status, data, headers = client.get(
157+
"https://ensembledata.com/apis/customer/get-used-units",
158+
{"token": "mbWgvamioO7dQUJ9", "date": "2024-01-01"},
159+
)
160+
print(status, data, headers)

ensembledata/api/_requester.py

+27-9
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from ._response import EDResponse
1010
from ._version import version
1111
from .errors import EDError
12+
from ._http import HttpClient, AsyncHttpClient, default_sync_client, default_async_client
1213

1314
BASE_URL = "https://ensembledata.com/apis"
1415
USER_AGENT = f"ensembledata-python/{version}"
@@ -18,9 +19,10 @@ class EDErrorCode(IntEnum):
1819
TOKEN_NOT_FOUND = 491
1920

2021

21-
def _handle_response(res: httpx.Response, *, return_top_level_data: bool) -> EDResponse:
22-
units_charged = res.headers.get("units_charged", 0)
23-
payload = res.json()
22+
def _handle_response(
23+
status_code: int, payload: Any, headers: Mapping[str, Any], *, return_top_level_data: bool
24+
) -> EDResponse:
25+
units_charged = headers.get("units_charged", 0)
2426
assert isinstance(payload, dict)
2527

2628
# In most cases there will only be a "data" field in the response if the status code is 2xx,
@@ -31,11 +33,11 @@ def _handle_response(res: httpx.Response, *, return_top_level_data: bool) -> EDR
3133
# There are a couple of endpoints that don't use a single top level "data" field, but
3234
# rather have multiple top level fields, for example "nextCursor".
3335
if return_top_level_data:
34-
return EDResponse(res.status_code, payload, units_charged)
36+
return EDResponse(status_code, payload, units_charged)
3537

36-
return EDResponse(res.status_code, payload.get("data"), units_charged)
38+
return EDResponse(status_code, payload.get("data"), units_charged)
3739

38-
raise EDError(res.status_code, payload.get("detail"), units_charged)
40+
raise EDError(status_code, payload.get("detail"), units_charged)
3941

4042

4143
def _check_token(token: str) -> None:
@@ -48,11 +50,19 @@ def _check_token(token: str) -> None:
4850

4951

5052
class Requester:
51-
def __init__(self, token: str, *, timeout: float, max_network_retries: int):
53+
def __init__(
54+
self,
55+
token: str,
56+
*,
57+
timeout: float,
58+
max_network_retries: int,
59+
http_client: HttpClient | None = None,
60+
):
5261
_check_token(token)
5362
self.token = token
5463
self.timeout = timeout
5564
self.max_network_retries = max_network_retries
65+
self.http_client = http_client or default_sync_client(timeout=timeout)
5666

5767
def get(
5868
self,
@@ -64,7 +74,7 @@ def get(
6474
) -> EDResponse:
6575
for attempt in range(self.max_network_retries):
6676
try:
67-
res = httpx.get(
77+
status_code, payload, headers = self.http_client.get(
6878
f"{BASE_URL}{url}",
6979
params={"token": self.token, **params},
7080
timeout=(self.timeout if timeout is None else timeout),
@@ -80,11 +90,19 @@ def get(
8090

8191

8292
class AsyncRequester:
83-
def __init__(self, token: str, *, timeout: float, max_network_retries: int):
93+
def __init__(
94+
self,
95+
token: str,
96+
*,
97+
timeout: int,
98+
max_network_retries: int,
99+
http_client: AsyncHttpClient | None = None,
100+
):
84101
_check_token(token)
85102
self.token = token
86103
self.timeout = timeout
87104
self.max_network_retries = max_network_retries
105+
self.http_client = http_client or default_async_client(timeout=timeout)
88106

89107
async def get(
90108
self,

0 commit comments

Comments
 (0)