Skip to content

refactor: Remove LocalClient and use HTTPClient for local Tesseracts as well #27

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Feb 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 21 additions & 74 deletions tesseract_core/sdk/tesseract.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,13 @@ class Tesseract:
instances spawned when instantiating the class).
"""

_client: HTTPClient | LocalClient
image: str
volumes: list[str] | None
gpus: list[str] | None
project_id: str

_client: HTTPClient
project_id: str | None = None
container_id: str | None = None

def __init__(self, url: str) -> None:
self._client = HTTPClient(url)
Expand All @@ -52,57 +54,45 @@ def from_image(
return obj

def __enter__(self):
if not self.client_type == "local":
raise ValueError(
"Use Tesseract.from_image(...) to create a context-managed Tesseract instance."
)
self._serve(volumes=self.volumes, gpus=self.gpus)
self._client = LocalClient(self.tesseract_container_id)
url = self._serve(volumes=self.volumes, gpus=self.gpus)
self._client = HTTPClient(url)
return self

def __exit__(self, exc_type, exc_value, traceback):
engine.teardown(self.project_id)
self.project_id = None
self.container_id = None

def _serve(
self,
port: str = "",
volumes: list[str] | None = None,
gpus: list[str] | None = None,
) -> None:
if hasattr(self, "tesseract_container_id"):
self.tesseract_container_id: str
) -> str:
if self.container_id:
raise RuntimeError(
"Client already attached to the Tesseract "
f"container {self.tesseract_container_id}"
"Client already attached to the Tesseract container {self.container_id}"
)
project_id = engine.serve([self.image], port=port, volumes=volumes, gpus=gpus)

command = ["docker", "compose", "-p", project_id, "ps", "--format", "json"]
result = subprocess.run(command, capture_output=True, text=True)

containers = json.loads(result.stdout)
# This relies on the fact that result.stdout from docker compose ps
# contains multiple json dicts, one for each container, separated by newlines,
# but json.loads will only parse the first one.
# The first_container dict contains useful info like container id, ports, etc.
first_container = json.loads(result.stdout)

if containers:
first_container_id = containers["ID"]
if first_container:
first_container_id = first_container["ID"]
first_container_port = first_container["Publishers"][0]["PublishedPort"]
else:
raise RuntimeError("No containers found.")

self.tesseract_container_id = first_container_id
self.project_id = project_id

@cached_property
def client_type(self) -> str:
"""Get the type of client being used ('http' or 'local')."""
return (
"http"
if hasattr(self, "_client") and isinstance(self._client, HTTPClient)
else "local"
)

@property
def url(self) -> str | None:
"""Get the URL if using HTTP client, None for local client."""
return getattr(self._client, "url", None)
self.container_id = first_container_id
return f"http://localhost:{first_container_port}"

@cached_property
def openapi_schema(self) -> dict:
Expand Down Expand Up @@ -351,46 +341,3 @@ def _run_tesseract(self, endpoint: str, payload: dict | None = None) -> dict:
endpoint = "openapi.json"

return self._request(endpoint, method, payload)


class LocalClient:
"""A client that connects to a local Tesseract."""

def __init__(self, container_id: str) -> None:
self.container_id = container_id

def _run_tesseract(
self,
endpoint: str,
payload: dict | None = None,
) -> dict:
command = endpoint.replace("_", "-")
if payload:
encoded_payload = _tree_map(
_encode_array, payload, is_leaf=lambda x: hasattr(x, "shape")
)
else:
encoded_payload = None

args = []
if encoded_payload:
args.append(json.dumps(encoded_payload))

out, err = engine.exec_tesseract(self.container_id, command, args)

if err:
raise RuntimeError(err)

data = json.loads(out)

if command in [
"apply",
"jacobian",
"jacobian-vector-product",
"vector-jacobian-product",
]:
data = _tree_map(
_decode_array, data, is_leaf=lambda x: type(x) is dict and "shape" in x
)

return data
37 changes: 11 additions & 26 deletions tests/sdk_tests/test_tesseract.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from tesseract_core import Tesseract
from tesseract_core.sdk.tesseract import (
HTTPClient,
LocalClient,
_decode_array,
_encode_array,
_tree_map,
Expand All @@ -17,7 +16,9 @@ def mock_serving(mocker):
serve_mock.return_value = "proj-id-123"

subprocess_run_mock = mocker.patch("subprocess.run")
subprocess_run_mock.return_value.stdout = '{"ID": "abc1234"}'
subprocess_run_mock.return_value.stdout = (
'{"ID": "abc1234", "Publishers":[{"PublishedPort": 54321}]}'
)

teardown_mock = mocker.patch("tesseract_core.sdk.engine.teardown")
return {
Expand All @@ -30,15 +31,12 @@ def mock_serving(mocker):
@pytest.fixture
def mock_clients(mocker):
mocker.patch("tesseract_core.sdk.tesseract.HTTPClient._run_tesseract")
mocker.patch("tesseract_core.sdk.tesseract.LocalClient._run_tesseract")


def test_Tesseract_init():
# Instantiate with a url
t = Tesseract(url="localhost")

assert t.client_type == "http"

# The attributes for local Tesseracts should not be set
assert not hasattr(t, "image")
assert not hasattr(t, "gpus")
Expand All @@ -55,21 +53,18 @@ def test_Tesseract_from_image():

# Let's also check that stuff we don't expect there is not there
assert not hasattr(t, "url")
assert t.client_type == "local"


def test_Tesseract_schema_methods(mocker, mock_serving):
mocked_run = mocker.patch("tesseract_core.sdk.engine.exec_tesseract")
mocked_run.return_value = '{"#defs": {"some": "stuff"}}', None
mocked_run = mocker.patch("tesseract_core.sdk.tesseract.HTTPClient._run_tesseract")
mocked_run.return_value = {"#defs": {"some": "stuff"}}

with Tesseract.from_image("sometesseract:0.2.3") as t:
input_schema = t.input_schema
output_schema = t.output_schema
openapi_schema = t.openapi_schema

assert (
input_schema == output_schema == openapi_schema == {"#defs": {"some": "stuff"}}
)
assert input_schema == output_schema == openapi_schema == mocked_run.return_value


def test_serve_lifecycle(mock_serving, mock_clients):
Expand All @@ -84,21 +79,11 @@ def test_serve_lifecycle(mock_serving, mock_clients):

mock_serving["teardown_mock"].assert_called_with("proj-id-123")


def test_LocalClient_exec(mocker):
mocked_exec = mocker.patch("tesseract_core.sdk.engine.exec_tesseract")
mocked_exec.return_value = '{"result": [4,4,4]}', None

client = LocalClient(container_id="1234567")

out = client._run_tesseract("apply", {"inputs": {"a": 1}})

assert out == {"result": [4, 4, 4]}
mocked_exec.assert_called_with(
"1234567",
"apply",
['{"inputs": {"a": 1}}'],
)
# check that the same Tesseract obj cannot be used to instantiate two containers
with pytest.raises(RuntimeError):
with t:
with t:
pass


def test_HTTPClient_run_tesseract(mocker):
Expand Down
Loading