Skip to content

[DS-4136] Replace websocket w/ REST in Search RAG #1596

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions swirl/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
class RagError(Exception):
"""Exception raised for errors related to RAG."""
def __init__(self, message="Error with RAG"):
self.message = message
super().__init__(self.message)
169 changes: 150 additions & 19 deletions swirl/serializers.py
Original file line number Diff line number Diff line change
@@ -1,67 +1,198 @@
'''
"""
@author: Sid Probstein
@contact: [email protected]
'''
"""

from django.contrib.auth.models import User, Group
from django.contrib.auth.models import Group, User
from rest_framework import serializers
from swirl.models import SearchProvider, Search, Result,QueryTransform

from swirl.models import QueryTransform, Result, Search, SearchProvider


class UserSerializer(serializers.HyperlinkedModelSerializer):
class Meta:
model = User
fields = ['url', 'username', 'email', 'groups']
fields = ["url", "username", "email", "groups"]


class GroupSerializer(serializers.HyperlinkedModelSerializer):
class Meta:
model = Group
fields = ['url', 'name']
fields = ["url", "name"]


class SearchProviderSerializer(serializers.ModelSerializer):
owner = serializers.ReadOnlyField(source='owner.username')
owner = serializers.ReadOnlyField(source="owner.username")

class Meta:
model = SearchProvider
fields = ['id', 'name', 'owner', 'shared', 'date_created', 'date_updated', 'active', 'default', 'authenticator','connector', 'url', 'query_template', 'query_template_json', 'post_query_template', 'http_request_headers', 'page_fetch_config_json', 'query_processors', 'query_mappings', 'result_grouping_field', 'result_processors', 'response_mappings', 'result_mappings', 'results_per_query', 'credentials', 'eval_credentials', 'tags']
fields = [
"id",
"name",
"owner",
"shared",
"date_created",
"date_updated",
"active",
"default",
"authenticator",
"connector",
"url",
"query_template",
"query_template_json",
"post_query_template",
"http_request_headers",
"page_fetch_config_json",
"query_processors",
"query_mappings",
"result_grouping_field",
"result_processors",
"response_mappings",
"result_mappings",
"results_per_query",
"credentials",
"eval_credentials",
"tags",
]


class SearchProviderNoCredentialsSerializer(serializers.ModelSerializer):
owner = serializers.ReadOnlyField(source='owner.username')
owner = serializers.ReadOnlyField(source="owner.username")

class Meta:
model = SearchProvider
fields = ['id', 'name', 'owner', 'shared', 'date_created', 'date_updated', 'active', 'default', 'authenticator', 'connector', 'url', 'query_template', 'query_template_json', 'post_query_template', 'http_request_headers', 'page_fetch_config_json', 'query_processors', 'query_mappings', 'result_processors', 'response_mappings', 'result_mappings', 'results_per_query', 'tags']
fields = [
"id",
"name",
"owner",
"shared",
"date_created",
"date_updated",
"active",
"default",
"authenticator",
"connector",
"url",
"query_template",
"query_template_json",
"post_query_template",
"http_request_headers",
"page_fetch_config_json",
"query_processors",
"query_mappings",
"result_processors",
"response_mappings",
"result_mappings",
"results_per_query",
"tags",
]


class SearchSerializer(serializers.ModelSerializer):
owner = serializers.ReadOnlyField(source='owner.username')
owner = serializers.ReadOnlyField(source="owner.username")

class Meta:
model = Search
fields = ['id', 'owner', 'date_created', 'date_updated', 'query_string', 'query_string_processed', 'sort', 'results_requested', 'searchprovider_list', 'subscribe', 'status', 'pre_query_processors', 'post_result_processors', 'result_url', 'new_result_url', 'messages', 'result_mixer', 'retention', 'tags']
fields = [
"id",
"owner",
"date_created",
"date_updated",
"query_string",
"query_string_processed",
"sort",
"results_requested",
"searchprovider_list",
"subscribe",
"status",
"pre_query_processors",
"post_result_processors",
"result_url",
"new_result_url",
"messages",
"result_mixer",
"retention",
"tags",
]


class ResultSerializer(serializers.ModelSerializer):
owner = serializers.ReadOnlyField(source='owner.username')
owner = serializers.ReadOnlyField(source="owner.username")

class Meta:
model = Result
fields = ['id', 'owner', 'date_created', 'date_updated', 'search_id', 'searchprovider', 'query_to_provider', 'query_processors', 'result_processors', 'result_processor_json_feedback', 'messages', 'status', 'retrieved', 'found', 'time', 'json_results', 'tags']
fields = [
"id",
"owner",
"date_created",
"date_updated",
"search_id",
"searchprovider",
"query_to_provider",
"query_processors",
"result_processors",
"result_processor_json_feedback",
"messages",
"status",
"retrieved",
"found",
"time",
"json_results",
"tags",
]


class QueryTransformSerializer(serializers.ModelSerializer):
owner = serializers.ReadOnlyField(source='owner.username')
owner = serializers.ReadOnlyField(source="owner.username")

class Meta:
model = QueryTransform
fields = ['id', 'name','owner','shared','date_created','date_updated','qrx_type','config_content']
fields = [
"id",
"name",
"owner",
"shared",
"date_created",
"date_updated",
"qrx_type",
"config_content",
]


class QueryTransformNoCredentialsSerializer(serializers.ModelSerializer):
owner = serializers.ReadOnlyField(source='owner.username')
owner = serializers.ReadOnlyField(source="owner.username")

class Meta:
model = QueryTransform
fields = ['id', 'name','owner','shared', 'date_created','date_updated','qrx_type','config_content']
fields = [
"id",
"name",
"owner",
"shared",
"date_created",
"date_updated",
"qrx_type",
"config_content",
]


class DetailSearchRagSerializer(serializers.Serializer):
message = serializers.CharField(required=True, allow_blank=True)

class Meta:
fields = ["message"]


###
# Minimal Serializers for drf-spectacular OpenAPI documentation only
class LoginRequestSerializer(serializers.Serializer):
username = serializers.CharField()
password = serializers.CharField()


class AuthResponseSerializer(serializers.Serializer):
token = serializers.CharField()
user = serializers.CharField()


class StatusResponseSerializer(serializers.Serializer):
status = serializers.CharField()
1 change: 1 addition & 0 deletions swirl/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
path('querytransforms/delete/<int:pk>/', views.QueryTransformViewSet.as_view({'delete': 'destroy'}), name='delete'),

path('search/search', views.SearchViewSet.as_view({'get': 'list'}), name='search'),
path('sapi/detail-search-rag/', views.DetailSearchRagView.as_view(), name='detail-search-rag'),

path('', views.index, name='index'),
path('index.html', views.index, name='index'),
Expand Down
20 changes: 19 additions & 1 deletion swirl/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,10 @@
from swirl.models import *
from swirl.serializers import *
from swirl.models import SearchProvider, Search, Result, QueryTransform, Authenticator as AuthenticatorModel, OauthToken
from swirl.serializers import UserSerializer, GroupSerializer, SearchProviderSerializer, SearchSerializer, ResultSerializer, QueryTransformSerializer, QueryTransformNoCredentialsSerializer, LoginRequestSerializer, StatusResponseSerializer, AuthResponseSerializer
from swirl.serializers import UserSerializer, DetailSearchRagSerializer, GroupSerializer, SearchProviderSerializer, SearchSerializer, ResultSerializer, QueryTransformSerializer, QueryTransformNoCredentialsSerializer, LoginRequestSerializer, StatusResponseSerializer, AuthResponseSerializer
from swirl.authenticators.authenticator import Authenticator
from swirl.authenticators import *
from swirl.views_helpers.search_rag import SearchRag

module_name = 'views.py'

Expand Down Expand Up @@ -642,6 +643,23 @@ def partial_update(self, request, pk=None):
########################################
########################################

class DetailSearchRagView(APIView):
serializer_class = DetailSearchRagSerializer
authentication_classes = [SessionAuthentication, BasicAuthentication]
permission_classes = []

def get(self, request):
logger.debug(f"DetailSearchRagView: {request.GET}")
search_rag = SearchRag(request)
result = search_rag.process_rag()
serializer = DetailSearchRagSerializer(result)
return Response(serializer.data, status=status.HTTP_200_OK)


########################################
########################################


class ResultViewSet(viewsets.ModelViewSet):
"""
API endpoint for managing Result objects, including Mixed Results
Expand Down
80 changes: 80 additions & 0 deletions swirl/views_helpers/search_rag.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import logging

