Skip to content

Commit 9b4b949

Browse files
feat: add automated model and voices download if it is not yet downloaded
1 parent 77f2976 commit 9b4b949

File tree

1 file changed

+97
-9
lines changed

1 file changed

+97
-9
lines changed

src/rai_s2s/rai_s2s/tts/models/kokoro_tts.py

Lines changed: 97 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313
# limitations under the License.
1414

1515

16+
import os
17+
import subprocess
18+
from pathlib import Path
1619
from typing import Tuple
1720

1821
import numpy as np
@@ -28,42 +31,127 @@ class KokoroTTS(TTSModel):
2831
2932
Parameters
3033
----------
31-
model_path : str, optional
32-
Path to the ONNX model file for Kokoro TTS, by default "kokoro-v0_19.onnx".
33-
voices_path : str, optional
34-
Path to the JSON file containing voice configurations, by default "voices.json".
3534
voice : str, optional
3635
The voice model to use, by default "af_sarah".
3736
language : str, optional
3837
The language code for the TTS model, by default "en-us".
3938
speed : float, optional
4039
The speed of the speech generation, by default 1.0.
40+
cache_dir : str | Path, optional
41+
Directory to cache downloaded models, by default "~/.cache/rai/kokoro/".
42+
4143
Raises
4244
------
4345
TTSModelError
44-
If there is an issue with initializing the Kokoro TTS model.
46+
If there is an issue with initializing the Kokoro TTS model or downloading
47+
required files.
4548
4649
"""
4750

51+
MODEL_URL = "https://github.com/thewh1teagle/kokoro-onnx/releases/download/model-files/kokoro-v0_19.onnx"
52+
VOICES_URL = "https://github.com/thewh1teagle/kokoro-onnx/releases/download/model-files/voices.json"
53+
54+
MODEL_FILENAME = "kokoro-v0_19.onnx"
55+
VOICES_FILENAME = "voices.json"
56+
4857
def __init__(
4958
self,
50-
model_path: str = "kokoro-v0_19.onnx",
51-
voices_path: str = "voices.json",
5259
voice: str = "af_sarah",
5360
language: str = "en-us",
5461
speed: float = 1.0,
62+
cache_dir: str | Path = Path.home() / ".cache/rai/kokoro/",
5563
):
5664
self.voice = voice
5765
self.speed = speed
5866
self.language = language
67+
self.cache_dir = Path(cache_dir)
68+
69+
os.makedirs(self.cache_dir, exist_ok=True)
70+
71+
self.model_path = self._ensure_model_exists()
72+
self.voices_path = self._ensure_voices_exists()
5973

6074
try:
6175
self.kokoro = Kokoro(
62-
model_path=model_path, voices_path=voices_path
63-
) # TODO (mkotynia) add method to download the model ?
76+
model_path=str(self.model_path), voices_path=str(self.voices_path)
77+
)
6478
except Exception as e:
6579
raise TTSModelError(f"Failed to initialize Kokoro TTS model: {e}") from e
6680

81+
def _ensure_model_exists(self) -> Path:
82+
"""
83+
Checks if the model file exists and downloads it if necessary.
84+
85+
Returns
86+
-------
87+
Path
88+
The path to the model file.
89+
90+
Raises
91+
------
92+
TTSModelError
93+
If the model cannot be downloaded or accessed.
94+
"""
95+
model_path = self.cache_dir / self.MODEL_FILENAME
96+
if model_path.exists() and model_path.is_file():
97+
return model_path
98+
99+
self._download_file(self.MODEL_URL, model_path)
100+
return model_path
101+
102+
def _ensure_voices_exists(self) -> Path:
103+
"""
104+
Checks if the voices file exists and downloads it if necessary.
105+
106+
Returns
107+
-------
108+
Path
109+
The path to the voices file.
110+
111+
Raises
112+
------
113+
TTSModelError
114+
If the voices file cannot be downloaded or accessed.
115+
"""
116+
voices_path = self.cache_dir / self.VOICES_FILENAME
117+
if voices_path.exists() and voices_path.is_file():
118+
return voices_path
119+
120+
self._download_file(self.VOICES_URL, voices_path)
121+
return voices_path
122+
123+
def _download_file(self, url: str, destination: Path) -> None:
124+
"""
125+
Downloads a file from a URL to a destination path.
126+
127+
Parameters
128+
----------
129+
url : str
130+
The URL to download from.
131+
destination : Path
132+
The destination path to save the file.
133+
134+
Raises
135+
------
136+
Exception
137+
If the download fails.
138+
"""
139+
try:
140+
subprocess.run(
141+
[
142+
"wget",
143+
url,
144+
"-O",
145+
str(destination),
146+
"--progress=dot:giga",
147+
],
148+
check=True,
149+
)
150+
except subprocess.CalledProcessError as e:
151+
raise Exception(f"Download failed with exit code {e.returncode}") from e
152+
except Exception as e:
153+
raise Exception(f"Download failed: {e}") from e
154+
67155
def get_speech(self, text: str) -> AudioSegment:
68156
"""
69157
Converts text into speech using the Kokoro TTS model.

0 commit comments

Comments
 (0)