diff --git a/cads_broker/database.py b/cads_broker/database.py index d8713c10..97d12aa3 100644 --- a/cads_broker/database.py +++ b/cads_broker/database.py @@ -877,8 +877,7 @@ def init_database(connection_string: str, force: bool = False) -> sa.engine.Engi sqlalchemy_utils.create_database(engine.url) # cleanup and create the schema BaseModel.metadata.drop_all(engine) - cacholote.database.Base.metadata.drop_all(engine) - cacholote.database.Base.metadata.create_all(engine) + cacholote.init_database(connection_string, force) BaseModel.metadata.create_all(engine) alembic.command.stamp(alembic_cfg, "head") else: @@ -893,11 +892,11 @@ def init_database(connection_string: str, force: bool = False) -> sa.engine.Engi if force: # cleanup and create the schema BaseModel.metadata.drop_all(engine) - cacholote.database.Base.metadata.drop_all(engine) - cacholote.database.Base.metadata.create_all(engine) + cacholote.init_database(connection_string, force) BaseModel.metadata.create_all(engine) alembic.command.stamp(alembic_cfg, "head") else: # update db structure + cacholote.init_database(connection_string, force) alembic.command.upgrade(alembic_cfg, "head") return engine diff --git a/tests/test_02_database.py b/tests/test_02_database.py index 888e0bdc..f2882203 100644 --- a/tests/test_02_database.py +++ b/tests/test_02_database.py @@ -982,7 +982,7 @@ def test_init_database(postgresql: Connection[str]) -> None: db.init_database(connection_string, force=True) expected_tables_complete = ( set(db.BaseModel.metadata.tables) - .union({"alembic_version"}) + .union({"alembic_version_cacholote"}) .union(set(cacholote.database.Base.metadata.tables)) ) assert set(conn.execute(query).scalars()) == expected_tables_complete # type: ignore @@ -1032,7 +1032,7 @@ def test_init_database_with_password(postgresql2: Connection[str]) -> None: db.init_database(connection_string, force=True) expected_tables_complete = ( set(db.BaseModel.metadata.tables) - .union({"alembic_version"}) + .union({"alembic_version_cacholote"}) .union(set(cacholote.database.Base.metadata.tables)) ) assert set(conn.execute(query).scalars()) == expected_tables_complete # type: ignore diff --git a/tests/test_90_entry_points.py b/tests/test_90_entry_points.py index 31811f5a..63e96a99 100644 --- a/tests/test_90_entry_points.py +++ b/tests/test_90_entry_points.py @@ -88,7 +88,9 @@ def test_init_db(postgresql: Connection[str], mocker) -> None: ) assert set(conn.execute(query).scalars()) == set( database.BaseModel.metadata.tables - ).union({"alembic_version"}).union(set(cacholote.database.Base.metadata.tables)) + ).union({"alembic_version_cacholote"}).union( + set(cacholote.database.Base.metadata.tables) + ) conn.close()