|
1 | 1 | from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
|
2 |
| -import asyncio |
3 | 2 | import json
|
4 | 3 | from pathlib import Path
|
5 |
| -from sqlalchemy import text |
| 4 | +from sqlmodel import text |
6 | 5 | from sqlmodel import SQLModel, select
|
7 | 6 | import logging
|
8 | 7 | from typing import List, Dict, Any
|
9 |
| -import argparse |
10 |
| - |
11 | 8 | import yaml
|
| 9 | +import argparse |
| 10 | +import asyncio |
| 11 | +import sys |
12 | 12 |
|
13 | 13 | # Import all your models
|
14 | 14 | import app.database.models as models
|
15 | 15 | from app.database.enums import ChunkType
|
16 | 16 | from app.database.utils import get_database_url
|
17 | 17 |
|
| 18 | + |
18 | 19 | # Set up logging
|
19 | 20 | logging.basicConfig(level=logging.INFO)
|
20 | 21 | logger = logging.getLogger(__name__)
|
21 | 22 |
|
22 | 23 |
|
23 |
| -async def init_db(drop_all: bool = False): |
| 24 | +async def reset_db(): |
| 25 | + """Reset database using async operations""" |
24 | 26 | try:
|
25 |
| - """Initialize the database with tables.""" |
26 | 27 | engine = create_async_engine(get_database_url())
|
27 |
| - |
28 | 28 | async with engine.begin() as conn:
|
29 |
| - if drop_all: |
30 |
| - logger.info("Dropping all existing tables...") |
31 |
| - await conn.run_sync(SQLModel.metadata.drop_all) |
32 |
| - |
33 |
| - logger.info("Creating tables and pgvector extension...") |
| 29 | + logger.info("Dropping all existing tables...") |
| 30 | + await conn.run_sync(SQLModel.metadata.drop_all) |
| 31 | + |
| 32 | + # Drop alembic_version table |
| 33 | + logger.info("Dropping alembic_version table...") |
| 34 | + await conn.execute(text("DROP TABLE IF EXISTS alembic_version CASCADE")) |
| 35 | + |
| 36 | + logger.info("Dropping existing enum types...") |
| 37 | + enum_query = """ |
| 38 | + SELECT t.typname FROM pg_type t |
| 39 | + JOIN pg_catalog.pg_namespace n ON n.oid = t.typnamespace |
| 40 | + WHERE t.typtype = 'e' AND n.nspname = 'public' |
| 41 | + """ |
| 42 | + result = await conn.execute(text(enum_query)) |
| 43 | + enum_types = result.scalars().all() |
| 44 | + |
| 45 | + for enum_name in enum_types: |
| 46 | + await conn.execute(text(f"DROP TYPE IF EXISTS {enum_name} CASCADE")) |
| 47 | + |
| 48 | + logger.info("Creating pgvector extension...") |
34 | 49 | await conn.execute(text("CREATE EXTENSION IF NOT EXISTS vector"))
|
35 |
| - await conn.run_sync(SQLModel.metadata.create_all) |
36 | 50 | except Exception as e:
|
37 | 51 | logger.error(f"Error initializing database: {str(e)}")
|
38 | 52 | raise
|
39 | 53 | finally:
|
40 | 54 | await engine.dispose()
|
41 | 55 |
|
42 | 56 |
|
| 57 | +def run_migrations(): |
| 58 | + """Run alembic migrations""" |
| 59 | + from alembic.config import Config |
| 60 | + from alembic import command |
| 61 | + |
| 62 | + logger.info("Running migrations...") |
| 63 | + alembic_cfg = Config("alembic.ini") |
| 64 | + command.upgrade(alembic_cfg, "head") |
| 65 | + logger.info("Migrations complete") |
| 66 | + |
| 67 | + |
43 | 68 | async def inject_sample_data():
|
44 | 69 | try:
|
45 | 70 | """Load sample data into the database."""
|
@@ -194,39 +219,45 @@ async def inject_vector_data(file: str):
|
194 | 219 | await engine.dispose()
|
195 | 220 |
|
196 | 221 |
|
197 |
| -if __name__ == "__main__": |
198 |
| - try: |
199 |
| - # Parse command line arguments |
200 |
| - parser = argparse.ArgumentParser( |
201 |
| - description="Initialize the database for Twiga development." |
202 |
| - ) |
203 |
| - parser.add_argument( |
204 |
| - "--drop", |
205 |
| - action="store_true", |
206 |
| - help="Drop the existing tables before creation", |
207 |
| - ) |
208 |
| - parser.add_argument( |
209 |
| - "--sample-data", action="store_true", help="Add sample data to the database" |
210 |
| - ) |
211 |
| - parser.add_argument( |
212 |
| - "--vector-data", |
213 |
| - type=str, |
214 |
| - metavar="FILENAME", |
215 |
| - help="Populate the vector database using the specified chunks file (chunks_OPENAI.json or chunks_BAAI.json)", |
216 |
| - ) |
| 222 | +async def main(): |
| 223 | + """Initialize and setup the Twiga development database.""" |
| 224 | + # Set up argument parser |
| 225 | + parser = argparse.ArgumentParser( |
| 226 | + description="Initialize and setup the Twiga development database." |
| 227 | + ) |
| 228 | + parser.add_argument( |
| 229 | + "--create", action="store_true", help="Reset and run alembic migrations" |
| 230 | + ) |
| 231 | + parser.add_argument("--sample-data", action="store_true", help="Add sample data") |
| 232 | + parser.add_argument( |
| 233 | + "--vector-data", |
| 234 | + type=str, |
| 235 | + help="Vector database chunks file (chunks_OPENAI.json or chunks_BAAI.json)", |
| 236 | + ) |
| 237 | + |
| 238 | + # Parse arguments |
| 239 | + args = parser.parse_args() |
217 | 240 |
|
218 |
| - args = parser.parse_args() |
| 241 | + try: |
219 | 242 |
|
220 |
| - asyncio.run(init_db(drop_all=args.drop)) |
| 243 | + if args.create: |
| 244 | + logger.info("Starting database setup...") |
| 245 | + await reset_db() |
| 246 | + run_migrations() |
221 | 247 |
|
222 | 248 | if args.sample_data:
|
223 |
| - logger.info("Starting data injection...") |
224 |
| - asyncio.run(inject_sample_data()) |
| 249 | + logger.info("Starting sample data injection...") |
| 250 | + await inject_sample_data() |
225 | 251 |
|
226 | 252 | if args.vector_data:
|
227 | 253 | logger.info("Starting vector data injection...")
|
228 |
| - asyncio.run(inject_vector_data(args.vector_data)) |
| 254 | + await inject_vector_data(args.vector_data) |
229 | 255 |
|
230 |
| - logger.info("Complete.") |
| 256 | + logger.info("Database setup complete") |
231 | 257 | except Exception as e:
|
232 |
| - logger.error(f"Setup failed: {e}") |
| 258 | + print(f"Setup failed: {e}") |
| 259 | + sys.exit(1) |
| 260 | + |
| 261 | + |
| 262 | +if __name__ == "__main__": |
| 263 | + asyncio.run(main()) |
0 commit comments