Skip to content

Commit bfa0850

Browse files
authored
Add model and tests (#22)
1 parent 9d4aa1f commit bfa0850

File tree

6 files changed

+203
-70
lines changed

6 files changed

+203
-70
lines changed

app/datasources/db/models.py

Lines changed: 76 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,80 @@
1-
from typing import Optional
1+
from sqlmodel import (
2+
JSON,
3+
Column,
4+
Field,
5+
Relationship,
6+
SQLModel,
7+
UniqueConstraint,
8+
select,
9+
)
210

3-
from sqlmodel import Field, SQLModel
411

12+
class SqlQueryBase:
513

6-
class Contract(SQLModel, table=True):
7-
address: bytes = Field(nullable=False, primary_key=True)
14+
@classmethod
15+
async def get_all(cls, session):
16+
result = await session.exec(select(cls))
17+
return result.all()
18+
19+
async def _save(self, session):
20+
session.add(self)
21+
await session.commit()
22+
return self
23+
24+
async def update(self, session):
25+
return await self._save(session)
26+
27+
async def create(self, session):
28+
return await self._save(session)
29+
30+
31+
class AbiSource(SqlQueryBase, SQLModel, table=True):
32+
id: int | None = Field(default=None, primary_key=True)
33+
name: str = Field(nullable=False)
34+
url: str = Field(nullable=False)
35+
36+
abis: list["Abi"] = Relationship(back_populates="source")
37+
38+
39+
class Abi(SqlQueryBase, SQLModel, table=True):
40+
id: int | None = Field(default=None, primary_key=True)
41+
abi_hash: bytes = Field(nullable=False, index=True, unique=True)
42+
relevance: int | None = Field(nullable=False, default=0)
43+
abi_json: dict = Field(default_factory=dict, sa_column=Column(JSON))
44+
source_id: int | None = Field(
45+
nullable=True, default=None, foreign_key="abisource.id"
46+
)
47+
48+
source: AbiSource | None = Relationship(back_populates="abis")
49+
contracts: list["Contract"] = Relationship(back_populates="abi")
50+
51+
52+
class Project(SqlQueryBase, SQLModel, table=True):
53+
id: int | None = Field(default=None, primary_key=True)
54+
description: str = Field(nullable=False)
55+
logo_file: str = Field(nullable=False)
56+
contracts: list["Contract"] = Relationship(back_populates="project")
57+
58+
59+
class Contract(SqlQueryBase, SQLModel, table=True):
60+
__table_args__ = (
61+
UniqueConstraint("address", "chain_id", name="address_chain_unique"),
62+
)
63+
64+
id: int | None = Field(default=None, primary_key=True)
65+
address: bytes = Field(nullable=False, index=True)
866
name: str = Field(nullable=False)
9-
description: Optional[str] = None
67+
display_name: str | None = None
68+
description: str | None = None
69+
trusted_for_delegate: bool = Field(nullable=False, default=False)
70+
proxy: bool = Field(nullable=False, default=False)
71+
fetch_retries: int = Field(nullable=False, default=0)
72+
abi_id: bytes | None = Field(
73+
nullable=True, default=None, foreign_key="abi.abi_hash"
74+
)
75+
abi: Abi | None = Relationship(back_populates="contracts")
76+
project_id: int | None = Field(
77+
nullable=True, default=None, foreign_key="project.id"
78+
)
79+
project: Project | None = Relationship(back_populates="contracts")
80+
chain_id: int = Field(default=None)

app/services/contract.py

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from typing import Sequence
22

3-
from sqlmodel import select
43
from sqlmodel.ext.asyncio.session import AsyncSession
54

65
from app.datasources.db.models import Contract
@@ -16,18 +15,4 @@ async def get_all(session: AsyncSession) -> Sequence[Contract]:
1615
:param session: passed by the decorator
1716
:return:
1817
"""
19-
result = await session.exec(select(Contract))
20-
return result.all()
21-
22-
@staticmethod
23-
async def create(contract: Contract, session: AsyncSession) -> Contract:
24-
"""
25-
Create a new contract
26-
27-
:param contract:
28-
:param session:
29-
:return:
30-
"""
31-
session.add(contract)
32-
await session.commit()
33-
return contract
18+
return await Contract.get_all(session)

app/tests/db/test_model.py

Lines changed: 36 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,45 @@
1-
from sqlmodel import select
21
from sqlmodel.ext.asyncio.session import AsyncSession
32

43
from app.datasources.db.database import database_session
5-
from app.datasources.db.models import Contract
4+
from app.datasources.db.models import Abi, AbiSource, Contract, Project
65
from app.tests.db.db_async_conn import DbAsyncConn
76

87

98
class TestModel(DbAsyncConn):
109
@database_session
1110
async def test_contract(self, session: AsyncSession):
12-
contract = Contract(address=b"a", name="A Test Contracts")
13-
session.add(contract)
14-
await session.commit()
15-
statement = select(Contract).where(Contract.address == b"a")
16-
result = await session.exec(statement)
17-
self.assertEqual(result.one(), contract)
11+
contract = Contract(address=b"a", name="A test contract", chain_id=1)
12+
await contract.create(session)
13+
await contract.create(session)
14+
result = await contract.get_all(session)
15+
self.assertEqual(result[0], contract)
16+
17+
@database_session
18+
async def test_project(self, session: AsyncSession):
19+
project = Project(description="A Test Project", logo_file="logo.jpg")
20+
await project.create(session)
21+
result = await project.get_all(session)
22+
self.assertEqual(result[0], project)
23+
24+
@database_session
25+
async def test_abi(self, session: AsyncSession):
26+
abi = Abi(abi_hash=b"A Test Abi", abi_json={"name": "A Test Project"})
27+
await abi.create(session)
28+
result = await abi.get_all(session)
29+
self.assertEqual(result[0], abi)
30+
31+
@database_session
32+
async def test_abi_source(self, session: AsyncSession):
33+
abi_source = AbiSource(name="A Test Source", url="https://test.com")
34+
await abi_source.create(session)
35+
result = await abi_source.get_all(session)
36+
self.assertEqual(result[0], abi_source)
37+
abi = Abi(
38+
abi_hash=b"A Test Abi",
39+
abi_json={"name": "A Test Project"},
40+
source_id=abi_source.id,
41+
)
42+
await abi.create(session)
43+
result = await abi.get_all(session)
44+
self.assertEqual(result[0], abi)
45+
self.assertEqual(result[0].source, abi_source)

app/tests/routers/test_contracts.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from ...datasources.db.database import database_session
66
from ...datasources.db.models import Contract
77
from ...main import app
8-
from ...services.contract import ContractService
98
from ..db.db_async_conn import DbAsyncConn
109

1110

@@ -18,13 +17,12 @@ def setUpClass(cls):
1817

1918
@database_session
2019
async def test_view_contracts(self, session: AsyncSession):
21-
contract = Contract(address=b"a", name="A Test Contracts")
20+
contract = Contract(address=b"a", name="A Test Contracts", chain_id=1)
2221
expected_response = {
2322
"name": "A Test Contracts",
2423
"description": None,
2524
"address": "a",
2625
}
27-
await ContractService.create(contract=contract, session=session)
26+
await contract.create(session)
2827
response = self.client.get("/api/v1/contracts")
2928
self.assertEqual(response.status_code, 200)
30-
self.assertDictEqual(response.json()[0], expected_response)
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
"""init
2+
3+
Revision ID: 9912fd3fc9ce
4+
Revises:
5+
Create Date: 2024-12-13 11:23:10.023773
6+
7+
"""
8+
9+
from typing import Sequence, Union
10+
11+
import sqlalchemy as sa
12+
import sqlmodel
13+
from alembic import op
14+
15+
# revision identifiers, used by Alembic.
16+
revision: str = "9912fd3fc9ce"
17+
down_revision: Union[str, None] = None
18+
branch_labels: Union[str, Sequence[str], None] = None
19+
depends_on: Union[str, Sequence[str], None] = None
20+
21+
22+
def upgrade() -> None:
23+
# ### commands auto generated by Alembic - please adjust! ###
24+
op.create_table(
25+
"abisource",
26+
sa.Column("id", sa.Integer(), nullable=False),
27+
sa.Column("name", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
28+
sa.Column("url", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
29+
sa.PrimaryKeyConstraint("id"),
30+
)
31+
op.create_table(
32+
"project",
33+
sa.Column("id", sa.Integer(), nullable=False),
34+
sa.Column("description", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
35+
sa.Column("logo_file", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
36+
sa.PrimaryKeyConstraint("id"),
37+
)
38+
op.create_table(
39+
"abi",
40+
sa.Column("id", sa.Integer(), nullable=False),
41+
sa.Column("abi_hash", sa.LargeBinary(), nullable=False),
42+
sa.Column("relevance", sa.Integer(), nullable=False),
43+
sa.Column("abi_json", sa.JSON(), nullable=True),
44+
sa.Column("source_id", sa.Integer(), nullable=False),
45+
sa.ForeignKeyConstraint(
46+
["source_id"],
47+
["abisource.id"],
48+
),
49+
sa.PrimaryKeyConstraint("id"),
50+
)
51+
op.create_index(op.f("ix_abi_abi_hash"), "abi", ["abi_hash"], unique=True)
52+
op.create_table(
53+
"contract",
54+
sa.Column("id", sa.Integer(), nullable=False),
55+
sa.Column("address", sa.LargeBinary(), nullable=False),
56+
sa.Column("name", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
57+
sa.Column("display_name", sqlmodel.sql.sqltypes.AutoString(), nullable=True),
58+
sa.Column("description", sqlmodel.sql.sqltypes.AutoString(), nullable=True),
59+
sa.Column("trusted_for_delegate", sa.Boolean(), nullable=False),
60+
sa.Column("proxy", sa.Boolean(), nullable=False),
61+
sa.Column("fetch_retries", sa.Integer(), nullable=False),
62+
sa.Column("abi_id", sa.LargeBinary(), nullable=True),
63+
sa.Column("project_id", sa.Integer(), nullable=True),
64+
sa.Column("chain_id", sa.Integer(), nullable=False),
65+
sa.ForeignKeyConstraint(
66+
["abi_id"],
67+
["abi.abi_hash"],
68+
),
69+
sa.ForeignKeyConstraint(
70+
["project_id"],
71+
["project.id"],
72+
),
73+
sa.PrimaryKeyConstraint("id"),
74+
sa.UniqueConstraint("address", "chain_id", name="address_chain_unique"),
75+
)
76+
op.create_index(op.f("ix_contract_address"), "contract", ["address"], unique=False)
77+
# ### end Alembic commands ###
78+
79+
80+
def downgrade() -> None:
81+
# ### commands auto generated by Alembic - please adjust! ###
82+
op.drop_index(op.f("ix_contract_address"), table_name="contract")
83+
op.drop_table("contract")
84+
op.drop_index(op.f("ix_abi_abi_hash"), table_name="abi")
85+
op.drop_table("abi")
86+
op.drop_table("project")
87+
op.drop_table("abisource")
88+
# ### end Alembic commands ###

migrations/versions/d0c5d72aa50b_init.py

Lines changed: 0 additions & 37 deletions
This file was deleted.

0 commit comments

Comments
 (0)