diff --git a/swirl/exceptions.py b/swirl/exceptions.py new file mode 100644 index 000000000..ba724dcc8 --- /dev/null +++ b/swirl/exceptions.py @@ -0,0 +1,5 @@ +class RagError(Exception): + """Exception raised for errors related to RAG.""" + def __init__(self, message="Error with RAG"): + self.message = message + super().__init__(self.message) diff --git a/swirl/serializers.py b/swirl/serializers.py index df13204ff..01ff3f3bf 100644 --- a/swirl/serializers.py +++ b/swirl/serializers.py @@ -1,67 +1,198 @@ -''' +""" @author: Sid Probstein @contact: sid@swirl.today -''' +""" -from django.contrib.auth.models import User, Group +from django.contrib.auth.models import Group, User from rest_framework import serializers -from swirl.models import SearchProvider, Search, Result,QueryTransform + +from swirl.models import QueryTransform, Result, Search, SearchProvider + class UserSerializer(serializers.HyperlinkedModelSerializer): class Meta: model = User - fields = ['url', 'username', 'email', 'groups'] + fields = ["url", "username", "email", "groups"] + class GroupSerializer(serializers.HyperlinkedModelSerializer): class Meta: model = Group - fields = ['url', 'name'] + fields = ["url", "name"] + class SearchProviderSerializer(serializers.ModelSerializer): - owner = serializers.ReadOnlyField(source='owner.username') + owner = serializers.ReadOnlyField(source="owner.username") + class Meta: model = SearchProvider - fields = ['id', 'name', 'owner', 'shared', 'date_created', 'date_updated', 'active', 'default', 'authenticator','connector', 'url', 'query_template', 'query_template_json', 'post_query_template', 'http_request_headers', 'page_fetch_config_json', 'query_processors', 'query_mappings', 'result_grouping_field', 'result_processors', 'response_mappings', 'result_mappings', 'results_per_query', 'credentials', 'eval_credentials', 'tags'] + fields = [ + "id", + "name", + "owner", + "shared", + "date_created", + "date_updated", + "active", + "default", + "authenticator", + "connector", + "url", + "query_template", + "query_template_json", + "post_query_template", + "http_request_headers", + "page_fetch_config_json", + "query_processors", + "query_mappings", + "result_grouping_field", + "result_processors", + "response_mappings", + "result_mappings", + "results_per_query", + "credentials", + "eval_credentials", + "tags", + ] + class SearchProviderNoCredentialsSerializer(serializers.ModelSerializer): - owner = serializers.ReadOnlyField(source='owner.username') + owner = serializers.ReadOnlyField(source="owner.username") + class Meta: model = SearchProvider - fields = ['id', 'name', 'owner', 'shared', 'date_created', 'date_updated', 'active', 'default', 'authenticator', 'connector', 'url', 'query_template', 'query_template_json', 'post_query_template', 'http_request_headers', 'page_fetch_config_json', 'query_processors', 'query_mappings', 'result_processors', 'response_mappings', 'result_mappings', 'results_per_query', 'tags'] + fields = [ + "id", + "name", + "owner", + "shared", + "date_created", + "date_updated", + "active", + "default", + "authenticator", + "connector", + "url", + "query_template", + "query_template_json", + "post_query_template", + "http_request_headers", + "page_fetch_config_json", + "query_processors", + "query_mappings", + "result_processors", + "response_mappings", + "result_mappings", + "results_per_query", + "tags", + ] + class SearchSerializer(serializers.ModelSerializer): - owner = serializers.ReadOnlyField(source='owner.username') + owner = serializers.ReadOnlyField(source="owner.username") + class Meta: model = Search - fields = ['id', 'owner', 'date_created', 'date_updated', 'query_string', 'query_string_processed', 'sort', 'results_requested', 'searchprovider_list', 'subscribe', 'status', 'pre_query_processors', 'post_result_processors', 'result_url', 'new_result_url', 'messages', 'result_mixer', 'retention', 'tags'] + fields = [ + "id", + "owner", + "date_created", + "date_updated", + "query_string", + "query_string_processed", + "sort", + "results_requested", + "searchprovider_list", + "subscribe", + "status", + "pre_query_processors", + "post_result_processors", + "result_url", + "new_result_url", + "messages", + "result_mixer", + "retention", + "tags", + ] + class ResultSerializer(serializers.ModelSerializer): - owner = serializers.ReadOnlyField(source='owner.username') + owner = serializers.ReadOnlyField(source="owner.username") + class Meta: model = Result - fields = ['id', 'owner', 'date_created', 'date_updated', 'search_id', 'searchprovider', 'query_to_provider', 'query_processors', 'result_processors', 'result_processor_json_feedback', 'messages', 'status', 'retrieved', 'found', 'time', 'json_results', 'tags'] + fields = [ + "id", + "owner", + "date_created", + "date_updated", + "search_id", + "searchprovider", + "query_to_provider", + "query_processors", + "result_processors", + "result_processor_json_feedback", + "messages", + "status", + "retrieved", + "found", + "time", + "json_results", + "tags", + ] + class QueryTransformSerializer(serializers.ModelSerializer): - owner = serializers.ReadOnlyField(source='owner.username') + owner = serializers.ReadOnlyField(source="owner.username") + class Meta: model = QueryTransform - fields = ['id', 'name','owner','shared','date_created','date_updated','qrx_type','config_content'] + fields = [ + "id", + "name", + "owner", + "shared", + "date_created", + "date_updated", + "qrx_type", + "config_content", + ] + class QueryTransformNoCredentialsSerializer(serializers.ModelSerializer): - owner = serializers.ReadOnlyField(source='owner.username') + owner = serializers.ReadOnlyField(source="owner.username") + class Meta: model = QueryTransform - fields = ['id', 'name','owner','shared', 'date_created','date_updated','qrx_type','config_content'] + fields = [ + "id", + "name", + "owner", + "shared", + "date_created", + "date_updated", + "qrx_type", + "config_content", + ] + + +class DetailSearchRagSerializer(serializers.Serializer): + message = serializers.CharField(required=True, allow_blank=True) + + class Meta: + fields = ["message"] + -### # Minimal Serializers for drf-spectacular OpenAPI documentation only class LoginRequestSerializer(serializers.Serializer): username = serializers.CharField() password = serializers.CharField() + class AuthResponseSerializer(serializers.Serializer): token = serializers.CharField() user = serializers.CharField() + class StatusResponseSerializer(serializers.Serializer): status = serializers.CharField() diff --git a/swirl/urls.py b/swirl/urls.py index d518d6119..631983cde 100644 --- a/swirl/urls.py +++ b/swirl/urls.py @@ -42,6 +42,7 @@ path('querytransforms/delete//', views.QueryTransformViewSet.as_view({'delete': 'destroy'}), name='delete'), path('search/search', views.SearchViewSet.as_view({'get': 'list'}), name='search'), + path('sapi/detail-search-rag/', views.DetailSearchRagView.as_view(), name='detail-search-rag'), path('', views.index, name='index'), path('index.html', views.index, name='index'), diff --git a/swirl/views.py b/swirl/views.py index 34c34cd39..84fe1883b 100644 --- a/swirl/views.py +++ b/swirl/views.py @@ -39,9 +39,10 @@ from swirl.models import * from swirl.serializers import * from swirl.models import SearchProvider, Search, Result, QueryTransform, Authenticator as AuthenticatorModel, OauthToken -from swirl.serializers import UserSerializer, GroupSerializer, SearchProviderSerializer, SearchSerializer, ResultSerializer, QueryTransformSerializer, QueryTransformNoCredentialsSerializer, LoginRequestSerializer, StatusResponseSerializer, AuthResponseSerializer +from swirl.serializers import UserSerializer, DetailSearchRagSerializer, GroupSerializer, SearchProviderSerializer, SearchSerializer, ResultSerializer, QueryTransformSerializer, QueryTransformNoCredentialsSerializer, LoginRequestSerializer, StatusResponseSerializer, AuthResponseSerializer from swirl.authenticators.authenticator import Authenticator from swirl.authenticators import * +from swirl.views_helpers.search_rag import SearchRag module_name = 'views.py' @@ -642,6 +643,23 @@ def partial_update(self, request, pk=None): ######################################## ######################################## +class DetailSearchRagView(APIView): + serializer_class = DetailSearchRagSerializer + authentication_classes = [SessionAuthentication, BasicAuthentication] + permission_classes = [] + + def get(self, request): + logger.debug(f"DetailSearchRagView: {request.GET}") + search_rag = SearchRag(request) + result = search_rag.process_rag() + serializer = DetailSearchRagSerializer(result) + return Response(serializer.data, status=status.HTTP_200_OK) + + +######################################## +######################################## + + class ResultViewSet(viewsets.ModelViewSet): """ API endpoint for managing Result objects, including Mixed Results diff --git a/swirl/views_helpers/search_rag.py b/swirl/views_helpers/search_rag.py new file mode 100644 index 000000000..184fc06ba --- /dev/null +++ b/swirl/views_helpers/search_rag.py @@ -0,0 +1,80 @@ +import logging + +from rest_framework.request import Request + +from swirl.exceptions import RagError +from swirl.models import Result +from swirl.processors import * + +logger = logging.getLogger(__name__) + +instances = {} + + +class SearchRag: + + def __init__(self, request_data: Request) -> None: + request_data = request_data.GET.dict() + logger.info(f"{self}: init search rag {request_data}") + self.search_id = request_data.get("search_id", None) + + # Parse query parameters + self.rag_query_items = [] + rag_query_items = request_data.get("rag_items", [""])[0] + + if rag_query_items: + self.rag_query_items = rag_query_items.split(",") + + def get_rag_result(self) -> tuple[str, dict[str, str]]: + isRagItemsUpdated = False + try: + rag_result = Result.objects.get( + search_id=self.search_id, searchprovider="ChatGPT" + ) + isRagItemsUpdated = True + isRagItemsUpdated = not ( + set(rag_result.json_results[0]["rag_query_items"]) + == set(self.rag_query_items) + ) + except: + pass + try: + rag_result = Result.objects.get( + search_id=self.search_id, searchprovider="ChatGPT" + ) + isRagItemsUpdated = not ( + set(rag_result.json_results[0]["rag_query_items"]) + == set(self.rag_query_items) + ) + if rag_result and not isRagItemsUpdated: + if rag_result.json_results[0]["body"][0]: + return rag_result.json_results[0]["body"][0] + return False + except: + pass + rag_processor = RAGPostResultProcessor( + search_id=self.search_id, + request_id="", + should_get_results=True, + rag_query_items=self.rag_query_items, + ) + instances[self.search_id] = rag_processor + if rag_processor.validate(): + result = rag_processor.process(should_return=True) + try: + if self.search_id in instances: + del instances[self.search_id] + return result.json_results[0]["body"][0] + except: + if self.search_id in instances: + del instances[self.search_id] + return False + + def process_rag(self) -> dict[str, str]: + result = "" + try: + result = self.get_rag_result() + except RagError as err: + logger.error(f"{self}: Rag Error {err}") + + return {"message": result}