Skip to content

Commit 1623df1

Browse files
feat: support for quantized models
1 parent 248a8f7 commit 1623df1

File tree

1 file changed

+39
-5
lines changed

1 file changed

+39
-5
lines changed

src/rai_s2s/rai_s2s/tts/models/kokoro_tts.py

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import re
1818
import subprocess
1919
from pathlib import Path
20-
from typing import Tuple
20+
from typing import Literal, Tuple
2121

2222
import numpy as np
2323
from kokoro_onnx import Kokoro
@@ -49,17 +49,19 @@ class KokoroTTS(TTSModel):
4949
5050
"""
5151

52-
MODEL_URL = "https://github.com/thewh1teagle/kokoro-onnx/releases/download/model-files/kokoro-v0_19.onnx"
52+
BASE_MODEL_URL = (
53+
"https://github.com/thewh1teagle/kokoro-onnx/releases/download/model-files/"
54+
)
5355
VOICES_URL = "https://github.com/thewh1teagle/kokoro-onnx/releases/download/model-files/voices.json"
5456

55-
MODEL_FILENAME = "kokoro-v0_19.onnx"
5657
VOICES_FILENAME = "voices.json"
5758

5859
def __init__(
5960
self,
6061
voice: str = "af_sarah",
6162
language: str = "en-us",
6263
speed: float = 1.0,
64+
model_size: Literal["small", "medium", "large"] = "large",
6365
cache_dir: str | Path = Path.home() / ".cache/rai/kokoro/",
6466
):
6567
self.voice = voice
@@ -69,6 +71,7 @@ def __init__(
6971

7072
os.makedirs(self.cache_dir, exist_ok=True)
7173

74+
self.model_size = model_size
7275
self.model_path = self._ensure_model_exists()
7376
self.voices_path = self._ensure_voices_exists()
7477

@@ -93,11 +96,13 @@ def _ensure_model_exists(self) -> Path:
9396
TTSModelError
9497
If the model cannot be downloaded or accessed.
9598
"""
96-
model_path = self.cache_dir / self.MODEL_FILENAME
99+
model_filename = self._get_model_filename()
100+
model_path = self.cache_dir / model_filename
97101
if model_path.exists() and model_path.is_file():
98102
return model_path
99103

100-
self._download_file(self.MODEL_URL, model_path)
104+
model_url = self._get_model_url()
105+
self._download_file(model_url, model_path)
101106
return model_path
102107

103108
def _ensure_voices_exists(self) -> Path:
@@ -282,3 +287,32 @@ def get_available_voices(self) -> list[str]:
282287
return list(self.kokoro.get_voices())
283288
except Exception as e:
284289
raise TTSModelError(f"Failed to retrieve voice names: {e}")
290+
291+
def _get_model_filename(self) -> str:
292+
"""
293+
Gets the model filename based on the model size.
294+
295+
Returns
296+
-------
297+
str
298+
The model filename for the specified model size.
299+
"""
300+
if self.model_size == "large":
301+
return "kokoro-v0_19.onnx"
302+
elif self.model_size == "medium":
303+
return "kokoro-v0_19.fp16.onnx"
304+
elif self.model_size == "small":
305+
return "kokoro-v0_19.int8.onnx"
306+
else:
307+
raise TTSModelError(f"Unsupported model size: {self.model_size}")
308+
309+
def _get_model_url(self) -> str:
310+
"""
311+
Gets the full model URL based on the model size.
312+
313+
Returns
314+
-------
315+
str
316+
The full URL for downloading the model.
317+
"""
318+
return self.BASE_MODEL_URL + self._get_model_filename()

0 commit comments

Comments
 (0)