Skip to content

Commit 5f2e4eb

Browse files
committed
Merge remote-tracking branch 'upstream/main'
2 parents 909039d + db852ae commit 5f2e4eb

File tree

14 files changed

+1019
-544
lines changed

14 files changed

+1019
-544
lines changed

.github/workflows/main.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ jobs:
4040
# - "3.13"
4141
os:
4242
- "ubuntu"
43+
- "macos"
4344

4445
runs-on: ${{ matrix.os }}-latest
4546
timeout-minutes: 20

conda/dev.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ channels:
33
- nodefaults
44
- conda-forge
55
dependencies:
6-
- python <3.13
6+
- python 3.9.*
77
- pip
88
- poetry >=2
99
- nodejs # used by semantic-release

docs/changelog.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,19 @@
11
# Release Notes
22
---
33

4+
# [0.13.0](https://github.com/osl-incubator/rago/compare/0.12.0...0.13.0) (2025-03-13)
5+
6+
7+
### Bug Fixes
8+
9+
* **pkg:** Add support for MacOS ([#75](https://github.com/osl-incubator/rago/issues/75)) ([e6a33b0](https://github.com/osl-incubator/rago/commit/e6a33b0d967d21130cf005789989dc9e1c28c4fb))
10+
11+
12+
### Features
13+
14+
* Add cohere backend support ([#62](https://github.com/osl-incubator/rago/issues/62)) ([6817ba0](https://github.com/osl-incubator/rago/commit/6817ba08607e5366b4b36dba7b91644eced7edb7))
15+
* **generation:** add backend for DeepSeek's generation class ([#49](https://github.com/osl-incubator/rago/issues/49)) ([47947d6](https://github.com/osl-incubator/rago/commit/47947d65105db88c7a021ae7da8b48cff8ce58d1))
16+
417
# [0.12.0](https://github.com/osl-incubator/rago/compare/0.11.3...0.12.0) (2025-02-11)
518

619

poetry.lock

Lines changed: 804 additions & 517 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "rago"
3-
version = "0.12.0" # semantic-release
3+
version = "0.13.0" # semantic-release
44
description = "Rago is a lightweight framework for RAG"
55
readme = "README.md"
66
authors = [
@@ -19,6 +19,8 @@ requires-python = ">=3.9,<3.13"
1919

2020
dependencies = [
2121
"typeguard >=4.0",
22+
"numpy >=2,<=2.2 ; python_version > '3.9'",
23+
"numpy >=2,<2.1 ; python_version < '3.10'",
2224
"faiss-cpu >=1.9.0",
2325
"sentencepiece >=0.2.0",
2426
"sentence-transformers >=3.2.0",
@@ -29,39 +31,37 @@ dependencies = [
2931
"pypdf >=5",
3032
"langchain >=0.3.7",
3133
"langchain-community >=0.3.7",
32-
"spacy >=3",
34+
"spacy >=3.8.0 ; python_version > '3.9'",
35+
"spacy >=3.8.0,<3.8.4 ; python_version < '3.10'",
3336
"instructor >=1",
3437
"pydantic >=2",
35-
"torch >=2.5",
36-
"torchvision >=0.20",
38+
"torch >=2,<2.6",
3739
"eval-type-backport >=0.2 ; python_version < '3.10'",
3840
"joblib (>=1.4.2,<2.0.0)",
39-
"cohere >=5.13.4"
41+
"cohere >=5.13.4",
42+
"fireworks-ai>0.15.10"
4043
]
4144

4245
[build-system]
4346
requires = ["poetry-core>=2", "poetry>=2"]
4447
build-backend = "poetry.core.masonry.api"
4548

46-
[tool.poetry.extras]
47-
cpu = ["torch", "torchvision"]
48-
gpu = ["torch", "torchvision"]
49+
# [project.optional-dependencies]
50+
# cpu = ["torch"]
51+
# gpu = ["torch"]
4952

50-
[[tool.poetry.source]]
51-
name = "pytorch-cpu"
52-
url = "https://download.pytorch.org/whl/cpu"
53-
priority = "supplemental"
53+
# [[tool.poetry.source]]
54+
# name = "pytorch-cpu"
55+
# url = "https://download.pytorch.org/whl/cpu"
56+
# priority = "explicit"
5457

55-
[tool.poetry.dependencies]
56-
python = ">=3.9,<3.13"
57-
torch = [
58-
{version = ">=2.5.0", markers="extra!='gpu'", source="pytorch-cpu"},
59-
{version = ">=2.5.0", markers="extra=='gpu'"},
60-
]
61-
torchvision = [
62-
{version = ">=0.20.0", markers="extra!='gpu'", source="pytorch-cpu"},
63-
{version = ">=0.20.0", markers="extra=='gpu'"},
64-
]
58+
# [tool.poetry.dependencies]
59+
# python = ">=3.9,<3.13"
60+
# torch = [
61+
# { version = ">=2,<2.6", platform = "darwin" },
62+
# { version = ">=2,<2.6", platform = "linux", source="pytorch-cpu", markers="extra!='gpu'" },
63+
# { version = ">=2,<2.6", platform = "linux", markers="extra=='gpu'" },
64+
# ]
6565

6666
[tool.poetry.group.dev.dependencies]
6767
pytest = ">=7.3.2"

src/rago/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ def get_version() -> str:
1010
try:
1111
return importlib_metadata.version(__name__)
1212
except importlib_metadata.PackageNotFoundError: # pragma: no cover
13-
return '0.12.0' # semantic-release
13+
return '0.13.0' # semantic-release
1414

1515

1616
version = get_version()

src/rago/augmented/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,15 @@
44

55
from rago.augmented.base import AugmentedBase
66
from rago.augmented.cohere import CohereAug
7+
from rago.augmented.fireworks import FireworksAug
78
from rago.augmented.openai import OpenAIAug
89
from rago.augmented.sentence_transformer import SentenceTransformerAug
910
from rago.augmented.spacy import SpaCyAug
1011

1112
__all__ = [
1213
'AugmentedBase',
1314
'CohereAug',
15+
'FireworksAug',
1416
'OpenAIAug',
1517
'SentenceTransformerAug',
1618
'SpaCyAug',

src/rago/augmented/fireworks.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
"""Classes for augmentation with Fireworks embeddings."""
2+
3+
from __future__ import annotations
4+
5+
from hashlib import sha256
6+
from typing import cast
7+
8+
import numpy as np
9+
import openai # fireworks client doesnt have query
10+
11+
12+
# embeddings model feature yet
13+
from typeguard import typechecked
14+
15+
from rago.augmented.base import AugmentedBase, EmbeddingType
16+
17+
18+
@typechecked
19+
class FireworksAug(AugmentedBase):
20+
"""Class for augmentation with Fireworks embeddings."""
21+
22+
default_model_name = 'nomic-ai/nomic-embed-text-v1.5' # embedding model
23+
default_top_k = 3
24+
25+
def _setup(self) -> None:
26+
"""Set up the object with initial parameters."""
27+
if not self.api_key:
28+
raise ValueError('API key for Fireworks is required.')
29+
self.openai_client = openai.OpenAI(
30+
base_url='https://api.fireworks.ai/inference/v1',
31+
api_key=self.api_key,
32+
)
33+
34+
def get_embedding(self, content: list[str]) -> EmbeddingType:
35+
"""Retrieve the embedding for given texts using the OpenAI client."""
36+
cache_key = sha256(''.join(content).encode('utf-8')).hexdigest()
37+
cached = self._get_cache(cache_key)
38+
if cached is not None:
39+
return cast(EmbeddingType, cached)
40+
41+
# Using the OpenAI embeddings API call for fireworks
42+
response = self.openai_client.embeddings.create(
43+
model=self.model_name,
44+
input=content,
45+
)
46+
result = np.array(
47+
[data.embedding for data in response.data], dtype=np.float32
48+
)
49+
self._save_cache(cache_key, result)
50+
return result
51+
52+
def search(
53+
self, query: str, documents: list[str], top_k: int = 0
54+
) -> list[str]:
55+
"""Search an encoded query into vector database."""
56+
if not hasattr(self, 'db') or not self.db:
57+
raise Exception('Vector database (db) is not initialized.')
58+
59+
document_encoded = self.get_embedding(documents)
60+
query_encoded = self.get_embedding([query])
61+
top_k = top_k or self.top_k or self.default_top_k or 1
62+
63+
self.db.embed(document_encoded)
64+
scores, indices = self.db.search(query_encoded, top_k=top_k)
65+
66+
self.logs['indices'] = indices
67+
self.logs['scores'] = scores
68+
self.logs['search_params'] = {
69+
'query_encoded': query_encoded,
70+
'top_k': top_k,
71+
}
72+
73+
retrieved_docs = [documents[i] for i in indices if i >= 0]
74+
75+
return retrieved_docs

src/rago/generation/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from rago.generation.base import GenerationBase
66
from rago.generation.cohere import CohereGen
77
from rago.generation.deepseek import DeepSeekGen
8+
from rago.generation.fireworks import FireworksGen
89
from rago.generation.gemini import GeminiGen
910
from rago.generation.hugging_face import HuggingFaceGen
1011
from rago.generation.llama import LlamaGen
@@ -13,6 +14,7 @@
1314
__all__ = [
1415
'CohereGen',
1516
'DeepSeekGen',
17+
'FireworksGen',
1618
'GeminiGen',
1719
'GenerationBase',
1820
'HuggingFaceGen',

src/rago/generation/fireworks.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
"""FireworksGen class for text generation using Fireworks API."""
2+
3+
from __future__ import annotations
4+
5+
from typing import cast
6+
7+
import instructor
8+
9+
from fireworks.client import Fireworks
10+
from pydantic import BaseModel
11+
from typeguard import typechecked
12+
13+
from rago.generation.base import GenerationBase
14+
15+
16+
@typechecked
17+
class FireworksGen(GenerationBase):
18+
"""Fireworks AI generation model for text generation."""
19+
20+
default_model_name: str = 'accounts/fireworks/models/llama-v3-8b-instruct'
21+
default_api_params = { # noqa: RUF012
22+
'top_p': 0.9,
23+
}
24+
25+
def _setup(self) -> None:
26+
"""Set up the object with the initial parameters."""
27+
model = Fireworks(api_key=self.api_key)
28+
29+
self.model = (
30+
instructor.from_fireworks(
31+
client=model,
32+
mode=instructor.Mode.FIREWORKS_JSON,
33+
)
34+
if self.structured_output
35+
else model
36+
)
37+
38+
def generate(self, query: str, context: list[str]) -> str | BaseModel:
39+
"""Generate text using Fireworks AI's API."""
40+
input_text = self.prompt_template.format(
41+
query=query, context=' '.join(context)
42+
)
43+
44+
api_params = self.api_params or self.default_api_params
45+
46+
messages = []
47+
if self.system_message:
48+
messages.append({'role': 'system', 'content': self.system_message})
49+
messages.append({'role': 'user', 'content': input_text})
50+
51+
model_params = {
52+
'model': self.model_name,
53+
'messages': messages,
54+
'max_tokens': self.output_max_length,
55+
'temperature': self.temperature,
56+
**api_params,
57+
}
58+
59+
if self.structured_output:
60+
model_params['response_model'] = self.structured_output
61+
response = self.model.chat.completions.create(**model_params)
62+
self.logs['model_params'] = model_params
63+
return cast(BaseModel, response)
64+
65+
response = self.model.chat.completions.create(**model_params)
66+
self.logs['model_params'] = model_params
67+
return cast(str, response.choices[0].message.content.strip())

tests/.env.tpl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,5 @@ OPENAI_API_KEY=${OPENAI_API_KEY}
33
GEMINI_API_KEY=${GEMINI_API_KEY}
44
TOKENIZERS_PARALLELISM=false
55
COHERE_API_KEY=${COHERE_API_KEY}
6-
#FIREWORKS_API_KEY=${FIREWORKS_API_KEY}
6+
FIREWORKS_API_KEY=${FIREWORKS_API_KEY}
77
#TOGETHER_API_KEY=${TOGETHER_API_KEY}

tests/conftest.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,3 +74,14 @@ def api_key_cohere(env) -> str:
7474
'Please set the COHERE_API_KEY environment variable.'
7575
)
7676
return key
77+
78+
79+
@pytest.fixture
80+
def api_key_fireworks(env) -> str:
81+
"""Fixture for Fireworks API key from environment."""
82+
key = os.getenv('FIREWORKS_API_KEY')
83+
if not key:
84+
raise EnvironmentError(
85+
'Please set the FIREWORKS_API_KEY environment variable.'
86+
)
87+
return key

tests/test_augmentation.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,12 @@
44

55
import pytest
66

7-
from rago.augmented import CohereAug, OpenAIAug, SpaCyAug
7+
from rago.augmented import CohereAug, FireworksAug, OpenAIAug, SpaCyAug
88

99
API_MAP = {
1010
OpenAIAug: 'api_key_openai',
1111
CohereAug: 'api_key_cohere',
12+
FireworksAug: 'api_key_fireworks',
1213
}
1314

1415
gen_models = [
@@ -26,9 +27,14 @@
2627
model_name='text-embedding-3-small',
2728
),
2829
),
30+
# model 2
2931
partial(
3032
CohereAug,
3133
),
34+
# model 3
35+
partial(
36+
FireworksAug,
37+
),
3238
]
3339

3440

@@ -52,6 +58,7 @@ def test_aug_spacy(
5258
api_key_openai: str,
5359
api_key_cohere: str,
5460
api_key_gemini: str,
61+
api_key_fireworks: str,
5562
api_key_hugging_face: str,
5663
partial_model: partial,
5764
) -> None:

0 commit comments

Comments
 (0)