Skip to content

Commit 65f918c

Browse files
authored
feat!: add a default global ID encoder, add a id_field on the query type (#1227)
BREAKING CHANGE: the global ID function takes only a str with the GID, the query type uses the id_field arg to determine which kwarg to pass to the decoder
1 parent a538334 commit 65f918c

File tree

9 files changed

+53
-24
lines changed

9 files changed

+53
-24
lines changed

ariadne/contrib/relay/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
RelayQueryType,
99
)
1010
from ariadne.contrib.relay.types import ConnectionResolver, GlobalIDTuple
11+
from ariadne.contrib.relay.utils import decode_global_id, encode_global_id
1112

1213
__all__ = [
1314
"ConnectionArguments",
@@ -17,4 +18,6 @@
1718
"RelayQueryType",
1819
"ConnectionResolver",
1920
"GlobalIDTuple",
21+
"decode_global_id",
22+
"encode_global_id",
2023
]

ariadne/contrib/relay/connection.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
1-
from typing import Sequence
2-
3-
from typing_extensions import Any
1+
from typing import Sequence, Any
42

53
from ariadne.contrib.relay.arguments import ConnectionArgumentsUnion
64

@@ -12,14 +10,19 @@ def __init__(
1210
total: int,
1311
has_next_page: bool,
1412
has_previous_page: bool,
13+
id_field: str = "id",
1514
) -> None:
1615
self.edges = edges
1716
self.total = total
1817
self.has_next_page = has_next_page
1918
self.has_previous_page = has_previous_page
19+
self.id_field = id_field
20+
21+
def get_cursor(self, obj):
22+
return obj[self.id_field]
2023

21-
def get_cursor(self, node):
22-
return node["id"]
24+
def get_node(self, obj):
25+
return obj
2326

