Skip to content

feat(ibis): Add Oracle connector #1067

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Mar 3, 2025
13 changes: 13 additions & 0 deletions ibis-server/app/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ class QueryMySqlDTO(QueryDTO):
connection_info: ConnectionUrl | MySqlConnectionInfo = connection_info_field


class QueryOracleDTO(QueryDTO):
connection_info: ConnectionUrl | OracleConnectionInfo = connection_info_field


class QueryPostgresDTO(QueryDTO):
connection_info: ConnectionUrl | PostgresConnectionInfo = connection_info_field

Expand Down Expand Up @@ -131,6 +135,14 @@ class PostgresConnectionInfo(BaseModel):
password: SecretStr | None = None


class OracleConnectionInfo(BaseModel):
host: SecretStr = Field(examples=["localhost"])
port: SecretStr = Field(examples=[1521])
database: SecretStr
user: SecretStr
password: SecretStr | None = None


class SnowflakeConnectionInfo(BaseModel):
user: SecretStr
password: SecretStr
Expand Down Expand Up @@ -201,6 +213,7 @@ class GcsFileConnectionInfo(BaseModel):
| ConnectionUrl
| MSSqlConnectionInfo
| MySqlConnectionInfo
| OracleConnectionInfo
| PostgresConnectionInfo
| SnowflakeConnectionInfo
| TrinoConnectionInfo
Expand Down
14 changes: 14 additions & 0 deletions ibis-server/app/model/data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
ConnectionInfo,
MSSqlConnectionInfo,
MySqlConnectionInfo,
OracleConnectionInfo,
PostgresConnectionInfo,
QueryBigQueryDTO,
QueryCannerDTO,
Expand All @@ -27,6 +28,7 @@
QueryMinioFileDTO,
QueryMSSqlDTO,
QueryMySqlDTO,
QueryOracleDTO,
QueryPostgresDTO,
QueryS3FileDTO,
QuerySnowflakeDTO,
Expand All @@ -43,6 +45,7 @@ class DataSource(StrEnum):
clickhouse = auto()
mssql = auto()
mysql = auto()
oracle = auto()
postgres = auto()
snowflake = auto()
trino = auto()
Expand Down Expand Up @@ -70,6 +73,7 @@ class DataSourceExtension(Enum):
clickhouse = QueryClickHouseDTO
mssql = QueryMSSqlDTO
mysql = QueryMySqlDTO
oracle = QueryOracleDTO
postgres = QueryPostgresDTO
snowflake = QuerySnowflakeDTO
trino = QueryTrinoDTO
Expand Down Expand Up @@ -176,6 +180,16 @@ def get_postgres_connection(info: PostgresConnectionInfo) -> BaseBackend:
password=(info.password and info.password.get_secret_value()),
)

@staticmethod
def get_oracle_connection(info: OracleConnectionInfo) -> BaseBackend:
return ibis.oracle.connect(
host=info.host.get_secret_value(),
port=int(info.port.get_secret_value()),
database=info.database.get_secret_value(),
user=info.user.get_secret_value(),
password=(info.password and info.password.get_secret_value()),
)

