Skip to content

Commit db852ae

Browse files
authored
feat: Add support for fireworks-ai (osl-incubator#74)
1 parent a19cb93 commit db852ae

File tree

10 files changed

+238
-8
lines changed

10 files changed

+238
-8
lines changed

poetry.lock

Lines changed: 59 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,8 @@ dependencies = [
3838
"torch >=2,<2.6",
3939
"eval-type-backport >=0.2 ; python_version < '3.10'",
4040
"joblib (>=1.4.2,<2.0.0)",
41-
"cohere >=5.13.4"
41+
"cohere >=5.13.4",
42+
"fireworks-ai>0.15.10"
4243
]
4344

4445
[build-system]

src/rago/augmented/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,15 @@
44

55
from rago.augmented.base import AugmentedBase
66
from rago.augmented.cohere import CohereAug
7+
from rago.augmented.fireworks import FireworksAug
78
from rago.augmented.openai import OpenAIAug
89
from rago.augmented.sentence_transformer import SentenceTransformerAug
910
from rago.augmented.spacy import SpaCyAug
1011

1112
__all__ = [
1213
'AugmentedBase',
1314
'CohereAug',
15+
'FireworksAug',
1416
'OpenAIAug',
1517
'SentenceTransformerAug',
1618
'SpaCyAug',

src/rago/augmented/fireworks.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
"""Classes for augmentation with Fireworks embeddings."""
2+
3+
from __future__ import annotations
4+
5+
from hashlib import sha256
6+
from typing import cast
7+
8+
import numpy as np
9+
import openai # fireworks client doesnt have query
10+
11+
12+
# embeddings model feature yet
13+
from typeguard import typechecked
14+
15+
from rago.augmented.base import AugmentedBase, EmbeddingType
16+
17+
18+
@typechecked
19+
class FireworksAug(AugmentedBase):
20+
"""Class for augmentation with Fireworks embeddings."""
21+
22+
default_model_name = 'nomic-ai/nomic-embed-text-v1.5' # embedding model
23+
default_top_k = 3
24+
25+
def _setup(self) -> None:
26+
"""Set up the object with initial parameters."""
27+
if not self.api_key:
28+
raise ValueError('API key for Fireworks is required.')
29+
self.openai_client = openai.OpenAI(
30+
base_url='https://api.fireworks.ai/inference/v1',
31+
api_key=self.api_key,
32+
)
33+
34+
def get_embedding(self, content: list[str]) -> EmbeddingType:
35+
"""Retrieve the embedding for given texts using the OpenAI client."""
36+
cache_key = sha256(''.join(content).encode('utf-8')).hexdigest()
37+
cached = self._get_cache(cache_key)
38+
if cached is not None:
39+
return cast(EmbeddingType, cached)
40+
41+
# Using the OpenAI embeddings API call for fireworks
42+
response = self.openai_client.embeddings.create(
43+
model=self.model_name,
44+
input=content,
45+
)
46+
result = np.array(
47+
[data.embedding for data in response.data], dtype=np.float32
48+
)
49+
self._save_cache(cache_key, result)
50+
return result
51+
52+
def search(
53+
self, query: str, documents: list[str], top_k: int = 0
54+
) -> list[str]:
55+
"""Search an encoded query into vector database."""
56+
if not hasattr(self, 'db') or not self.db:
57+
raise Exception('Vector database (db) is not initialized.')
58+
59+
document_encoded = self.get_embedding(documents)
60+
query_encoded = self.get_embedding([query])
61+
top_k = top_k or self.top_k or self.default_top_k or 1
62+
63+
self.db.embed(document_encoded)
64+
scores, indices = self.db.search(query_encoded, top_k=top_k)
65+
66+
self.logs['indices'] = indices
67+
self.logs['scores'] = scores
68+
self.logs['search_params'] = {
69+
'query_encoded': query_encoded,
70+
'top_k': top_k,
71+
}
72+
73+
retrieved_docs = [documents[i] for i in indices if i >= 0]
74+
75+
return retrieved_docs

src/rago/generation/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from rago.generation.base import GenerationBase
66
from rago.generation.cohere import CohereGen
77
from rago.generation.deepseek import DeepSeekGen
8+
from rago.generation.fireworks import FireworksGen
89
from rago.generation.gemini import GeminiGen
910
from rago.generation.hugging_face import HuggingFaceGen
1011
from rago.generation.llama import LlamaGen
@@ -13,6 +14,7 @@
1314
__all__ = [
1415
'CohereGen',
1516
'DeepSeekGen',
17+
'FireworksGen',
1618
'GeminiGen',
1719
'GenerationBase',
1820
'HuggingFaceGen',

src/rago/generation/fireworks.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
"""FireworksGen class for text generation using Fireworks API."""
2+
3+
from __future__ import annotations
4+
5+
from typing import cast
6+
7+
import instructor
8+
9+
from fireworks.client import Fireworks
10+
from pydantic import BaseModel
11+
from typeguard import typechecked
12+
13+
from rago.generation.base import GenerationBase
14+
15+
16+
@typechecked
17+
class FireworksGen(GenerationBase):
18+
"""Fireworks AI generation model for text generation."""
19+
20+
default_model_name: str = 'accounts/fireworks/models/llama-v3-8b-instruct'
21+
default_api_params = { # noqa: RUF012
22+
'top_p': 0.9,
23+
}
24+
25+
def _setup(self) -> None:
26+
"""Set up the object with the initial parameters."""
27+
model = Fireworks(api_key=self.api_key)
28+
29+
self.model = (
30+
instructor.from_fireworks(
31+
client=model,
32+
mode=instructor.Mode.FIREWORKS_JSON,
33+
)
34+
if self.structured_output
35+
else model
36+
)
37+
38+
def generate(self, query: str, context: list[str]) -> str | BaseModel:
39+
"""Generate text using Fireworks AI's API."""
40+
input_text = self.prompt_template.format(
41+
query=query, context=' '.join(context)
42+
)
43+
44+
api_params = self.api_params or self.default_api_params
45+
46+
messages = []
47+
if self.system_message:
48+
messages.append({'role': 'system', 'content': self.system_message})
49+
messages.append({'role': 'user', 'content': input_text})
50+
51+
model_params = {
52+
'model': self.model_name,
53+
'messages': messages,
54+
'max_tokens': self.output_max_length,
55+
'temperature': self.temperature,
56+
**api_params,
57+
}
58+
59+
if self.structured_output:
60+
model_params['response_model'] = self.structured_output
61+
response = self.model.chat.completions.create(**model_params)
62+
self.logs['model_params'] = model_params
63+
return cast(BaseModel, response)
64+
65+
response = self.model.chat.completions.create(**model_params)
66+
self.logs['model_params'] = model_params
67+
return cast(str, response.choices[0].message.content.strip())

tests/.env.tpl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,5 @@ OPENAI_API_KEY=${OPENAI_API_KEY}
33
GEMINI_API_KEY=${GEMINI_API_KEY}
44
TOKENIZERS_PARALLELISM=false
55
COHERE_API_KEY=${COHERE_API_KEY}
6-
#FIREWORKS_API_KEY=${FIREWORKS_API_KEY}
6+
FIREWORKS_API_KEY=${FIREWORKS_API_KEY}
77
#TOGETHER_API_KEY=${TOGETHER_API_KEY}

tests/conftest.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,3 +74,14 @@ def api_key_cohere(env) -> str:
7474
'Please set the COHERE_API_KEY environment variable.'
7575
)
7676
return key
77+
78+
79+
@pytest.fixture
80+
def api_key_fireworks(env) -> str:
81+
"""Fixture for Fireworks API key from environment."""
82+
key = os.getenv('FIREWORKS_API_KEY')
83+
if not key:
84+
raise EnvironmentError(
85+
'Please set the FIREWORKS_API_KEY environment variable.'
86+
)
87+
return key

tests/test_augmentation.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,12 @@
44

55
import pytest
66

7-
from rago.augmented import CohereAug, OpenAIAug, SpaCyAug
7+
from rago.augmented import CohereAug, FireworksAug, OpenAIAug, SpaCyAug
88

99
API_MAP = {
1010
OpenAIAug: 'api_key_openai',
1111
CohereAug: 'api_key_cohere',
12+
FireworksAug: 'api_key_fireworks',
1213
}
1314

1415
gen_models = [
@@ -26,9 +27,14 @@
2627
model_name='text-embedding-3-small',
2728
),
2829
),
30+
# model 2
2931
partial(
3032
CohereAug,
3133
),
34+
# model 3
35+
partial(
36+
FireworksAug,
37+
),
3238
]
3339

3440

@@ -52,6 +58,7 @@ def test_aug_spacy(
5258
api_key_openai: str,
5359
api_key_cohere: str,
5460
api_key_gemini: str,
61+
api_key_fireworks: str,
5562
api_key_hugging_face: str,
5663
partial_model: partial,
5764
) -> None:

0 commit comments

Comments
 (0)