2427
def get_page_info(
2528
self, connection_arguments: ConnectionArgumentsUnion
@@ -32,4 +35,7 @@ def get_page_info(
3235
}
3336

3437
def get_edges(self):
35-
return [{"node": node, "cursor": self.get_cursor(node)} for node in self.edges]
38+
return [
39+
{"node": self.get_node(obj), "cursor": self.get_cursor(obj)}
40+
for obj in self.edges
41+
]

ariadne/contrib/relay/objects.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
from base64 import b64decode
21
from inspect import iscoroutinefunction
32
from typing import Optional, Tuple, cast
43

@@ -14,17 +13,13 @@
1413
from ariadne.contrib.relay.types import (
1514
ConnectionResolver,
1615
GlobalIDDecoder,
17-
GlobalIDTuple,
1816
)
17+
from ariadne.contrib.relay.utils import decode_global_id
1918
from ariadne.types import Resolver
2019
from ariadne.utils import type_get_extension
2120
from ariadne.utils import type_set_extension
2221

2322

24-
def decode_global_id(kwargs) -> GlobalIDTuple:
25-
return GlobalIDTuple(*b64decode(kwargs["id"]).decode().split(":"))
26-
27-
2823
class RelayObjectType(ObjectType):
2924
_node_resolver: Optional[Resolver] = None
3025

@@ -93,6 +88,7 @@ def bind_to_schema(self, schema: GraphQLSchema) -> None:
9388

9489

9590
class RelayNodeInterfaceType(InterfaceType):
91+
9692
def __init__(
9793
self,
9894
type_resolver: Optional[Resolver] = None,
@@ -105,13 +101,15 @@ def __init__(
105101
self,
106102
node: Optional[RelayNodeInterfaceType] = None,
107103
global_id_decoder: GlobalIDDecoder = decode_global_id,
104+
id_field: str = "id",
108105
) -> None:
109106
super().__init__("Query")
110107
if node is None:
111108
node = RelayNodeInterfaceType()
112109
self.node = node
113110
self.set_field("node", self.resolve_node)
114111
self.global_id_decoder = global_id_decoder
112+
self.id_field = id_field
115113

116114
@property
117115
def bindables(self) -> Tuple["RelayQueryType", "RelayNodeInterfaceType"]:
@@ -127,7 +125,7 @@ def get_node_resolver(self, type_name, schema: GraphQLSchema) -> Resolver:
127125
return resolver
128126

129127
def resolve_node(self, obj, info, *args, **kwargs):
130-
type_name, _ = self.global_id_decoder(kwargs)
128+
type_name, _ = self.global_id_decoder(kwargs[self.id_field])
131129

132130
resolver = self.get_node_resolver(type_name, info.schema)
133131

ariadne/contrib/relay/types.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
from collections import namedtuple
2-
from typing import Any, Callable, Dict
2+
from typing import Callable
33

44
from typing_extensions import TypeVar
55

66
from ariadne.contrib.relay.connection import RelayConnection
77

88
ConnectionResolver = TypeVar("ConnectionResolver", bound=Callable[..., RelayConnection])
99
GlobalIDTuple = namedtuple("GlobalIDTuple", ["type", "id"])
10-
GlobalIDDecoder = Callable[[Dict[str, Any]], GlobalIDTuple]
10+
GlobalIDDecoder = Callable[[str], GlobalIDTuple]
11+
GlobalIDEncoder = Callable[[str, str], str]

ariadne/contrib/relay/utils.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
from base64 import b64decode, b64encode
2+
3+
from ariadne.contrib.relay.types import (
4+
GlobalIDTuple,
5+
)
6+
7+
8+
def decode_global_id(gid: str) -> GlobalIDTuple:
9+
return GlobalIDTuple(*b64decode(gid).decode().split(":"))
10+
11+
12+
def encode_global_id(type_name: str, _id: str) -> str:
13+
return b64encode(f"{type_name}:{_id}".encode()).decode()

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
44

55
[project]
66
name = "ariadne"
7-
version = "0.25.1"
7+
version = "0.25.2"
88
description = "Ariadne is a Python library for implementing GraphQL servers."
99
authors = [{ name = "Mirumee Software", email = "[email protected]" }]
1010
readme = "README.md"

tests/relay/conftest.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def relay_type_defs():
5858

5959
@pytest.fixture
6060
def global_id_decoder():
61-
return lambda kwargs: GlobalIDTuple(*b64decode(kwargs["bid"]).decode().split(":"))
61+
return lambda gid: GlobalIDTuple(*b64decode(gid).decode().split(":"))
6262

6363

6464
@pytest.fixture
@@ -71,6 +71,7 @@ def relay_query(factions, relay_node_interface, global_id_decoder):
7171
query = RelayQueryType(
7272
node=relay_node_interface,
7373
global_id_decoder=global_id_decoder,
74+
id_field="bid",
7475
)
7576
query.set_field("rebels", lambda *_: factions[0])
7677
query.set_field("empire", lambda *_: factions[1])

tests/relay/test_objects.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import pytest
2+
23
from graphql import extend_schema
34
from graphql import graphql_sync
45
from graphql import parse
@@ -15,9 +16,6 @@
1516
RelayQueryType,
1617
decode_global_id,
1718
)
18-
from ariadne.contrib.relay.types import (
19-
GlobalIDTuple,
20-
)
2119

2220

2321
@pytest.fixture
@@ -31,10 +29,6 @@ def friends_connection():
3129
)
3230

3331

34-
def test_decode_global_id():
35-
assert decode_global_id({"id": "VXNlcjox"}) == GlobalIDTuple("User", "1")
36-
37-
3832
def test_default_id_decoder():
3933
query = RelayQueryType()
4034
assert query.global_id_decoder is decode_global_id

tests/relay/test_utils.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
from ariadne.contrib.relay import (
2+
GlobalIDTuple,
3+
decode_global_id,
4+
encode_global_id,
5+
)
6+
7+
8+
def test_decode_global_id():
9+
assert decode_global_id("VXNlcjox") == GlobalIDTuple("User", "1")
10+
11+
12+
def test_encode_global_id():
13+
assert encode_global_id("User", "1") == "VXNlcjox"

0 commit comments

Comments
 (0)