Skip to content

Commit b84ca42

Browse files
alex75malmans2
andauthored
postgres-specific code removed (#123)
* postgres-specific code removed * add unit tests * remove debug print * qa * fix test * better handling of wdir * qa --------- Co-authored-by: Mattia Almansi <[email protected]>
1 parent aaa38c2 commit b84ca42

File tree

4 files changed

+91
-36
lines changed

4 files changed

+91
-36
lines changed

cacholote/database.py

Lines changed: 48 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def _decode_kwargs(**kwargs: Any) -> dict[str, Any]:
111111
def _cached_sessionmaker(
112112
url: str, **kwargs: Any
113113
) -> sa.orm.sessionmaker[sa.orm.Session]:
114-
engine = sa.create_engine(url, **_decode_kwargs(**kwargs))
114+
engine = init_database(url, **_decode_kwargs(**kwargs))
115115
Base.metadata.create_all(engine)
116116
return sa.orm.sessionmaker(engine)
117117

@@ -120,44 +120,57 @@ def cached_sessionmaker(url: str, **kwargs: Any) -> sa.orm.sessionmaker[sa.orm.S
120120
return _cached_sessionmaker(url, **_encode_kwargs(**kwargs))
121121

122122

123-
def init_database(connection_string: str, force: bool = False) -> sa.engine.Engine:
123+
def init_database(
124+
connection_string: str, force: bool = False, **kwargs: Any
125+
) -> sa.engine.Engine:
124126
"""
125127
Make sure the db located at URI `connection_string` exists updated and return the engine object.
126128
127-
:param connection_string: something like 'postgresql://user:password@netloc:port/dbname'
128-
:param force: if True, drop the database structure and build again from scratch
129+
Parameters
130+
----------
131+
connection_string: str
132+
Something like 'postgresql://user:password@netloc:port/dbname'
133+
force: bool
134+
if True, drop the database structure and build again from scratch
135+
kwargs: Any
136+
Keyword arguments for create_engine
137+
138+
Returns
139+
-------
140+
engine: Engine
129141
"""
130-
engine = sa.create_engine(connection_string)
142+
engine = sa.create_engine(connection_string, **kwargs)
131143
migration_directory = os.path.abspath(os.path.join(__file__, ".."))
132-
os.chdir(migration_directory)
133-
alembic_config_path = os.path.join(migration_directory, "alembic.ini")
134-
alembic_cfg = alembic.config.Config(alembic_config_path)
135-
for option in ["drivername", "username", "password", "host", "port", "database"]:
136-
value = getattr(engine.url, option)
137-
if value is None:
138-
value = ""
139-
alembic_cfg.set_main_option(option, str(value))
140-
if not sqlalchemy_utils.database_exists(engine.url):
141-
sqlalchemy_utils.create_database(engine.url)
142-
# cleanup and create the schema
143-
Base.metadata.drop_all(engine)
144-
Base.metadata.create_all(engine)
145-
alembic.command.stamp(alembic_cfg, "head")
146-
else:
147-
# check the structure is empty or incomplete
148-
query = sa.text(
149-
"SELECT table_name FROM information_schema.tables WHERE table_schema='public'"
150-
)
151-
conn = engine.connect()
152-
if "cache_entries" not in conn.execute(query).scalars().all():
144+
with utils.change_working_dir(migration_directory):
145+
alembic_config_path = os.path.join(migration_directory, "alembic.ini")
146+
alembic_cfg = alembic.config.Config(alembic_config_path)
147+
for option in [
148+
"drivername",
149+
"username",
150+
"password",
151+
"host",
152+
"port",
153+
"database",
154+
]:
155+
value = getattr(engine.url, option)
156+
if value is None:
157+
value = ""
158+
alembic_cfg.set_main_option(option, str(value))
159+
if not sqlalchemy_utils.database_exists(engine.url):
160+
sqlalchemy_utils.create_database(engine.url)
161+
# cleanup and create the schema
162+
Base.metadata.drop_all(engine)
163+
Base.metadata.create_all(engine)
164+
alembic.command.stamp(alembic_cfg, "head")
165+
elif "cache_entries" not in sa.inspect(engine).get_table_names():
166+
# db structure is empty or incomplete
153167
force = True
154-
conn.close()
155-
if force:
156-
# cleanup and create the schema
157-
Base.metadata.drop_all(engine)
158-
Base.metadata.create_all(engine)
159-
alembic.command.stamp(alembic_cfg, "head")
160-
else:
161-
# update db structure
162-
alembic.command.upgrade(alembic_cfg, "head")
168+
if force:
169+
# cleanup and create the schema
170+
Base.metadata.drop_all(engine)
171+
Base.metadata.create_all(engine)
172+
alembic.command.stamp(alembic_cfg, "head")
173+
else:
174+
# update db structure
175+
alembic.command.upgrade(alembic_cfg, "head")
163176
return engine

cacholote/utils.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,17 @@
1616
# limitations under the License.import hashlib
1717
from __future__ import annotations
1818

19+
import contextlib
1920
import dataclasses
2021
import datetime
2122
import functools
2223
import hashlib
2324
import io
25+
import os
2426
import time
2527
import warnings
2628
from types import TracebackType
27-
from typing import Any
29+
from typing import Any, Iterator
2830

2931
import fsspec
3032

@@ -129,3 +131,13 @@ def __exit__(
129131
def utcnow() -> datetime.datetime:
130132
"""See https://discuss.python.org/t/deprecating-utcnow-and-utcfromtimestamp/26221."""
131133
return datetime.datetime.now(tz=datetime.timezone.utc)
134+
135+
136+
@contextlib.contextmanager
137+
def change_working_dir(working_dir: str) -> Iterator[str]:
138+
old_dir = os.getcwd()
139+
os.chdir(working_dir)
140+
try:
141+
yield os.getcwd()
142+
finally:
143+
os.chdir(old_dir)

tests/test_02_utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import os
34
import pathlib
45

56
import fsspec
@@ -29,3 +30,10 @@ def test_copy_buffered_file(tmp_path: pathlib.Path) -> None:
2930
with open(src, "rb") as f_src, open(dst, "wb") as f_dst:
3031
utils.copy_buffered_file(f_src, f_dst)
3132
assert open(src, "rb").read() == open(dst, "rb").read() == b"test"
33+
34+
35+
def test_change_working_dir(tmp_path: pathlib.Path) -> None:
36+
old_cwd = os.getcwd()
37+
with utils.change_working_dir(str(tmp_path)) as actual:
38+
assert actual == os.getcwd() == str(tmp_path.resolve())
39+
assert os.getcwd() == old_cwd

tests/test_70_alembic.py

Lines changed: 22 additions & 0 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)