Skip to content

Commit ccc7281

Browse files
authored
Merge pull request #74 from Tanzania-AI-Community/feature/db-seed-procedure
New Database Seed Script
2 parents 7214a14 + 4becf4e commit ccc7281

File tree

6 files changed

+91
-46
lines changed

6 files changed

+91
-46
lines changed

Makefile

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,6 @@ restart: stop run
1414

1515
generate-local-data:
1616
@echo 'Generating local data ...'
17-
@docker-compose -f docker/dev/docker-compose.yml --env-file .env run --rm app bash -c "PYTHONPATH=/app uv run python scripts/database/seed.py --sample-data --vector-data chunks_BAAI.json"
17+
@docker-compose -f docker/dev/docker-compose.yml --env-file .env run --rm app bash -c "PYTHONPATH=/app uv run python scripts/database/seed.py --create --sample-data --vector-data chunks_BAAI.json"
1818

1919
setup-env: build generate-local-data run

app/config.py

+5
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,11 @@ class Settings(BaseSettings):
4141
# Database settings
4242
database_url: SecretStr
4343

44+
@property
45+
def sync_database_url(self) -> str:
46+
"""Get synchronous database URL for migrations"""
47+
return self.database_url.get_secret_value().replace("+asyncpg", "")
48+
4449
# Business environment
4550
business_env: bool = False # Default if not found in .env
4651

docs/en/GETTING_STARTED.md

