Skip to content

Commit 9e22cb3

Browse files
committed
search rag endpoint
1 parent 5ed8916 commit 9e22cb3

File tree

5 files changed

+255
-20
lines changed

5 files changed

+255
-20
lines changed

swirl/exceptions.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
class RagError(Exception):
2+
"""Exception raised for errors related to RAG."""
3+
def __init__(self, message="Error with RAG"):
4+
self.message = message
5+
super().__init__(self.message)

swirl/serializers.py

Lines changed: 150 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,67 +1,198 @@
1-
'''
1+
"""
22
@author: Sid Probstein
33
4-
'''
4+
"""
55

6-
from django.contrib.auth.models import User, Group
6+
from django.contrib.auth.models import Group, User
77
from rest_framework import serializers
8-
from swirl.models import SearchProvider, Search, Result,QueryTransform
8+
9+
from swirl.models import QueryTransform, Result, Search, SearchProvider
10+
911

1012
class UserSerializer(serializers.HyperlinkedModelSerializer):
1113
class Meta:
1214
model = User
13-
fields = ['url', 'username', 'email', 'groups']
15+
fields = ["url", "username", "email", "groups"]
16+
1417

1518
class GroupSerializer(serializers.HyperlinkedModelSerializer):
1619
class Meta:
1720
model = Group
18-
fields = ['url', 'name']
21+
fields = ["url", "name"]
22+
1923

2024
class SearchProviderSerializer(serializers.ModelSerializer):
21-
owner = serializers.ReadOnlyField(source='owner.username')
25+
owner = serializers.ReadOnlyField(source="owner.username")
26+
2227
class Meta:
2328
model = SearchProvider
24-
fields = ['id', 'name', 'owner', 'shared', 'date_created', 'date_updated', 'active', 'default', 'authenticator','connector', 'url', 'query_template', 'query_template_json', 'post_query_template', 'http_request_headers', 'page_fetch_config_json', 'query_processors', 'query_mappings', 'result_grouping_field', 'result_processors', 'response_mappings', 'result_mappings', 'results_per_query', 'credentials', 'eval_credentials', 'tags']
29+
fields = [
30+
"id",
31+
"name",
32+
"owner",
33+
"shared",
34+
"date_created",
35+
"date_updated",
36+
"active",
37+
"default",
38+
"authenticator",
39+
"connector",
40+
"url",
41+
"query_template",
42+
"query_template_json",
43+
"post_query_template",
44+
"http_request_headers",
45+
"page_fetch_config_json",
46+
"query_processors",
47+
"query_mappings",
48+
"result_grouping_field",
49+
"result_processors",
50+
"response_mappings",
51+
"result_mappings",
52+
"results_per_query",
53+
"credentials",
54+
"eval_credentials",
55+
"tags",
56+
]
57+
2558

2659
class SearchProviderNoCredentialsSerializer(serializers.ModelSerializer):
27-
owner = serializers.ReadOnlyField(source='owner.username')
60+
owner = serializers.ReadOnlyField(source="owner.username")
61+
2862
class Meta:
2963
model = SearchProvider
30-
fields = ['id', 'name', 'owner', 'shared', 'date_created', 'date_updated', 'active', 'default', 'authenticator', 'connector', 'url', 'query_template', 'query_template_json', 'post_query_template', 'http_request_headers', 'page_fetch_config_json', 'query_processors', 'query_mappings', 'result_processors', 'response_mappings', 'result_mappings', 'results_per_query', 'tags']
64+
fields = [
65+
"id",
66+
"name",
67+
"owner",
68+
"shared",
69+
"date_created",
70+
"date_updated",
71+
"active",
72+
"default",
73+
"authenticator",
74+
"connector",
75+
"url",
76+
"query_template",
77+
"query_template_json",
78+
"post_query_template",
79+
"http_request_headers",
80+
"page_fetch_config_json",
81+
"query_processors",
82+
"query_mappings",
83+
"result_processors",
84+
"response_mappings",
85+
"result_mappings",
86+
"results_per_query",
87+
"tags",
88+
]
89+
3190

3291
class SearchSerializer(serializers.ModelSerializer):
33-
owner = serializers.ReadOnlyField(source='owner.username')
92+
owner = serializers.ReadOnlyField(source="owner.username")
93+
3494
class Meta:
3595
model = Search
36-
fields = ['id', 'owner', 'date_created', 'date_updated', 'query_string', 'query_string_processed', 'sort', 'results_requested', 'searchprovider_list', 'subscribe', 'status', 'pre_query_processors', 'post_result_processors', 'result_url', 'new_result_url', 'messages', 'result_mixer', 'retention', 'tags']
96+
fields = [
97+
"id",
98+
"owner",
99+
"date_created",
100+
"date_updated",
101+
"query_string",
102+
"query_string_processed",
103+
"sort",
104+
"results_requested",
105+
"searchprovider_list",
106+
"subscribe",
107+
"status",
108+
"pre_query_processors",
109+
"post_result_processors",
110+
"result_url",
111+
"new_result_url",
112+
"messages",
113+
"result_mixer",
114+
"retention",
115+
"tags",
116+
]
117+
37118

38119
class ResultSerializer(serializers.ModelSerializer):
39-
owner = serializers.ReadOnlyField(source='owner.username')
120+
owner = serializers.ReadOnlyField(source="owner.username")
121+
40122
class Meta:
41123
model = Result
42-
fields = ['id', 'owner', 'date_created', 'date_updated', 'search_id', 'searchprovider', 'query_to_provider', 'query_processors', 'result_processors', 'result_processor_json_feedback', 'messages', 'status', 'retrieved', 'found', 'time', 'json_results', 'tags']
124+
fields = [
125+
"id",
126+
"owner",
127+
"date_created",
128+
"date_updated",
129+
"search_id",
130+
"searchprovider",
131+
"query_to_provider",
132+
"query_processors",
133+
"result_processors",
134+
"result_processor_json_feedback",
135+
"messages",
136+
"status",
137+
"retrieved",
138+
"found",
139+
"time",
140+
"json_results",
141+
"tags",
142+
]
143+
43144

44145
class QueryTransformSerializer(serializers.ModelSerializer):
45-
owner = serializers.ReadOnlyField(source='owner.username')
146+
owner = serializers.ReadOnlyField(source="owner.username")
147+
46148
class Meta:
47149
model = QueryTransform
48-
fields = ['id', 'name','owner','shared','date_created','date_updated','qrx_type','config_content']
150+
fields = [
151+
"id",
152+
"name",
153+
"owner",
154+
"shared",
155+
"date_created",
156+
"date_updated",
157+
"qrx_type",
158+
"config_content",
159+
]
160+
49161

50162
class QueryTransformNoCredentialsSerializer(serializers.ModelSerializer):
51-
owner = serializers.ReadOnlyField(source='owner.username')
163+
owner = serializers.ReadOnlyField(source="owner.username")
164+
52165
class Meta:
53166
model = QueryTransform
54-
fields = ['id', 'name','owner','shared', 'date_created','date_updated','qrx_type','config_content']
167+
fields = [
168+
"id",
169+
"name",
170+
"owner",
171+
"shared",
172+
"date_created",
173+
"date_updated",
174+
"qrx_type",
175+
"config_content",
176+
]
177+
178+
179+
class DetailSearchRagSerializer(serializers.Serializer):
180+
message = serializers.CharField(required=True, allow_blank=True)
181+
182+
class Meta:
183+
fields = ["message"]
184+
55185

56-
###
57186
# Minimal Serializers for drf-spectacular OpenAPI documentation only
58187
class LoginRequestSerializer(serializers.Serializer):
59188
username = serializers.CharField()
60189
password = serializers.CharField()
61190

191+
62192
class AuthResponseSerializer(serializers.Serializer):
63193
token = serializers.CharField()
64194
user = serializers.CharField()
65195

196+
66197
class StatusResponseSerializer(serializers.Serializer):
67198
status = serializers.CharField()

swirl/urls.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
path('querytransforms/delete/<int:pk>/', views.QueryTransformViewSet.as_view({'delete': 'destroy'}), name='delete'),
4343

4444
path('search/search', views.SearchViewSet.as_view({'get': 'list'}), name='search'),
45+
path('sapi/detail-search-rag/', views.DetailSearchRagView.as_view(), name='detail-search-rag'),
4546

4647
path('', views.index, name='index'),
4748
path('index.html', views.index, name='index'),

swirl/views.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,10 @@
3939
from swirl.models import *
4040
from swirl.serializers import *
4141
from swirl.models import SearchProvider, Search, Result, QueryTransform, Authenticator as AuthenticatorModel, OauthToken
42-
from swirl.serializers import UserSerializer, GroupSerializer, SearchProviderSerializer, SearchSerializer, ResultSerializer, QueryTransformSerializer, QueryTransformNoCredentialsSerializer, LoginRequestSerializer, StatusResponseSerializer, AuthResponseSerializer
42+
from swirl.serializers import UserSerializer, DetailSearchRagSerializer, GroupSerializer, SearchProviderSerializer, SearchSerializer, ResultSerializer, QueryTransformSerializer, QueryTransformNoCredentialsSerializer, LoginRequestSerializer, StatusResponseSerializer, AuthResponseSerializer
4343
from swirl.authenticators.authenticator import Authenticator
4444
from swirl.authenticators import *
45+
from swirl.views_helpers.search_rag import SearchRag
4546

4647
module_name = 'views.py'
4748

@@ -642,6 +643,23 @@ def partial_update(self, request, pk=None):
642643
########################################
643644
########################################
644645

646+
class DetailSearchRagView(APIView):
647+
serializer_class = DetailSearchRagSerializer
648+
authentication_classes = [SessionAuthentication, BasicAuthentication]
649+
permission_classes = []
650+
651+
def get(self, request):
652+
logger.debug(f"DetailSearchRagView: {request.GET}")
653+
search_rag = SearchRag(request)
654+
result = search_rag.process_rag()
655+
serializer = DetailSearchRagSerializer(result)
656+
return Response(serializer.data, status=status.HTTP_200_OK)
657+
658+
659+
########################################
660+
########################################
661+
662+
645663
class ResultViewSet(viewsets.ModelViewSet):
646664
"""
647665
API endpoint for managing Result objects, including Mixed Results

