From 3b70847d7b07f6a45911aa6a35bc6e0e802673f4 Mon Sep 17 00:00:00 2001 From: chirag gupta <103719146+chiruu12@users.noreply.github.com> Date: Tue, 18 Mar 2025 23:20:22 +0530 Subject: [PATCH] Adding Groq Backend support --- src/rago/generation/__init__.py | 2 + src/rago/generation/groq.py | 82 +++++++++++++++++++++++++++++++++ tests/.env.tpl | 1 + tests/conftest.py | 11 +++++ tests/test_generation.py | 8 ++++ 5 files changed, 104 insertions(+) create mode 100644 src/rago/generation/groq.py diff --git a/src/rago/generation/__init__.py b/src/rago/generation/__init__.py index bc10725..bee32b4 100644 --- a/src/rago/generation/__init__.py +++ b/src/rago/generation/__init__.py @@ -7,6 +7,7 @@ from rago.generation.deepseek import DeepSeekGen from rago.generation.fireworks import FireworksGen from rago.generation.gemini import GeminiGen +from rago.generation.groq import GroqGen from rago.generation.hugging_face import HuggingFaceGen from rago.generation.hugging_face_inf import HuggingFaceInfGen from rago.generation.llama import LlamaGen @@ -19,6 +20,7 @@ 'FireworksGen', 'GeminiGen', 'GenerationBase', + 'GroqGen', 'HuggingFaceGen', 'HuggingFaceInfGen', 'LlamaGen', diff --git a/src/rago/generation/groq.py b/src/rago/generation/groq.py new file mode 100644 index 0000000..4839a49 --- /dev/null +++ b/src/rago/generation/groq.py @@ -0,0 +1,82 @@ +"""Groq class for text generation.""" + +from __future__ import annotations + +from typing import cast + +import instructor +import openai + +from pydantic import BaseModel +from typeguard import typechecked + +from rago.generation.base import GenerationBase + + +@typechecked +class GroqGen(GenerationBase): + """Groq generation model for text generation.""" + + default_model_name = 'gemma2-9b-it' + default_api_params = { # noqa: RUF012 + 'top_p': 1.0, + } + + def _setup(self) -> None: + """Set up the Groq client.""" + groq_api_key = self.api_key + if not groq_api_key: + raise Exception('GROQ_API_KEY environment variable is not set') + + # Can use Groq client as well. + groq_client = openai.OpenAI( + base_url='https://api.groq.com/openai/v1', api_key=groq_api_key + ) + + # Optionally use instructor if structured output is needed + self.model = ( + instructor.from_openai(groq_client) + if self.structured_output + else groq_client + ) + + def generate( + self, + query: str, + context: list[str], + ) -> str | BaseModel: + """Generate text using the Groq AP.""" + input_text = self.prompt_template.format( + query=query, context=' '.join(context) + ) + + if not self.model: + raise Exception('The model was not created.') + + api_params = ( + self.api_params if self.api_params else self.default_api_params + ) + + messages = [] + if self.system_message: + messages.append({'role': 'system', 'content': self.system_message}) + messages.append({'role': 'user', 'content': input_text}) + + model_params = dict( + model=self.model_name, + messages=messages, + max_completion_tokens=self.output_max_length, + temperature=self.temperature, + **api_params, + ) + + if self.structured_output: + model_params['response_model'] = self.structured_output + + response = self.model.chat.completions.create(**model_params) + self.logs['model_params'] = model_params + + if hasattr(response, 'choices') and isinstance(response.choices, list): + return cast(str, response.choices[0].message.content.strip()) + + return cast(BaseModel, response) diff --git a/tests/.env.tpl b/tests/.env.tpl index 9abc40b..b4e03f9 100644 --- a/tests/.env.tpl +++ b/tests/.env.tpl @@ -5,3 +5,4 @@ TOKENIZERS_PARALLELISM=false COHERE_API_KEY=${COHERE_API_KEY} FIREWORKS_API_KEY=${FIREWORKS_API_KEY} TOGETHER_API_KEY=${TOGETHER_API_KEY} +GROQ_API_KEY=${GROQ_API_KEY} diff --git a/tests/conftest.py b/tests/conftest.py index 9b249d3..5820821 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -94,3 +94,14 @@ def api_key_together(env) -> str: 'Please set the TOGETHER_API_KEY environment variable.' ) return key + + +@pytest.fixture +def api_key_groq(env) -> str: + """Fixture for GROQ API key from environment.""" + key = os.getenv('GROQ_API_KEY') + if not key: + raise EnvironmentError( + 'Please set the GROQ_API_KEY environment variable.' + ) + return key diff --git a/tests/test_generation.py b/tests/test_generation.py index c03ea70..171afc8 100644 --- a/tests/test_generation.py +++ b/tests/test_generation.py @@ -10,6 +10,7 @@ DeepSeekGen, FireworksGen, GeminiGen, + GroqGen, HuggingFaceGen, HuggingFaceInfGen, LlamaGen, @@ -32,6 +33,7 @@ CohereGen: 'api_key_cohere', FireworksGen: 'api_key_fireworks', TogetherGen: 'api_key_together', + GroqGen: 'api_key_groq', } gen_models = [ @@ -85,6 +87,10 @@ partial( HuggingFaceInfGen, ), + # model 9 + partial( + GroqGen, + ), ] @@ -98,6 +104,7 @@ def test_generation_simple_output( api_key_gemini: str, api_key_together: str, api_key_hugging_face: str, + api_key_groq: str, partial_model: partial, ) -> None: """Test RAG pipeline with model generation.""" @@ -150,6 +157,7 @@ def test_generation_structure_output( api_key_gemini: str, api_key_together: str, api_key_hugging_face: str, + api_key_groq: str, animals_data: list[str], question: str, partial_model: partial,