Skip to content

Commit 2cbb9ab

Browse files
author
FaydSpeare
committed
client code generation
1 parent 6d2e6ec commit 2cbb9ab

File tree

9 files changed

+1084
-767
lines changed

9 files changed

+1084
-767
lines changed

README.md

+4-4
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# EnsembleData Python API
22

33
[![pypi](https://img.shields.io/pypi/v/ensembledata?color=%2334D058&label=pypi%20package)](https://pypi.org/project/ensembledata/)
4-
![](https://img.shields.io/pypi/pyversions/ensembledata.svg?color=%2334D058)
4+
[![pypi](https://img.shields.io/pypi/pyversions/ensembledata.svg)](https://pypi.org/project/ensembledata/)
55

66
## Documentation
77

@@ -28,7 +28,7 @@ from ensembledata.api import EDClient
2828

2929

3030
client = EDClient("API-TOKEN")
31-
result = client.tiktok.user_info_from_username("daviddobrik")
31+
result = client.tiktok.user_info_from_username(username="daviddobrik")
3232

3333
print("Data: ", result.data)
3434
print("Units charged:", result.units_charged)
@@ -79,7 +79,7 @@ from ensembledata.api import EDClient, EDError, errors
7979

8080
client = EDClient("API-TOKEN")
8181
try:
82-
result = client.tiktok.user_info_from_username("daviddobrik")
82+
result = client.tiktok.user_info_from_username(username="daviddobrik")
8383
except EDError as e:
8484

8585
# Rate limit exceeded...
@@ -113,7 +113,7 @@ from ensembledata.api import EDAsyncClient
113113

114114
async def main():
115115
client = EDAsyncClient("API-TOKEN")
116-
result = await client.tiktok.user_info_from_username("daviddobrik")
116+
result = await client.tiktok.user_info_from_username(username="daviddobrik")
117117

118118
if __name__ == "__main__":
119119
asyncio.run(main())

codegen/__init__.py

Whitespace-only changes.

codegen/__main__.py

+101
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
from typing import Any
2+
3+
import requests
4+
from jinja2 import Environment, PackageLoader, select_autoescape
5+
6+
7+
def openapi_type_to_python_type(openapi_type: str) -> str:
8+
if openapi_type == "integer":
9+
return "int"
10+
if openapi_type == "string":
11+
return "str"
12+
if openapi_type == "boolean":
13+
return "bool"
14+
raise Exception("Unknown type")
15+
16+
17+
def parse_endpoints_by_tag(openapi: Any):
18+
tags = dict()
19+
for path, endpoints in openapi["paths"].items():
20+
operation_id = endpoints["get"]["operationId"]
21+
function_name = "_".join(operation_id.split("_")[1:])
22+
params = []
23+
tag = endpoints["get"]["tags"][0]
24+
25+
for param in endpoints["get"]["parameters"]:
26+
name = param["name"]
27+
if name == "token":
28+
continue
29+
30+
type = None
31+
transform = None
32+
33+
if "type" in param["schema"]:
34+
type = openapi_type_to_python_type(param["schema"]["type"])
35+
36+
if "retype" in param["schema"]: # noqa: SIM102
37+
if param["schema"]["retype"] == "semicolon-separated-string-to-list":
38+
type = "Sequence[str]"
39+
transform = "join_with_semicolon"
40+
41+
# Handle enums
42+
elif "allOf" in param["schema"]:
43+
assert len(param["schema"]["allOf"]) == 1
44+
assert "$ref" in param["schema"]["allOf"][0]
45+
ref = param["schema"]["allOf"][0]["$ref"]
46+
ref_name = ref.split("/")[-1]
47+
type = f"Literal{openapi['components']['schemas'][ref_name]['enum']}"
48+
else:
49+
raise Exception("Unknown param type")
50+
51+
params.append(
52+
{
53+
"required": param["required"],
54+
"param": name,
55+
"name": param["schema"].get("rename") or name,
56+
"type": type,
57+
"transform": transform,
58+
}
59+
)
60+
61+
if tag not in tags:
62+
tags[tag] = []
63+
64+
tags[tag].append(
65+
{
66+
"function_name": function_name,
67+
"path": path,
68+
"params": params,
69+
"return_top_level_data": operation_id
70+
in ["tiktok_user_posts_from_username", "tiktok_user_posts_from_secuid"],
71+
}
72+
)
73+
74+
return tags
75+
76+
77+
def main():
78+
openapi = requests.get("https://ensembledata.com/apis/openapi.json").json()
79+
80+
env = Environment(
81+
loader=PackageLoader("codegen"),
82+
autoescape=select_autoescape(),
83+
lstrip_blocks=True,
84+
trim_blocks=True,
85+
)
86+
87+
tags = parse_endpoints_by_tag(openapi)
88+
client_template = env.get_template("client.jinja")
89+
90+
client_content = client_template.render(tags=tags, async_methods=False)
91+
async_client_content = client_template.render(tags=tags, async_methods=True)
92+
93+
with open("ensembledata/api/_client.py", "w") as file:
94+
file.write(client_content)
95+
96+
with open("ensembledata/api/_async_client.py", "w") as file:
97+
file.write(async_client_content)
98+
99+
100+
if __name__ == "__main__":
101+
main()

codegen/templates/client.jinja

+68
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
from __future__ import annotations
2+
3+
import sys
4+
from typing import TYPE_CHECKING, Any, Mapping, Sequence
5+
6+
if sys.version_info < (3, 8):
7+
from typing_extensions import Literal
8+
else:
9+
from typing import Literal
10+
11+
from ._requester import {{ "Async" if async_methods else "" }}Requester
12+
13+
if TYPE_CHECKING:
14+
from ._response import EDResponse
15+
16+
class UseDefault:
17+
pass
18+
19+
USE_DEFAULT = UseDefault()
20+
21+
{% for key, endpoints in tags.items() %}
22+
class {{ key }}Endpoints:
23+
24+
def __init__(self, requester: {{ "Async" if async_methods else "" }}Requester):
25+
self.requester = requester
26+
27+
{% for x in endpoints %}
28+
{{ "async " if async_methods else "" }}def {{ x.function_name }}(
29+
self,
30+
*,
31+
{% for item in x.params %}
32+
{% if item.required %}
33+
{{ item.name }}: {{ item.type }},
34+
{% else %}
35+
{{ item.name }}: {{ item.type }} | UseDefault = USE_DEFAULT,
36+
{% endif %}
37+
{% endfor %}
38+
extra_params: Mapping[str, Any] | None = None,
39+
) -> EDResponse:
40+
params: dict[str, Any] = {
41+
{% for item in x.params %}
42+
{% if item.transform == "join_with_semicolon" %}
43+
"{{ item.param }}": ";".join({{ item.name }}),
44+
{% else %}
45+
"{{ item.param }}": {{ item.name }},
46+
{% endif %}
47+
{% endfor %}
48+
}
49+
if extra_params is not None:
50+
params = {**extra_params, **params}
51+
params = {k: v for k, v in params.items() if not (v is None or v is USE_DEFAULT)}
52+
return {{ "await " if async_methods else "" }}self.requester.get("{{ x.path }}", params=params{{ ", return_top_level_data=True" if x.return_top_level_data else "" }})
53+
54+
{% endfor %}
55+
56+
{% endfor %}
57+
58+
class ED{{ "Async" if async_methods else "" }}Client:
59+
def __init__(self, token: str, *, timeout: int = 600, max_network_retries: int = 3):
60+
self.requester = {{ "Async" if async_methods else "" }}Requester(
61+
token, timeout=timeout, max_network_retries=max_network_retries
62+
)
63+
{% for tag in tags %}
64+
self.{{ tag|lower }} = {{ tag }}Endpoints(self.requester)
65+
{% endfor %}
66+
67+
async def request(self, uri: str, params: Mapping[str, Any] | None = None) -> EDResponse:
68+
return {{ "await " if async_methods else "" }}self.requester.get(uri, params=params or {})

ensembledata/api/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
from ._async_client import EDAsyncClient
22
from ._client import EDClient
33
from ._response import EDResponse
4-
from .errors import EDError
54
from ._version import version
5+
from .errors import EDError
66

77
__version__ = version
88
__all__ = ["EDClient", "EDAsyncClient", "EDResponse", "EDError"]

0 commit comments

Comments
 (0)