swirl/views_helpers/search_rag.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
import logging
2+
3+
from rest_framework.request import Request
4+
5+
from swirl.exceptions import RagError
6+
from swirl.models import Result
7+
from swirl.processors import *
8+
9+
logger = logging.getLogger(__name__)
10+
11+
instances = {}
12+
13+
14+
class SearchRag:
15+
16+
def __init__(self, request_data: Request) -> None:
17+
request_data = request_data.GET.dict()
18+
logger.info(f"{self}: init search rag {request_data}")
19+
self.search_id = request_data.get("search_id", None)
20+
21+
# Parse query parameters
22+
self.rag_query_items = []
23+
rag_query_items = request_data.get("rag_items", [""])[0]
24+
25+
if rag_query_items:
26+
self.rag_query_items = rag_query_items.split(",")
27+
28+
def get_rag_result(self) -> tuple[str, dict[str, str]]:
29+
isRagItemsUpdated = False
30+
try:
31+
rag_result = Result.objects.get(
32+
search_id=self.search_id, searchprovider="ChatGPT"
33+
)
34+
isRagItemsUpdated = True
35+
isRagItemsUpdated = not (
36+
set(rag_result.json_results[0]["rag_query_items"])
37+
== set(self.rag_query_items)
38+
)
39+
except:
40+
pass
41+
try:
42+
rag_result = Result.objects.get(
43+
search_id=self.search_id, searchprovider="ChatGPT"
44+
)
45+
isRagItemsUpdated = not (
46+
set(rag_result.json_results[0]["rag_query_items"])
47+
== set(self.rag_query_items)
48+
)
49+
if rag_result and not isRagItemsUpdated:
50+
if rag_result.json_results[0]["body"][0]:
51+
return rag_result.json_results[0]["body"][0]
52+
return False
53+
except:
54+
pass
55+
rag_processor = RAGPostResultProcessor(
56+
search_id=self.search_id,
57+
request_id="",
58+
should_get_results=True,
59+
rag_query_items=self.rag_query_items,
60+
)
61+
instances[self.search_id] = rag_processor
62+
if rag_processor.validate():
63+
result = rag_processor.process(should_return=True)
64+
try:
65+
if self.search_id in instances:
66+
del instances[self.search_id]
67+
return result.json_results[0]["body"][0]
68+
except:
69+
if self.search_id in instances:
70+
del instances[self.search_id]
71+
return False
72+
73+
def process_rag(self) -> dict[str, str]:
74+
result = ""
75+
try:
76+
result = self.get_rag_result()
77+
except RagError as err:
78+
logger.error(f"{self}: Rag Error {err}")
79+
80+
return {"message": result}

0 commit comments

Comments
 (0)