|
1 | 1 | from os import getenv
|
2 |
| -from typing import TypeVar, Dict |
| 2 | +from typing import TypeVar, Dict, Type, TypeAlias |
3 | 3 |
|
4 | 4 | from dotenv import load_dotenv
|
5 | 5 | from sqlalchemy.engine import URL
|
6 |
| -from sqlalchemy.ext.asyncio import create_async_engine, AsyncEngine, AsyncSession |
| 6 | +from sqlalchemy.ext.asyncio import create_async_engine, AsyncEngine, AsyncSession, async_sessionmaker |
7 | 7 | from sqlalchemy.future import select as sa_select
|
8 |
| -from sqlalchemy.orm import DeclarativeMeta, declarative_base, sessionmaker |
| 8 | +from sqlalchemy.orm import DeclarativeMeta, declarative_base, sessionmaker, DeclarativeBase |
9 | 9 | from sqlalchemy.pool import NullPool
|
10 | 10 | from sqlalchemy.sql import Executable
|
11 | 11 | from sqlalchemy.sql.expression import exists as sa_exists, delete as sa_delete, Delete
|
@@ -40,26 +40,30 @@ def delete(table) -> Delete:
|
40 | 40 | return sa_delete(table)
|
41 | 41 |
|
42 | 42 |
|
| 43 | +class Base(DeclarativeBase): |
| 44 | + pass |
| 45 | + |
| 46 | + |
43 | 47 | class DB:
|
44 | 48 | """An async SQLAlchemy ORM wrapper"""
|
45 | 49 |
|
46 |
| - Base: DeclarativeMeta |
47 | 50 | _engine: AsyncEngine
|
48 | 51 | _session: AsyncSession
|
49 | 52 |
|
50 |
| - def __init__(self, driver: str, options: Dict = {"pool_size": 20, "max_overflow": 20}, **kwargs): |
51 |
| - url: str = URL.create(drivername=driver, **kwargs) |
| 53 | + def __init__(self, driver: str, options=None, **kwargs): |
| 54 | + if options is None: |
| 55 | + options = {"pool_size": 20, "max_overflow": 20} |
| 56 | + url = URL.create(drivername=driver, **kwargs) |
52 | 57 | self._engine = create_async_engine(url, echo=True, pool_pre_ping=True, pool_recycle=300, **options)
|
53 |
| - self.Base = declarative_base() |
54 |
| - self._session: AsyncSession = sessionmaker(self._engine, expire_on_commit=False, class_=AsyncSession)() |
| 58 | + self._session = async_sessionmaker(self._engine, expire_on_commit=False)() |
55 | 59 |
|
56 | 60 | async def create_tables(self):
|
57 | 61 | """Creates all Model Tables"""
|
58 | 62 | async with self._engine.begin() as conn:
|
59 |
| - await conn.run_sync(self.Base.metadata.create_all) |
| 63 | + await conn.run_sync(Base.metadata.create_all) |
60 | 64 |
|
61 | 65 | async def add(self, obj: T) -> T:
|
62 |
| - """Adds an Row to the Database""" |
| 66 | + """Adds a Row to the Database""" |
63 | 67 | self._session.add(obj)
|
64 | 68 | return obj
|
65 | 69 |
|
@@ -153,4 +157,3 @@ async def __call__(self) -> DB:
|
153 | 157 |
|
154 | 158 |
|
155 | 159 | database_dependency: DatabaseDependency = DatabaseDependency()
|
156 |
| -Base: DeclarativeMeta = database_dependency.db.Base |
|
0 commit comments