Skip to content

feat(generation): add backend support for Groq #86

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/rago/generation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from rago.generation.deepseek import DeepSeekGen
from rago.generation.fireworks import FireworksGen
from rago.generation.gemini import GeminiGen
from rago.generation.groq import GroqGen
from rago.generation.hugging_face import HuggingFaceGen
from rago.generation.hugging_face_inf import HuggingFaceInfGen
from rago.generation.llama import LlamaGen
Expand All @@ -19,6 +20,7 @@
'FireworksGen',
'GeminiGen',
'GenerationBase',
'GroqGen',
'HuggingFaceGen',
'HuggingFaceInfGen',
'LlamaGen',
Expand Down
82 changes: 82 additions & 0 deletions src/rago/generation/groq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
"""Groq class for text generation."""

from __future__ import annotations

from typing import cast

import instructor
import openai

from pydantic import BaseModel
from typeguard import typechecked

from rago.generation.base import GenerationBase


@typechecked
class GroqGen(GenerationBase):
"""Groq generation model for text generation."""

default_model_name = 'gemma2-9b-it'
default_api_params = { # noqa: RUF012
'top_p': 1.0,
}

def _setup(self) -> None:
"""Set up the Groq client."""
groq_api_key = self.api_key
if not groq_api_key:
raise Exception('GROQ_API_KEY environment variable is not set')

# Can use Groq client as well.
groq_client = openai.OpenAI(
base_url='https://api.groq.com/openai/v1', api_key=groq_api_key
)

# Optionally use instructor if structured output is needed
self.model = (
instructor.from_openai(groq_client)
if self.structured_output
else groq_client
)

def generate(
self,
query: str,
context: list[str],
) -> str | BaseModel:
"""Generate text using the Groq AP."""
input_text = self.prompt_template.format(
query=query, context=' '.join(context)
)

if not self.model:
raise Exception('The model was not created.')

api_params = (
self.api_params if self.api_params else self.default_api_params
)

messages = []
if self.system_message:
messages.append({'role': 'system', 'content': self.system_message})
messages.append({'role': 'user', 'content': input_text})

model_params = dict(
model=self.model_name,
messages=messages,
max_completion_tokens=self.output_max_length,
temperature=self.temperature,
**api_params,
)

if self.structured_output:
model_params['response_model'] = self.structured_output

response = self.model.chat.completions.create(**model_params)
self.logs['model_params'] = model_params

if hasattr(response, 'choices') and isinstance(response.choices, list):
return cast(str, response.choices[0].message.content.strip())

return cast(BaseModel, response)
1 change: 1 addition & 0 deletions tests/.env.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ TOKENIZERS_PARALLELISM=false
COHERE_API_KEY=${COHERE_API_KEY}
FIREWORKS_API_KEY=${FIREWORKS_API_KEY}
TOGETHER_API_KEY=${TOGETHER_API_KEY}
GROQ_API_KEY=${GROQ_API_KEY}
11 changes: 11 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,3 +94,14 @@ def api_key_together(env) -> str:
'Please set the TOGETHER_API_KEY environment variable.'
)
return key


@pytest.fixture
def api_key_groq(env) -> str:
"""Fixture for GROQ API key from environment."""
key = os.getenv('GROQ_API_KEY')
if not key:
raise EnvironmentError(
'Please set the GROQ_API_KEY environment variable.'
)
return key
8 changes: 8 additions & 0 deletions tests/test_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
DeepSeekGen,
FireworksGen,
GeminiGen,
GroqGen,
HuggingFaceGen,
HuggingFaceInfGen,
LlamaGen,
Expand All @@ -32,6 +33,7 @@
CohereGen: 'api_key_cohere',
FireworksGen: 'api_key_fireworks',
TogetherGen: 'api_key_together',
GroqGen: 'api_key_groq',
}

gen_models = [
Expand Down Expand Up @@ -85,6 +87,10 @@
partial(
HuggingFaceInfGen,
),
# model 9
partial(
GroqGen,
),
]


Expand All @@ -98,6 +104,7 @@ def test_generation_simple_output(
api_key_gemini: str,
api_key_together: str,
api_key_hugging_face: str,
api_key_groq: str,
partial_model: partial,
) -> None:
"""Test RAG pipeline with model generation."""
Expand Down Expand Up @@ -150,6 +157,7 @@ def test_generation_structure_output(
api_key_gemini: str,
api_key_together: str,
api_key_hugging_face: str,
api_key_groq: str,
animals_data: list[str],
question: str,
partial_model: partial,
Expand Down
Loading