Skip to content

Commit 72ec555

Browse files
authored
Add ChatDatabricks() (#82)
* Close #73: Add ChatDatabricks() * Update changelog * Set secrets as env vars * Use tenacity to retry flaky tests * Tweak docstring
1 parent 72be129 commit 72ec555

File tree

12 files changed

+260
-52
lines changed

12 files changed

+260
-52
lines changed

.github/workflows/test.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ jobs:
2929
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
3030
GOOGLE_API_KEY: ${{ secrets.GOOGLE_API_KEY }}
3131
AZURE_OPENAI_API_KEY: ${{ secrets.AZURE_OPENAI_API_KEY }}
32+
DATABRICKS_HOST: ${{ secrets.DATABRICKS_HOST }}
33+
DATABRICKS_TOKEN: ${{ secrets.DATABRICKS_TOKEN }}
3234
# Free tier of Google is rate limited, so we only test on 3.12
3335
TEST_GOOGLE: ${{ matrix.config.test_google }}
3436
# Free tier of Azure is rate limited, so we only test on 3.12

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1111

1212
### New features
1313

14+
* Added `ChatDatabricks()`, for chatting with Databrick's [foundation models](https://docs.databricks.com/aws/en/machine-learning/model-serving/score-foundation-models). (#82)
1415
* `.stream()` and `.stream_async()` gain a `content` argument. Set this to `"all"` to include `ContentToolRequest` and `ContentToolResponse` instances in the stream. (#75)
1516
* `ContentToolRequest` and `ContentToolResponse` are now exported to `chatlas` namespace. (#75)
1617
* `ContentToolRequest` and `ContentToolResponse` now have `.tagify()` methods, making it so they can render automatically in a Shiny chatbot. (#75)

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ It also supports the following enterprise cloud providers:
4747

4848
* AWS Bedrock: [`ChatBedrockAnthropic()`](https://posit-dev.github.io/chatlas/reference/ChatBedrockAnthropic.html).
4949
* Azure OpenAI: [`ChatAzureOpenAI()`](https://posit-dev.github.io/chatlas/reference/ChatAzureOpenAI.html).
50+
* Databricks: [`ChatDatabricks()`](https://posit-dev.github.io/chatlas/reference/ChatDatabricks.html).
5051
* Snowflake Cortex: [`ChatSnowflake()`](https://posit-dev.github.io/chatlas/reference/ChatSnowflake.html).
5152
* Vertex AI: [`ChatVertex()`](https://posit-dev.github.io/chatlas/reference/ChatVertex.html).
5253

chatlas/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from ._content import ContentToolRequest, ContentToolResult
66
from ._content_image import content_image_file, content_image_plot, content_image_url
77
from ._content_pdf import content_pdf_file, content_pdf_url
8+
from ._databricks import ChatDatabricks
89
from ._github import ChatGithub
910
from ._google import ChatGoogle, ChatVertex
1011
from ._groq import ChatGroq
@@ -27,6 +28,7 @@
2728
"ChatAnthropic",
2829
"ChatAuto",
2930
"ChatBedrockAnthropic",
31+
"ChatDatabricks",
3032
"ChatGithub",
3133
"ChatGoogle",
3234
"ChatGroq",

chatlas/_auto.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from ._anthropic import ChatAnthropic, ChatBedrockAnthropic
88
from ._chat import Chat
9+
from ._databricks import ChatDatabricks
910
from ._github import ChatGithub
1011
from ._google import ChatGoogle, ChatVertex
1112
from ._groq import ChatGroq
@@ -18,6 +19,7 @@
1819
AutoProviders = Literal[
1920
"anthropic",
2021
"bedrock-anthropic",
22+
"databricks",
2123
"github",
2224
"google",
2325
"groq",
@@ -32,6 +34,7 @@
3234
_provider_chat_model_map: dict[AutoProviders, Callable[..., Chat]] = {
3335
"anthropic": ChatAnthropic,
3436
"bedrock-anthropic": ChatBedrockAnthropic,
37+
"databricks": ChatDatabricks,
3538
"github": ChatGithub,
3639
"google": ChatGoogle,
3740
"groq": ChatGroq,

chatlas/_databricks.py

Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING, Optional
4+
5+
from ._chat import Chat
6+
from ._logging import log_model_default
7+
from ._openai import OpenAIProvider
8+
from ._turn import Turn, normalize_turns
9+
10+
if TYPE_CHECKING:
11+
from databricks.sdk import WorkspaceClient
12+
13+
from ._openai import ChatCompletion
14+
from .types.openai import SubmitInputArgs
15+
16+
17+
def ChatDatabricks(
18+
*,
19+
system_prompt: Optional[str] = None,
20+
model: Optional[str] = None,
21+
turns: Optional[list[Turn]] = None,
22+
workspace_client: Optional["WorkspaceClient"] = None,
23+
) -> Chat["SubmitInputArgs", ChatCompletion]:
24+
"""
25+
Chat with a model hosted on Databricks.
26+
27+
Databricks provides out-of-the-box access to a number of [foundation
28+
models](https://docs.databricks.com/en/machine-learning/model-serving/score-foundation-models.html)
29+
and can also serve as a gateway for external models hosted by a third party.
30+
31+
Prerequisites
32+
--------------
33+
34+
::: {.callout-note}
35+
## Python requirements
36+
37+
`ChatDatabricks` requires the `databricks-sdk` package: `pip install
38+
"chatlas[databricks]"`.
39+
:::
40+
41+
::: {.callout-note}
42+
## Authentication
43+
44+
`chatlas` delegates to the `databricks-sdk` package for authentication with
45+
Databricks. As such, you can use any of the authentication methods discussed
46+
here:
47+
48+
https://docs.databricks.com/aws/en/dev-tools/sdk-python#authentication
49+
50+
Note that Python-specific article points to this language-agnostic "unified"
51+
approach to authentication:
52+
53+
https://docs.databricks.com/aws/en/dev-tools/auth/unified-auth
54+
55+
There, you'll find all the options listed, but a simple approach that
56+
generally works well is to set the following environment variables:
57+
58+
* `DATABRICKS_HOST`: The Databricks host URL for either the Databricks
59+
workspace endpoint or the Databricks accounts endpoint.
60+
* `DATABRICKS_TOKEN`: The Databricks personal access token.
61+
:::
62+
63+
Parameters
64+
----------
65+
system_prompt
66+
A system prompt to set the behavior of the assistant.
67+
model
68+
The model to use for the chat. The default, None, will pick a reasonable
69+
default, and warn you about it. We strongly recommend explicitly
70+
choosing a model for all but the most casual use.
71+
turns
72+
A list of turns to start the chat with (i.e., continuing a previous
73+
conversation). If not provided, the conversation begins from scratch. Do
74+
not provide non-`None` values for both `turns` and `system_prompt`. Each
75+
message in the list should be a dictionary with at least `role` (usually
76+
`system`, `user`, or `assistant`, but `tool` is also possible). Normally
77+
there is also a `content` field, which is a string.
78+
workspace_client
79+
A `databricks.sdk.WorkspaceClient()` to use for the connection. If not
80+
provided, a new client will be created.
81+
82+
Returns
83+
-------
84+
Chat
85+
A chat object that retains the state of the conversation.
86+
"""
87+
if model is None:
88+
model = log_model_default("databricks-dbrx-instruct")
89+
90+
return Chat(
91+
provider=DatabricksProvider(
92+
model=model,
93+
workspace_client=workspace_client,
94+
),
95+
turns=normalize_turns(
96+
turns or [],
97+
system_prompt,
98+
),
99+
)
100+
101+
102+
class DatabricksProvider(OpenAIProvider):
103+
def __init__(
104+
self,
105+
*,
106+
model: str,
107+
workspace_client: Optional["WorkspaceClient"] = None,
108+
):
109+
try:
110+
from databricks.sdk import WorkspaceClient
111+
except ImportError:
112+
raise ImportError(
113+
"`ChatDatabricks()` requires the `databricks-sdk` package. "
114+
"Install it with `pip install databricks-sdk[openai]`."
115+
)
116+
117+
try:
118+
import httpx
119+
from openai import AsyncOpenAI
120+
except ImportError:
121+
raise ImportError(
122+
"`ChatDatabricks()` requires the `openai` package. "
123+
"Install it with `pip install openai`."
124+
)
125+
126+
self._model = model
127+
self._seed = None
128+
129+
if workspace_client is None:
130+
workspace_client = WorkspaceClient()
131+
132+
client = workspace_client.serving_endpoints.get_open_ai_client()
133+
134+
self._client = client
135+
136+
# The databricks sdk does currently expose an async client, but we can
137+
# effectively mirror what .get_open_ai_client() does internally.
138+
# Note also there is a open PR to add async support that does essentially
139+
# the same thing:
140+
# https://github.com/databricks/databricks-sdk-py/pull/851
141+
self._async_client = AsyncOpenAI(
142+
base_url=client.base_url,
143+
api_key="no-token", # A placeholder to pass validations, this will not be used
144+
http_client=httpx.AsyncClient(auth=client._client.auth),
145+
)

chatlas/_openai.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,8 @@ def _chat_perform_args(
325325
del kwargs_full["tools"]
326326

327327
if stream and "stream_options" not in kwargs_full:
328-
kwargs_full["stream_options"] = {"include_usage": True}
328+
if self.__class__.__name__ != "DatabricksProvider":
329+
kwargs_full["stream_options"] = {"include_usage": True}
329330

330331
return kwargs_full
331332

docs/_quarto.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ quartodoc:
8080
- ChatAuto
8181
- ChatAzureOpenAI
8282
- ChatBedrockAnthropic
83+
- ChatDatabricks
8384
- ChatGithub
8485
- ChatGoogle
8586
- ChatGroq

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ dev = [
5050
"google-genai>=1.2.0",
5151
"numpy>1.24.4",
5252
"tiktoken",
53+
"databricks-sdk",
5354
"snowflake-ml-python",
5455
# torch (a dependency of snowflake-ml-python) is not yet compatible with Python >3.11
5556
"torch;python_version<='3.11'",
@@ -73,6 +74,7 @@ docs = [
7374
# Provider extras ----
7475
anthropic = ["anthropic"]
7576
bedrock-anthropic = ["anthropic[bedrock]"]
77+
databricks = ["databricks-sdk[openai]"]
7678
github = ["openai"]
7779
google = ["google-genai"]
7880
groq = ["openai"]

tests/conftest.py

Lines changed: 10 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import tempfile
22
from pathlib import Path
3-
from typing import Awaitable, Callable
3+
from typing import Callable
44

55
import pytest
66
from chatlas import (
@@ -14,6 +14,7 @@
1414
)
1515
from PIL import Image
1616
from pydantic import BaseModel
17+
from tenacity import retry, wait_exponential
1718

1819
ChatFun = Callable[..., Chat]
1920

@@ -34,36 +35,18 @@ class ArticleSummary(BaseModel):
3435
"""
3536

3637

37-
def retryassert(assert_func: Callable[..., None], retries=1):
38-
for _ in range(retries):
39-
try:
40-
return assert_func()
41-
except Exception:
42-
pass
43-
return assert_func()
44-
45-
46-
async def retryassert_async(assert_func: Callable[..., Awaitable[None]], retries=1):
47-
for _ in range(retries):
48-
try:
49-
return await assert_func()
50-
except Exception:
51-
pass
52-
return await assert_func()
53-
54-
5538
def assert_turns_system(chat_fun: ChatFun):
5639
system_prompt = "Return very minimal output, AND ONLY USE UPPERCASE."
5740

5841
chat = chat_fun(system_prompt=system_prompt)
5942
response = chat.chat("What is the name of Winnie the Pooh's human friend?")
6043
response_text = str(response)
6144
assert len(chat.get_turns()) == 2
62-
assert "CHRISTOPHER ROBIN" in response_text
45+
assert "CHRISTOPHER ROBIN" in response_text.upper()
6346

6447
chat = chat_fun(turns=[Turn("system", system_prompt)])
6548
response = chat.chat("What is the name of Winnie the Pooh's human friend?")
66-
assert "CHRISTOPHER ROBIN" in str(response)
49+
assert "CHRISTOPHER ROBIN" in str(response).upper()
6750
assert len(chat.get_turns()) == 2
6851

6952

@@ -267,3 +250,9 @@ def assert_pdf_local(chat_fun: ChatFun):
267250
"Two word answer only.",
268251
)
269252
assert "red delicious" in str(response).lower()
253+
254+
255+
retry_api_call = retry(
256+
wait=wait_exponential(min=1, max=60),
257+
reraise=True,
258+
)

tests/test_provider_anthropic.py

Lines changed: 16 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,7 @@
1313
assert_tools_simple_stream_content,
1414
assert_turns_existing,
1515
assert_turns_system,
16-
retryassert,
17-
retryassert_async,
16+
retry_api_call,
1817
)
1918

2019

@@ -50,53 +49,40 @@ def test_anthropic_respects_turns_interface():
5049
assert_turns_existing(chat_fun)
5150

5251

52+
@retry_api_call
5353
def test_anthropic_tool_variations():
5454
chat_fun = ChatAnthropic
55-
56-
def run_simpleassert():
57-
assert_tools_simple(chat_fun)
58-
59-
retryassert(run_simpleassert, retries=5)
60-
55+
assert_tools_simple(chat_fun)
6156
assert_tools_simple_stream_content(chat_fun)
57+
assert_tools_sequential(chat_fun, total_calls=6)
6258

63-
def run_parallelassert():
64-
# For some reason, at the time of writing, Claude 3.7 doesn't
65-
# respond with multiple tools at once for this test (but it does)
66-
# answer the question correctly with sequential tools.
67-
def chat_fun2(**kwargs):
68-
return ChatAnthropic(model="claude-3-5-sonnet-latest", **kwargs)
69-
70-
assert_tools_parallel(chat_fun2)
7159

72-
retryassert(run_parallelassert, retries=5)
60+
@retry_api_call
61+
def test_anthropic_tool_variations_parallel():
62+
# For some reason, at the time of writing, Claude 3.7 doesn't
63+
# respond with multiple tools at once for this test (but it does)
64+
# answer the question correctly with sequential tools.
65+
def chat_fun(**kwargs):
66+
return ChatAnthropic(model="claude-3-5-sonnet-latest", **kwargs)
7367

74-
# Fails occassionally returning "" instead of Susan
75-
def run_sequentialassert():
76-
assert_tools_sequential(chat_fun, total_calls=6)
77-
78-
retryassert(run_sequentialassert, retries=5)
68+
assert_tools_parallel(chat_fun)
7969

8070

8171
@pytest.mark.asyncio
72+
@retry_api_call
8273
async def test_anthropic_tool_variations_async():
83-
async def run_asyncassert():
84-
await assert_tools_async(ChatAnthropic)
85-
86-
await retryassert_async(run_asyncassert, retries=5)
74+
await assert_tools_async(ChatAnthropic)
8775

8876

8977
def test_data_extraction():
9078
assert_data_extraction(ChatAnthropic)
9179

9280

81+
@retry_api_call
9382
def test_anthropic_images():
9483
chat_fun = ChatAnthropic
9584

96-
def run_inlineassert():
97-
assert_images_inline(chat_fun)
98-
99-
retryassert(run_inlineassert, retries=3)
85+
assert_images_inline(chat_fun)
10086
assert_images_remote_error(chat_fun)
10187

10288

0 commit comments

Comments
 (0)