13
13
# limitations under the License.
14
14
15
15
16
+ import os
17
+ import subprocess
18
+ from pathlib import Path
16
19
from typing import Tuple
17
20
18
21
import numpy as np
@@ -28,42 +31,127 @@ class KokoroTTS(TTSModel):
28
31
29
32
Parameters
30
33
----------
31
- model_path : str, optional
32
- Path to the ONNX model file for Kokoro TTS, by default "kokoro-v0_19.onnx".
33
- voices_path : str, optional
34
- Path to the JSON file containing voice configurations, by default "voices.json".
35
34
voice : str, optional
36
35
The voice model to use, by default "af_sarah".
37
36
language : str, optional
38
37
The language code for the TTS model, by default "en-us".
39
38
speed : float, optional
40
39
The speed of the speech generation, by default 1.0.
40
+ cache_dir : str | Path, optional
41
+ Directory to cache downloaded models, by default "~/.cache/rai/kokoro/".
42
+
41
43
Raises
42
44
------
43
45
TTSModelError
44
- If there is an issue with initializing the Kokoro TTS model.
46
+ If there is an issue with initializing the Kokoro TTS model or downloading
47
+ required files.
45
48
46
49
"""
47
50
51
+ MODEL_URL = "https://github.com/thewh1teagle/kokoro-onnx/releases/download/model-files/kokoro-v0_19.onnx"
52
+ VOICES_URL = "https://github.com/thewh1teagle/kokoro-onnx/releases/download/model-files/voices.json"
53
+
54
+ MODEL_FILENAME = "kokoro-v0_19.onnx"
55
+ VOICES_FILENAME = "voices.json"
56
+
48
57
def __init__ (
49
58
self ,
50
- model_path : str = "kokoro-v0_19.onnx" ,
51
- voices_path : str = "voices.json" ,
52
59
voice : str = "af_sarah" ,
53
60
language : str = "en-us" ,
54
61
speed : float = 1.0 ,
62
+ cache_dir : str | Path = Path .home () / ".cache/rai/kokoro/" ,
55
63
):
56
64
self .voice = voice
57
65
self .speed = speed
58
66
self .language = language
67
+ self .cache_dir = Path (cache_dir )
68
+
69
+ os .makedirs (self .cache_dir , exist_ok = True )
70
+
71
+ self .model_path = self ._ensure_model_exists ()
72
+ self .voices_path = self ._ensure_voices_exists ()
59
73
60
74
try :
61
75
self .kokoro = Kokoro (
62
- model_path = model_path , voices_path = voices_path
63
- ) # TODO (mkotynia) add method to download the model ?
76
+ model_path = str ( self . model_path ) , voices_path = str ( self . voices_path )
77
+ )
64
78
except Exception as e :
65
79
raise TTSModelError (f"Failed to initialize Kokoro TTS model: { e } " ) from e
66
80
81
+ def _ensure_model_exists (self ) -> Path :
82
+ """
83
+ Checks if the model file exists and downloads it if necessary.
84
+
85
+ Returns
86
+ -------
87
+ Path
88
+ The path to the model file.
89
+
90
+ Raises
91
+ ------
92
+ TTSModelError
93
+ If the model cannot be downloaded or accessed.
94
+ """
95
+ model_path = self .cache_dir / self .MODEL_FILENAME
96
+ if model_path .exists () and model_path .is_file ():
97
+ return model_path
98
+
99
+ self ._download_file (self .MODEL_URL , model_path )
100
+ return model_path
101
+
102
+ def _ensure_voices_exists (self ) -> Path :
103
+ """
104
+ Checks if the voices file exists and downloads it if necessary.
105
+
106
+ Returns
107
+ -------
108
+ Path
109
+ The path to the voices file.
110
+
111
+ Raises
112
+ ------
113
+ TTSModelError
114
+ If the voices file cannot be downloaded or accessed.
115
+ """
116
+ voices_path = self .cache_dir / self .VOICES_FILENAME
117
+ if voices_path .exists () and voices_path .is_file ():
118
+ return voices_path
119
+
120
+ self ._download_file (self .VOICES_URL , voices_path )
121
+ return voices_path
122
+
123
+ def _download_file (self , url : str , destination : Path ) -> None :
124
+ """
125
+ Downloads a file from a URL to a destination path.
126
+
127
+ Parameters
128
+ ----------
129
+ url : str
130
+ The URL to download from.
131
+ destination : Path
132
+ The destination path to save the file.
133
+
134
+ Raises
135
+ ------
136
+ Exception
137
+ If the download fails.
138
+ """
139
+ try :
140
+ subprocess .run (
141
+ [
142
+ "wget" ,
143
+ url ,
144
+ "-O" ,
145
+ str (destination ),
146
+ "--progress=dot:giga" ,
147
+ ],
148
+ check = True ,
149
+ )
150
+ except subprocess .CalledProcessError as e :
151
+ raise Exception (f"Download failed with exit code { e .returncode } " ) from e
152
+ except Exception as e :
153
+ raise Exception (f"Download failed: { e } " ) from e
154
+
67
155
def get_speech (self , text : str ) -> AudioSegment :
68
156
"""
69
157
Converts text into speech using the Kokoro TTS model.
0 commit comments