-4
Original file line numberDiff line numberDiff line change
@@ -137,10 +137,6 @@ Next up, let's build all Docker images and local data, needed for further steps
137137
make setup-env
138138
```
139139

140-
> [!Note]
141-
>
142-
> You can try out [TablePlus](https://tableplus.com/) to visualize your databases.
143-
144140
## 🖥️ Set up the FastAPI application
145141

146142
Run the following command to run the project.

docs/en/MIGRATIONS.md

+13
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Migrations
2+
3+
You might run into some database troubles that require you to do database migrations. In the folder `migrations/versions/` you find the list of past database migrations. We're using [Alembic](https://alembic.sqlalchemy.org/en/latest/). They're docs aren't great so here's a beginner [article](https://medium.com/@kasperjuunge/how-to-get-started-with-alembic-and-sqlmodel-288700002543) on it.
4+
5+
By default, our Docker images use the alembic versioning system to initialize the database. If you wan't to rebuild the database to your needs, you can run new migrations and rebuild the Docker containers.
6+
7+
If you're not using Docker to run Twiga, then you can initialize the database and inject seed data with the command:
8+
9+
```
10+
uv run python -m scripts.database.seed --create --sample-data --vector-data chunks_BAAI.json
11+
```
12+
13+
This will remove all tables in the database if they exist, create new ones, install pgvector and inject sample data and vector data so that the database is ready to accept new users.

migrations/env.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
config = context.config
1616

1717
# Load PostgreSQL migrations URL from environment variables
18-
db_url = settings.database_url.get_secret_value()
18+
db_url = settings.sync_database_url
1919
print(f"Connecting to: {db_url}")
2020
config.set_main_option("sqlalchemy.url", db_url)
2121

scripts/database/seed.py

+71-40
Original file line numberDiff line numberDiff line change
@@ -1,45 +1,70 @@
11
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
2-
import asyncio
32
import json
43
from pathlib import Path
5-
from sqlalchemy import text
4+
from sqlmodel import text
65
from sqlmodel import SQLModel, select
76
import logging
87
from typing import List, Dict, Any
9-
import argparse
10-
118
import yaml
9+
import argparse
10+
import asyncio
11+
import sys
1212

1313
# Import all your models
1414
import app.database.models as models
1515
from app.database.enums import ChunkType
1616
from app.database.utils import get_database_url
1717

18+
1819
# Set up logging
1920
logging.basicConfig(level=logging.INFO)
2021
logger = logging.getLogger(__name__)
2122

2223

23-
async def init_db(drop_all: bool = False):
24+
async def reset_db():
25+
"""Reset database using async operations"""
2426
try:
25-
"""Initialize the database with tables."""
2627
engine = create_async_engine(get_database_url())
27-
2828
async with engine.begin() as conn:
29-
if drop_all:
30-
logger.info("Dropping all existing tables...")
31-
await conn.run_sync(SQLModel.metadata.drop_all)
32-
33-
logger.info("Creating tables and pgvector extension...")
29+
logger.info("Dropping all existing tables...")
30+
await conn.run_sync(SQLModel.metadata.drop_all)
31+
32+
# Drop alembic_version table
33+
logger.info("Dropping alembic_version table...")
34+
await conn.execute(text("DROP TABLE IF EXISTS alembic_version CASCADE"))
35+
36+
logger.info("Dropping existing enum types...")
37+
enum_query = """
38+
SELECT t.typname FROM pg_type t
39+
JOIN pg_catalog.pg_namespace n ON n.oid = t.typnamespace
40+
WHERE t.typtype = 'e' AND n.nspname = 'public'
41+
"""
42+
result = await conn.execute(text(enum_query))
43+
enum_types = result.scalars().all()
44+
45+
for enum_name in enum_types:
46+
await conn.execute(text(f"DROP TYPE IF EXISTS {enum_name} CASCADE"))
47+
48+
logger.info("Creating pgvector extension...")
3449
await conn.execute(text("CREATE EXTENSION IF NOT EXISTS vector"))
35-
await conn.run_sync(SQLModel.metadata.create_all)
3650
except Exception as e:
3751
logger.error(f"Error initializing database: {str(e)}")
3852
raise
3953
finally:
4054
await engine.dispose()
4155

4256

57+
def run_migrations():
58+
"""Run alembic migrations"""
59+
from alembic.config import Config
60+
from alembic import command
61+
62+
logger.info("Running migrations...")
63+
alembic_cfg = Config("alembic.ini")
64+
command.upgrade(alembic_cfg, "head")
65+
logger.info("Migrations complete")
66+
67+
4368
async def inject_sample_data():
4469
try:
4570
"""Load sample data into the database."""
@@ -194,39 +219,45 @@ async def inject_vector_data(file: str):
194219
await engine.dispose()
195220

196221

197-
if __name__ == "__main__":
198-
try:
199-
# Parse command line arguments
200-
parser = argparse.ArgumentParser(
201-
description="Initialize the database for Twiga development."
202-
)
203-
parser.add_argument(
204-
"--drop",
205-
action="store_true",
206-
help="Drop the existing tables before creation",
207-
)
208-
parser.add_argument(
209-
"--sample-data", action="store_true", help="Add sample data to the database"
210-
)
211-
parser.add_argument(
212-
"--vector-data",
213-
type=str,
214-
metavar="FILENAME",
215-
help="Populate the vector database using the specified chunks file (chunks_OPENAI.json or chunks_BAAI.json)",
216-
)
222+
async def main():
223+
"""Initialize and setup the Twiga development database."""
224+
# Set up argument parser
225+
parser = argparse.ArgumentParser(
226+
description="Initialize and setup the Twiga development database."
227+
)
228+
parser.add_argument(
229+
"--create", action="store_true", help="Reset and run alembic migrations"
230+
)
231+
parser.add_argument("--sample-data", action="store_true", help="Add sample data")
232+
parser.add_argument(
233+
"--vector-data",
234+
type=str,
235+
help="Vector database chunks file (chunks_OPENAI.json or chunks_BAAI.json)",
236+
)
237+
238+
# Parse arguments
239+
args = parser.parse_args()
217240

218-
args = parser.parse_args()
241+
try:
219242

220-
asyncio.run(init_db(drop_all=args.drop))
243+
if args.create:
244+
logger.info("Starting database setup...")
245+
await reset_db()
246+
run_migrations()
221247

222248
if args.sample_data:
223-
logger.info("Starting data injection...")
224-
asyncio.run(inject_sample_data())
249+
logger.info("Starting sample data injection...")
250+
await inject_sample_data()
225251

226252
if args.vector_data:
227253
logger.info("Starting vector data injection...")
228-
asyncio.run(inject_vector_data(args.vector_data))
254+
await inject_vector_data(args.vector_data)
229255

230-
logger.info("Complete.")
256+
logger.info("Database setup complete")
231257
except Exception as e:
232-
logger.error(f"Setup failed: {e}")
258+
print(f"Setup failed: {e}")
259+
sys.exit(1)
260+
261+
262+
if __name__ == "__main__":
263+
asyncio.run(main())

0 commit comments

Comments
 (0)