Skip to content
This repository was archived by the owner on Apr 15, 2025. It is now read-only.

feat(engine): json protocol #1035

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
24 changes: 12 additions & 12 deletions databases/tests/test_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import pytest

import prisma
# import prisma
from prisma import Prisma
from prisma.errors import FieldNotFoundError, ForeignKeyViolationError

Expand Down Expand Up @@ -39,19 +39,19 @@ async def test_field_not_found_error(client: Prisma) -> None:
)


@pytest.mark.asyncio
@pytest.mark.prisma
async def test_field_not_found_error_selection() -> None:
"""The FieldNotFoundError is raised when an unknown field is passed to selections."""
# @pytest.mark.asyncio
# @pytest.mark.prisma
# async def test_field_not_found_error_selection() -> None:
# """The FieldNotFoundError is raised when an unknown field is passed to selections."""

class CustomPost(prisma.bases.BasePost):
foo_field: str
# class CustomPost(prisma.bases.BasePost):
# foo_field: str

with pytest.raises(
FieldNotFoundError,
match=r'Field \'foo_field\' not found in enclosing type \'Post\'',
):
await CustomPost.prisma().find_first()
# with pytest.raises(
# FieldNotFoundError,
# match=r'Field \'foo_field\' not found in enclosing type \'Post\'',
# ):
# await CustomPost.prisma().find_first()
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Caution

test_field_not_found_error_selection skipped

JSON protocol does not explicitly select all fields, so it won't raise a FieldNotFoundError here.

However we could:

  • explicitly select all fields (cons: this could be hard to impl and bring overhead to serializer, prisma-engine and maybe even db); or...
  • leave it to pydantic as it will raise validation errors if data doesn't match the model.



@pytest.mark.asyncio
Expand Down
53 changes: 39 additions & 14 deletions src/prisma/_base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,24 @@

from pydantic import BaseModel

from ._types import Datasource, HttpConfig, PrismaMethod, MetricsFormat, TransactionId, DatasourceOverride
from ._types import (
Datasource,
HttpConfig,
PrismaMethod,
MetricsFormat,
TransactionId,
DatasourceOverride,
)
from .engine import (
SyncQueryEngine,
AsyncQueryEngine,
BaseAbstractEngine,
SyncAbstractEngine,
AsyncAbstractEngine,
json as json_proto,
)
from .errors import ClientNotConnectedError, ClientNotRegisteredError
from ._compat import model_parse, removeprefix
from ._builder import QueryBuilder
from ._metrics import Metrics
from ._registry import get_client
from .generator.models import EngineType
Expand Down Expand Up @@ -286,15 +293,15 @@ def _prepare_connect_args(
log.debug('datasources: %s', datasources)
return timeout, datasources

def _make_query_builder(
def _serialize(
self,
*,
method: PrismaMethod,
arguments: dict[str, Any],
model: type[BaseModel] | None,
root_selection: list[str] | None,
) -> QueryBuilder:
return QueryBuilder(
root_selection: json_proto.JsonSelectionSet | None = None,
) -> json_proto.JsonQuery:
return json_proto.serialize(
method=method,
model=model,
arguments=arguments,
Expand Down Expand Up @@ -415,12 +422,21 @@ def _execute(
method: PrismaMethod,
arguments: dict[str, Any],
model: type[BaseModel] | None = None,
root_selection: list[str] | None = None,
root_selection: json_proto.JsonSelectionSet | None = None,
) -> Any:
builder = self._make_query_builder(
method=method, model=model, arguments=arguments, root_selection=root_selection
return json_proto.deserialize(
self._engine.query(
json_proto.dumps(
self._serialize(
method=method,
arguments=arguments,
model=model,
root_selection=root_selection,
)
),
tx_id=self._tx_id,
)
)
return self._engine.query(builder.build(), tx_id=self._tx_id)


class AsyncBasePrisma(BasePrisma[AsyncAbstractEngine]):
Expand Down Expand Up @@ -535,9 +551,18 @@ async def _execute(
method: PrismaMethod,
arguments: dict[str, Any],
model: type[BaseModel] | None = None,
root_selection: list[str] | None = None,
root_selection: json_proto.JsonSelectionSet | None = None,
) -> Any:
builder = self._make_query_builder(
method=method, model=model, arguments=arguments, root_selection=root_selection
return json_proto.deserialize(
await self._engine.query(
json_proto.dumps(
self._serialize(
method=method,
arguments=arguments,
model=model,
root_selection=root_selection,
)
),
tx_id=self._tx_id,
)
)
return await self._engine.query(builder.build(), tx_id=self._tx_id)
2 changes: 1 addition & 1 deletion src/prisma/engine/_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def _spawn_process(
RUST_LOG='error',
RUST_LOG_FORMAT='json',
PRISMA_CLIENT_ENGINE_TYPE='binary',
PRISMA_ENGINE_PROTOCOL='graphql',
PRISMA_ENGINE_PROTOCOL='json',
)

if DEBUG:
Expand Down
3 changes: 3 additions & 0 deletions src/prisma/engine/json/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .types import *
from .serializer import dumps as dumps, serialize as serialize
from .deserializer import deserialize as deserialize
43 changes: 43 additions & 0 deletions src/prisma/engine/json/deserializer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from __future__ import annotations

import json
from typing import Any
from decimal import Decimal
from datetime import datetime
from typing_extensions import TypeGuard

from .types import JsonOutputTaggedValue
from ...fields import Base64


def deserialize(result: Any) -> Any:
if not result:
return result

if isinstance(result, list):
return list(map(deserialize, result))

if isinstance(result, dict):
if is_tagged_value(result):
return result['value'] # XXX: will pydantic cast this?

return {k: deserialize(v) for k, v in result.items()}

return result


def is_tagged_value(value: dict[Any, Any]) -> TypeGuard[JsonOutputTaggedValue]:
return isinstance(value.get('$type'), str)


def deserialize_tagged_value(tagged: JsonOutputTaggedValue) -> Any:
if tagged['$type'] == 'BigInt':
return int(tagged['value'])
elif tagged['$type'] == 'Bytes':
return Base64.fromb64(tagged['value'])
elif tagged['$type'] == 'DateTime':
return datetime.fromisoformat(tagged['value'])
elif tagged['$type'] == 'Decimal':
return Decimal(tagged['value'])
elif tagged['$type'] == 'Json':
return json.loads(tagged['value'])
Loading