@staticmethod
def get_snowflake_connection(info: SnowflakeConnectionInfo) -> BaseBackend:
return ibis.snowflake.connect(
Expand Down
6 changes: 6 additions & 0 deletions ibis-server/app/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ def base64_to_dict(base64_str: str) -> dict:

def to_json(df: pd.DataFrame) -> dict:
for column in df.columns:
if df[column].dtype == object:
# Convert Oracle LOB objects to string
df[column] = df[column].apply(lambda x: str(x) if hasattr(x, "read") else x)
if is_datetime64_any_dtype(df[column].dtype):
df[column] = _to_datetime_and_format(df[column])
return _to_json_obj(df)
Expand Down Expand Up @@ -44,6 +47,9 @@ def default(obj):
return _date_offset_to_str(obj)
if isinstance(obj, datetime.timedelta):
return str(obj)
# Add handling for any remaining LOB objects
if hasattr(obj, "read"): # Check if object is LOB-like
return str(obj)
raise TypeError

json_obj = orjson.loads(
Expand Down
9 changes: 9 additions & 0 deletions ibis-server/docs/how-to-add-data-source.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,15 @@ class PostgresConnectionInfo(BaseModel):
We use the base model of [Pydantic](https://docs.pydantic.dev/latest/api/base_model/) to support our class definitions.
Pydantic provides a convenient field type called [Secret Types](https://docs.pydantic.dev/2.0/usage/types/secrets/) that can protect the sensitive information.

Add your xxxConnectionInfo to ConnectionInfo
```python
ConnectionInfo = (
...
| PostgresConnectionInfo
...
)
```

Return to the `DataSourceExtension` enum class to implement the `get_{data_source}_connection` function.
This function should be specific to your new data source. For example, if you've added a PostgreSQL data source, you might implement a `get_postgres_connection` function.
```python
Expand Down
3 changes: 3 additions & 0 deletions ibis-server/justfile
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ dev:
test MARKER:
poetry run pytest -m '{{ MARKER }}'

test-verbose MARKER:
poetry run pytest -s -v -m '{{ MARKER }}'

image-name := "ghcr.io/canner/wren-engine-ibis:latest"

docker-build:
Expand Down
3 changes: 3 additions & 0 deletions ibis-server/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ ibis-framework = { version = "9.5.0", extras = [
"clickhouse",
"mssql",
"mysql",
"oracle",
"postgres",
"snowflake",
"trino",
Expand All @@ -33,6 +34,7 @@ gql = { extras = ["aiohttp"], version = "3.5.0" }
anyio = "4.8.0"
duckdb = "1.1.3"
opendal = ">=0.45"
oracledb = "2.5.1"

[tool.poetry.group.dev.dependencies]
pytest = "8.3.4"
Expand Down Expand Up @@ -61,6 +63,7 @@ markers = [
"functions: mark a test as a functions test",
"mssql: mark a test as a mssql test",
"mysql: mark a test as a mysql test",
"oracle: mark a test as a oracle test",
"postgres: mark a test as a postgres test",
"snowflake: mark a test as a snowflake test",
"trino: mark a test as a trino test",
Expand Down
142 changes: 142 additions & 0 deletions ibis-server/tests/routers/v2/connector/test_oracle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
import base64

import orjson
import pandas as pd
import pytest
import sqlalchemy
from sqlalchemy import text
from testcontainers.oracle import OracleDbContainer

from tests.conftest import file_path

pytestmark = pytest.mark.oracle

base_url = "/v2/connector/oracle"
oracle_password = "Oracle123"
oracle_user = "SYSTEM"
oracle_database = "FREEPDB1"

manifest = {
"catalog": "my_catalog",
"schema": "my_schema",
"models": [
{
"name": "Orders",
"tableReference": {
"schema": "SYSTEM",
"table": "ORDERS",
},
"columns": [
{"name": "orderkey", "expression": "O_ORDERKEY", "type": "number"},
{"name": "custkey", "expression": "O_CUSTKEY", "type": "number"},
{
"name": "orderstatus",
"expression": "O_ORDERSTATUS",
"type": "varchar2",
},
{
"name": "totalprice",
"expression": "O_TOTALPRICE",
"type": "number",
},
{"name": "orderdate", "expression": "O_ORDERDATE", "type": "date"},
{
"name": "order_cust_key",
"expression": "O_ORDERKEY || '_' || O_CUSTKEY",
"type": "varchar2",
},
{
"name": "timestamp",
"expression": "TO_TIMESTAMP('2024-01-01 23:59:59', 'YYYY-MM-DD HH24:MI:SS')",
"type": "timestamp",
},
{
"name": "test_null_time",
"expression": "CAST(NULL AS TIMESTAMP)",
"type": "timestamp",
},
{
"name": "blob_column",
"expression": "UTL_RAW.CAST_TO_RAW('abc')",
"type": "blob",
},
],
"primaryKey": "orderkey",
}
],
}


@pytest.fixture(scope="module")
def manifest_str():
return base64.b64encode(orjson.dumps(manifest)).decode("utf-8")


@pytest.fixture(scope="module")
def oracle(request) -> OracleDbContainer:
oracle = OracleDbContainer(
"gvenzl/oracle-free:23.6-slim-faststart", oracle_password=f"{oracle_password}"
).start()
engine = sqlalchemy.create_engine(oracle.get_connection_url())
with engine.begin() as conn:
pd.read_parquet(file_path("resource/tpch/data/orders.parquet")).to_sql(
"orders", engine, index=False
)
pd.read_parquet(file_path("resource/tpch/data/customer.parquet")).to_sql(
"customer", engine, index=False
)
Comment on lines +92 to +97
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Add error handling for data loading.

The data loading from parquet files lacks error handling. Consider:

  1. Validating file existence
  2. Adding error handling for file read/write operations
  3. Verifying data integrity after loading

# Add table and column comments
conn.execute(text("COMMENT ON TABLE orders IS 'This is a table comment'"))
conn.execute(text("COMMENT ON COLUMN orders.o_comment IS 'This is a comment'"))

return oracle


async def test_query(client, manifest_str, oracle: OracleDbContainer):
connection_info = _to_connection_info(oracle)
response = await client.post(
url=f"{base_url}/query",
json={
"connectionInfo": connection_info,
"manifestStr": manifest_str,
"sql": 'SELECT * FROM "Orders" LIMIT 1',
},
)
assert response.status_code == 200


async def test_query_with_connection_url(
client, manifest_str, oracle: OracleDbContainer
):
connection_url = _to_connection_url(oracle)
response = await client.post(
url=f"{base_url}/query",
json={
"connectionInfo": {"connectionUrl": connection_url},
"manifestStr": manifest_str,
"sql": 'SELECT * FROM "Orders" LIMIT 1',
},
)
assert response.status_code == 200
result = response.json()
assert len(result["columns"]) == len(manifest["models"][0]["columns"])
assert len(result["data"]) == 1
assert result["data"][0][0] == 1
assert result["dtypes"] is not None


def _to_connection_info(oracle: OracleDbContainer):
# We can't use oracle.user, oracle.password, oracle.dbname here
# since these values are None at this point
return {
"host": oracle.get_container_host_ip(),
"port": oracle.get_exposed_port(oracle.port),
"user": f"{oracle_user}",
"password": f"{oracle_password}",
"database": f"{oracle_database}",
}


def _to_connection_url(oracle: OracleDbContainer):
info = _to_connection_info(oracle)
return f"oracle://{info['user']}:{info['password']}@{info['host']}:{info['port']}/{info['database']}"