Skip to content

Commit 9b89941

Browse files
Added support for passing arguments to create_engine() (#391)
1 parent 636680d commit 9b89941

File tree

4 files changed

+108
-2
lines changed

4 files changed

+108
-2
lines changed

CHANGES.rst

+2
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ Version history
55

66
- Type annotations for ARRAY column attributes now include the Python type of
77
the array elements
8+
- Added support for specifying engine arguments via ``--engine-arg``
9+
(PR by @LajosCseppento)
810

911
**3.0.0**
1012

README.rst

+3
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,9 @@ Examples::
6767
sqlacodegen postgresql:///some_local_db
6868
sqlacodegen --generator tables mysql+pymysql://user:password@localhost/dbname
6969
sqlacodegen --generator dataclasses sqlite:///database.db
70+
# --engine-arg values are parsed with ast.literal_eval
71+
sqlacodegen oracle+oracledb://user:[email protected]:1521/XE --engine-arg thick_mode=True
72+
sqlacodegen oracle+oracledb://user:[email protected]:1521/XE --engine-arg thick_mode=True --engine-arg connect_args='{"user": "user", "dsn": "..."}'
7073

7174
To see the list of generic options::
7275

src/sqlacodegen/cli.py

+37-2
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
from __future__ import annotations
22

33
import argparse
4+
import ast
45
import sys
56
from contextlib import ExitStack
6-
from typing import TextIO
7+
from typing import Any, TextIO
78

89
from sqlalchemy.engine import create_engine
910
from sqlalchemy.schema import MetaData
@@ -29,6 +30,28 @@
2930
from importlib.metadata import entry_points, version
3031

3132

33+
def _parse_engine_arg(arg_str: str) -> tuple[str, Any]:
34+
if "=" not in arg_str:
35+
raise argparse.ArgumentTypeError("engine-arg must be in key=value format")
36+
37+
key, value = arg_str.split("=", 1)
38+
try:
39+
value = ast.literal_eval(value)
40+
except Exception:
41+
pass # Leave as string if literal_eval fails
42+
43+
return key, value
44+
45+
46+
def _parse_engine_args(arg_list: list[str]) -> dict[str, Any]:
47+
result = {}
48+
for arg in arg_list or []:
49+
key, value = _parse_engine_arg(arg)
50+
result[key] = value
51+
52+
return result
53+
54+
3255
def main() -> None:
3356
generators = {ep.name: ep for ep in entry_points(group="sqlacodegen.generators")}
3457
parser = argparse.ArgumentParser(
@@ -58,6 +81,17 @@ def main() -> None:
5881
action="store_true",
5982
help="ignore views (always true for sqlmodels generator)",
6083
)
84+
parser.add_argument(
85+
"--engine-arg",
86+
action="append",
87+
help=(
88+
"engine arguments in key=value format, e.g., "
89+
'--engine-arg=connect_args=\'{"user": "scott"}\' '
90+
"--engine-arg thick_mode=true or "
91+
'--engine-arg thick_mode=\'{"lib_dir": "/path"}\' '
92+
"(values are parsed with ast.literal_eval)"
93+
),
94+
)
6195
parser.add_argument("--outfile", help="file to write output to (default: stdout)")
6296
args = parser.parse_args()
6397

@@ -80,7 +114,8 @@ def main() -> None:
80114
print(f"Using pgvector {version('pgvector')}")
81115

82116
# Use reflection to fill in the metadata
83-
engine = create_engine(args.url)
117+
engine_args = _parse_engine_args(args.engine_arg)
118+
engine = create_engine(args.url, **engine_args)
84119
metadata = MetaData()
85120
tables = args.tables.split(",") if args.tables else None
86121
schemas = args.schemas.split(",") if args.schemas else [None]

tests/test_cli.py

+66
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,72 @@ class Foo(SQLModel, table=True):
150150
)
151151

152152

153+
def test_cli_engine_arg(db_path: Path, tmp_path: Path) -> None:
154+
output_path = tmp_path / "outfile"
155+
subprocess.run(
156+
[
157+
"sqlacodegen",
158+
f"sqlite:///{db_path}",
159+
"--generator",
160+
"tables",
161+
"--engine-arg",
162+
'connect_args={"timeout": 10}',
163+
"--outfile",
164+
str(output_path),
165+
],
166+
check=True,
167+
)
168+
169+
assert (
170+
output_path.read_text()
171+
== """\
172+
from sqlalchemy import Column, Integer, MetaData, Table, Text
173+
174+
metadata = MetaData()
175+
176+
177+
t_foo = Table(
178+
'foo', metadata,
179+
Column('id', Integer, primary_key=True),
180+
Column('name', Text, nullable=False)
181+
)
182+
"""
183+
)
184+
185+
186+
def test_cli_invalid_engine_arg(db_path: Path, tmp_path: Path) -> None:
187+
output_path = tmp_path / "outfile"
188+
189+
# Expect exception:
190+
# TypeError: 'this_arg_does_not_exist' is an invalid keyword argument for Connection()
191+
with pytest.raises(subprocess.CalledProcessError) as exc_info:
192+
subprocess.run(
193+
[
194+
"sqlacodegen",
195+
f"sqlite:///{db_path}",
196+
"--generator",
197+
"tables",
198+
"--engine-arg",
199+
'connect_args={"this_arg_does_not_exist": 10}',
200+
"--outfile",
201+
str(output_path),
202+
],
203+
check=True,
204+
capture_output=True,
205+
)
206+
207+
if sys.version_info < (3, 13):
208+
assert (
209+
"'this_arg_does_not_exist' is an invalid keyword argument"
210+
in exc_info.value.stderr.decode()
211+
)
212+
else:
213+
assert (
214+
"got an unexpected keyword argument 'this_arg_does_not_exist'"
215+
in exc_info.value.stderr.decode()
216+
)
217+
218+
153219
def test_main() -> None:
154220
expected_version = version("sqlacodegen")
155221
completed = subprocess.run(

0 commit comments

Comments
 (0)