17
17
import re
18
18
import subprocess
19
19
from pathlib import Path
20
- from typing import Tuple
20
+ from typing import Literal , Tuple
21
21
22
22
import numpy as np
23
23
from kokoro_onnx import Kokoro
@@ -49,17 +49,19 @@ class KokoroTTS(TTSModel):
49
49
50
50
"""
51
51
52
- MODEL_URL = "https://github.com/thewh1teagle/kokoro-onnx/releases/download/model-files/kokoro-v0_19.onnx"
52
+ BASE_MODEL_URL = (
53
+ "https://github.com/thewh1teagle/kokoro-onnx/releases/download/model-files/"
54
+ )
53
55
VOICES_URL = "https://github.com/thewh1teagle/kokoro-onnx/releases/download/model-files/voices.json"
54
56
55
- MODEL_FILENAME = "kokoro-v0_19.onnx"
56
57
VOICES_FILENAME = "voices.json"
57
58
58
59
def __init__ (
59
60
self ,
60
61
voice : str = "af_sarah" ,
61
62
language : str = "en-us" ,
62
63
speed : float = 1.0 ,
64
+ model_size : Literal ["small" , "medium" , "large" ] = "large" ,
63
65
cache_dir : str | Path = Path .home () / ".cache/rai/kokoro/" ,
64
66
):
65
67
self .voice = voice
@@ -69,6 +71,7 @@ def __init__(
69
71
70
72
os .makedirs (self .cache_dir , exist_ok = True )
71
73
74
+ self .model_size = model_size
72
75
self .model_path = self ._ensure_model_exists ()
73
76
self .voices_path = self ._ensure_voices_exists ()
74
77
@@ -93,11 +96,13 @@ def _ensure_model_exists(self) -> Path:
93
96
TTSModelError
94
97
If the model cannot be downloaded or accessed.
95
98
"""
96
- model_path = self .cache_dir / self .MODEL_FILENAME
99
+ model_filename = self ._get_model_filename ()
100
+ model_path = self .cache_dir / model_filename
97
101
if model_path .exists () and model_path .is_file ():
98
102
return model_path
99
103
100
- self ._download_file (self .MODEL_URL , model_path )
104
+ model_url = self ._get_model_url ()
105
+ self ._download_file (model_url , model_path )
101
106
return model_path
102
107
103
108
def _ensure_voices_exists (self ) -> Path :
@@ -282,3 +287,32 @@ def get_available_voices(self) -> list[str]:
282
287
return list (self .kokoro .get_voices ())
283
288
except Exception as e :
284
289
raise TTSModelError (f"Failed to retrieve voice names: { e } " )
290
+
291
+ def _get_model_filename (self ) -> str :
292
+ """
293
+ Gets the model filename based on the model size.
294
+
295
+ Returns
296
+ -------
297
+ str
298
+ The model filename for the specified model size.
299
+ """
300
+ if self .model_size == "large" :
301
+ return "kokoro-v0_19.onnx"
302
+ elif self .model_size == "medium" :
303
+ return "kokoro-v0_19.fp16.onnx"
304
+ elif self .model_size == "small" :
305
+ return "kokoro-v0_19.int8.onnx"
306
+ else :
307
+ raise TTSModelError (f"Unsupported model size: { self .model_size } " )
308
+
309
+ def _get_model_url (self ) -> str :
310
+ """
311
+ Gets the full model URL based on the model size.
312
+
313
+ Returns
314
+ -------
315
+ str
316
+ The full URL for downloading the model.
317
+ """
318
+ return self .BASE_MODEL_URL + self ._get_model_filename ()
0 commit comments