Skip to content

Commit 3798440

Browse files
committed
fixes from rebase
1 parent e9f273a commit 3798440

File tree

2 files changed

+47
-2
lines changed

2 files changed

+47
-2
lines changed

src/codegate/db/connection.py

+31-2
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import sqlite3
55
import uuid
66
from pathlib import Path
7-
from typing import Dict, List, Optional, Type
7+
from typing import List, Optional, Type
88

99
import numpy as np
1010
import sqlite_vec_sl_tmp
@@ -28,7 +28,6 @@
2828
GetMessagesRow,
2929
GetWorkspaceByNameConditions,
3030
Instance,
31-
IntermediatePromptWithOutputUsageAlerts,
3231
MuxRule,
3332
Output,
3433
Persona,
@@ -795,6 +794,36 @@ async def get_prompts(
795794
)
796795
return rows
797796

797+
async def get_total_messages_count_by_workspace_id(
798+
self, workspace_id: str, trigger_category: Optional[str] = None
799+
) -> int:
800+
"""
801+
Get total count of unique messages for a given workspace_id,
802+
considering trigger_category.
803+
"""
804+
sql = text(
805+
"""
806+
SELECT COUNT(DISTINCT p.id)
807+
FROM prompts p
808+
LEFT JOIN alerts a ON p.id = a.prompt_id
809+
WHERE p.workspace_id = :workspace_id
810+
"""
811+
)
812+
conditions = {"workspace_id": workspace_id}
813+
814+
if trigger_category:
815+
sql = text(sql.text + " AND a.trigger_category = :trigger_category")
816+
conditions["trigger_category"] = trigger_category
817+
818+
async with self._async_db_engine.begin() as conn:
819+
try:
820+
result = await conn.execute(sql, conditions)
821+
count = result.scalar() # Fetches the integer result directly
822+
return count or 0 # Ensure it returns an integer
823+
except Exception as e:
824+
logger.error(f"Failed to fetch message count. Error: {e}")
825+
return 0 # Return 0 in case of failure
826+
798827
async def get_alerts_by_workspace(
799828
self, workspace_id: str, trigger_category: Optional[str] = None
800829
) -> List[Alert]:

src/codegate/db/models.py

+16
Original file line numberDiff line numberDiff line change
@@ -337,3 +337,19 @@ class PersonaDistance(Persona):
337337
"""
338338

339339
distance: float
340+
341+
342+
class GetMessagesRow(BaseModel):
343+
id: Any
344+
timestamp: Any
345+
provider: Optional[Any]
346+
request: Any
347+
type: Any
348+
output_id: Optional[Any]
349+
output: Optional[Any]
350+
output_timestamp: Optional[Any]
351+
input_tokens: Optional[int]
352+
output_tokens: Optional[int]
353+
input_cost: Optional[float]
354+
output_cost: Optional[float]
355+
alerts: List[Alert] = []

0 commit comments

Comments
 (0)