Skip to content

Commit 8939ad8

Browse files
committed
feat: update messages endpoint to return a conversation summary
Modify the messages endpoint to return just a conversationsummary, that will simplify the current queries. Create a different endpoint that will return a list of conversations for a given prompt id
1 parent c25c5c7 commit 8939ad8

File tree

4 files changed

+155
-126
lines changed

4 files changed

+155
-126
lines changed

src/codegate/api/v1.py

+81-36
Original file line numberDiff line numberDiff line change
@@ -414,11 +414,11 @@ async def get_workspace_alerts_summary(workspace_name: str) -> v1_models.AlertSu
414414
raise HTTPException(status_code=500, detail="Internal server error")
415415

416416
try:
417-
summary = await dbreader.get_alerts_summary_by_workspace(ws.id)
417+
summary = await dbreader.get_alerts_summary(workspace_id=ws.id)
418418
return v1_models.AlertSummary(
419-
malicious_packages=summary["codegate_context_retriever_count"],
420-
pii=summary["codegate_pii_count"],
421-
secrets=summary["codegate_secrets_count"],
419+
malicious_packages=summary.total_packages_count,
420+
pii=summary.total_pii_count,
421+
secrets=summary.total_secrets_count,
422422
)
423423
except Exception:
424424
logger.exception("Error while getting alerts summary")
@@ -447,43 +447,88 @@ async def get_workspace_messages(
447447
raise HTTPException(status_code=500, detail="Internal server error")
448448

449449
offset = (page - 1) * page_size
450-
fetched_messages: List[v1_models.Conversation] = []
451450

452-
try:
453-
while len(fetched_messages) < page_size:
454-
messages_batch = await dbreader.get_messages(
455-
ws.id,
456-
offset,
457-
page_size,
458-
filter_by_ids,
459-
list([AlertSeverity.CRITICAL.value]), # TODO: Configurable severity
460-
filter_by_alert_trigger_types,
461-
)
462-
if not messages_batch:
463-
break
464-
parsed_conversations, _ = await v1_processing.parse_messages_in_conversations(
465-
messages_batch
466-
)
467-
fetched_messages.extend(parsed_conversations)
468-
469-
offset += len(messages_batch)
470-
471-
final_messages = fetched_messages[:page_size]
472-
473-
# Fetch total message count
474-
total_count = await dbreader.get_total_messages_count_by_workspace_id(
475-
ws.id, AlertSeverity.CRITICAL.value
451+
prompts = await dbreader.get_prompts(
452+
ws.id,
453+
offset,
454+
page_size,
455+
filter_by_ids,
456+
list([AlertSeverity.CRITICAL.value]), # TODO: Configurable severity
457+
filter_by_alert_trigger_types,
458+
)
459+
# Fetch total message count
460+
total_count = await dbreader.get_total_messages_count_by_workspace_id(
461+
ws.id, AlertSeverity.CRITICAL.value
462+
)
463+
464+
# iterate for all prompts to compose the conversation summary
465+
conversation_summaries: List[v1_models.ConversationSummary] = []
466+
for prompt in prompts:
467+
if not prompt.request:
468+
logger.warning(f"Skipping prompt {prompt.id}. Empty request field")
469+
continue
470+
471+
messages, _ = await v1_processing.parse_request(prompt.request)
472+
if not messages or len(messages) == 0:
473+
logger.warning(f"Skipping prompt {prompt.id}. No messages found")
474+
continue
475+
476+
# message is just the first entry in the request
477+
message_obj = v1_models.ChatMessage(
478+
message=messages[0], timestamp=prompt.timestamp, message_id=prompt.id
476479
)
477-
return v1_models.PaginatedMessagesResponse(
478-
data=final_messages,
479-
limit=page_size,
480-
offset=(page - 1) * page_size,
481-
total=total_count,
480+
481+
# count total alerts for the prompt
482+
total_alerts_row = await dbreader.get_alerts_summary(prompt_id=prompt.id)
483+
484+
# get token usage for the prompt
485+
prompts_outputs = await dbreader.get_prompts_with_output(prompt_id=prompt.id)
486+
ws_token_usage = await v1_processing.parse_workspace_token_usage(prompts_outputs)
487+
488+
conversation_summary = v1_models.ConversationSummary(
489+
chat_id=prompt.id,
490+
prompt=message_obj,
491+
provider=prompt.provider,
492+
type=prompt.type,
493+
conversation_timestamp=prompt.timestamp,
494+
total_alerts=total_alerts_row.total_alerts,
495+
token_usage_agg=ws_token_usage,
482496
)
497+
498+
conversation_summaries.append(conversation_summary)
499+
500+
return v1_models.PaginatedMessagesResponse(
501+
data=conversation_summaries,
502+
limit=page_size,
503+
offset=(page - 1) * page_size,
504+
total=total_count,
505+
)
506+
507+
508+
@v1.get(
509+
"/workspaces/{workspace_name}/messages/{prompt_id}",
510+
tags=["Workspaces"],
511+
generate_unique_id_function=uniq_name,
512+
)
513+
async def get_messages_by_prompt_id(
514+
workspace_name: str,
515+
prompt_id: str,
516+
) -> List[v1_models.Conversation]:
517+
"""Get messages for a workspace."""
518+
try:
519+
ws = await wscrud.get_workspace_by_name(workspace_name)
520+
except crud.WorkspaceDoesNotExistError:
521+
raise HTTPException(status_code=404, detail="Workspace does not exist")
483522
except Exception:
484-
logger.exception("Error while getting messages")
523+
logger.exception("Error while getting workspace")
485524
raise HTTPException(status_code=500, detail="Internal server error")
486525

526+
prompts_outputs = await dbreader.get_prompts_with_output(
527+
workspace_id=ws.id, prompt_id=prompt_id
528+
)
529+
conversations, _ = await v1_processing.parse_messages_in_conversations(prompts_outputs)
530+
return conversations
531+
487532

488533
@v1.get(
489534
"/workspaces/{workspace_name}/custom-instructions",
@@ -666,7 +711,7 @@ async def get_workspace_token_usage(workspace_name: str) -> v1_models.TokenUsage
666711
raise HTTPException(status_code=500, detail="Internal server error")
667712

668713
try:
669-
prompts_outputs = await dbreader.get_prompts_with_output(ws.id)
714+
prompts_outputs = await dbreader.get_prompts_with_output(worskpace_id=ws.id)
670715
ws_token_usage = await v1_processing.parse_workspace_token_usage(prompts_outputs)
671716
return ws_token_usage
672717
except Exception:

src/codegate/api/v1_models.py

+15-1
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,20 @@ class Conversation(pydantic.BaseModel):
225225
alerts: List[Alert] = []
226226

227227

228+
class ConversationSummary(pydantic.BaseModel):
229+
"""
230+
Represents a conversation summary.
231+
"""
232+
233+
chat_id: str
234+
prompt: ChatMessage
235+
total_alerts: int
236+
token_usage_agg: Optional[TokenUsageAggregate]
237+
provider: Optional[str]
238+
type: QuestionType
239+
conversation_timestamp: datetime.datetime
240+
241+
228242
class AlertConversation(pydantic.BaseModel):
229243
"""
230244
Represents an alert with it's respective conversation.
@@ -325,7 +339,7 @@ def __str__(self):
325339

326340

327341
class PaginatedMessagesResponse(pydantic.BaseModel):
328-
data: List[Conversation]
342+
data: List[ConversationSummary]
329343
limit: int
330344
offset: int
331345
total: int

0 commit comments

Comments
 (0)