from rest_framework.request import Request

from swirl.exceptions import RagError
from swirl.models import Result
from swirl.processors import *

logger = logging.getLogger(__name__)

instances = {}


class SearchRag:

def __init__(self, request_data: Request) -> None:
request_data = request_data.GET.dict()
logger.info(f"{self}: init search rag {request_data}")
self.search_id = request_data.get("search_id", None)

# Parse query parameters
self.rag_query_items = []
rag_query_items = request_data.get("rag_items", [""])[0]

if rag_query_items:
self.rag_query_items = rag_query_items.split(",")

def get_rag_result(self) -> tuple[str, dict[str, str]]:
isRagItemsUpdated = False
try:
rag_result = Result.objects.get(
search_id=self.search_id, searchprovider="ChatGPT"
)
isRagItemsUpdated = True
isRagItemsUpdated = not (
set(rag_result.json_results[0]["rag_query_items"])
== set(self.rag_query_items)
)
except:
pass
try:
rag_result = Result.objects.get(
search_id=self.search_id, searchprovider="ChatGPT"
)
isRagItemsUpdated = not (
set(rag_result.json_results[0]["rag_query_items"])
== set(self.rag_query_items)
)
if rag_result and not isRagItemsUpdated:
if rag_result.json_results[0]["body"][0]:
return rag_result.json_results[0]["body"][0]
return False
except:
pass
rag_processor = RAGPostResultProcessor(
search_id=self.search_id,
request_id="",
should_get_results=True,
rag_query_items=self.rag_query_items,
)
instances[self.search_id] = rag_processor
if rag_processor.validate():
result = rag_processor.process(should_return=True)
try:
if self.search_id in instances:
del instances[self.search_id]
return result.json_results[0]["body"][0]
except:
if self.search_id in instances:
del instances[self.search_id]
return False

def process_rag(self) -> dict[str, str]:
result = ""
try:
result = self.get_rag_result()
except RagError as err:
logger.error(f"{self}: Rag Error {err}")

return {"message": result}