Skip to content

Commit c292137

Browse files
committed
feat: support speech2text and tts for gpustack
1 parent d6d573e commit c292137

File tree

7 files changed

+197
-0
lines changed

7 files changed

+197
-0
lines changed

api/core/model_runtime/model_providers/gpustack/gpustack.yaml

+18
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ supported_model_types:
99
- llm
1010
- text-embedding
1111
- rerank
12+
- speech2text
13+
- tts
1214
configurate_methods:
1315
- customizable-model
1416
model_credential_schema:
@@ -118,3 +120,19 @@ model_credential_schema:
118120
label:
119121
en_US: Not Support
120122
zh_Hans: 不支持
123+
- variable: voices
124+
show_on:
125+
- variable: __model_type
126+
value: tts
127+
label:
128+
en_US: Available Voices (comma-separated)
129+
zh_Hans: 可用声音(用英文逗号分隔)
130+
type: text-input
131+
required: false
132+
default: "Chinese Female"
133+
placeholder:
134+
en_US: "Chinese Female, Chinese Male, Japanese Male, Cantonese Female, English Female, English Male, Korean Female"
135+
zh_Hans: "Chinese Female, Chinese Male, Japanese Male, Cantonese Female, English Female, English Male, Korean Female"
136+
help:
137+
en_US: "List voice names separated by commas. First voice will be used as default."
138+
zh_Hans: "用英文逗号分隔的声音列表。第一个声音将作为默认值。"

api/core/model_runtime/model_providers/gpustack/speech2text/__init__.py

Whitespace-only changes.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
from typing import IO, Optional
2+
3+
from core.model_runtime.model_providers.openai_api_compatible.speech2text.speech2text import OAICompatSpeech2TextModel
4+
5+
6+
class GPUStackSpeech2TextModel(OAICompatSpeech2TextModel):
7+
"""
8+
Model class for GPUStack Speech to text model.
9+
"""
10+
11+
def _invoke(self, model: str, credentials: dict, file: IO[bytes], user: Optional[str] = None) -> str:
12+
"""
13+
Invoke speech2text model
14+
:param model: model name
15+
:param credentials: model credentials
16+
:param file: audio file
17+
:param user: unique user id
18+
:return: text for given audio file
19+
"""
20+
compatible_credentials = self._get_compatible_credentials(credentials)
21+
return super()._invoke(model, compatible_credentials, file)
22+
23+
def validate_credentials(self, model: str, credentials: dict) -> None:
24+
"""
25+
Validate model credentials
26+
27+
:param model: model name
28+
:param credentials: model credentials
29+
"""
30+
compatible_credentials = self._get_compatible_credentials(credentials)
31+
super().validate_credentials(model, compatible_credentials)
32+
33+
def _get_compatible_credentials(self, credentials: dict) -> dict:
34+
"""
35+
Get compatible credentials
36+
37+
:param credentials: model credentials
38+
:return: compatible credentials
39+
"""
40+
compatible_credentials = credentials.copy()
41+
base_url = credentials["endpoint_url"].rstrip("/").removesuffix("/v1-openai")
42+
compatible_credentials["endpoint_url"] = f"{base_url}/v1-openai"
43+
return compatible_credentials

api/core/model_runtime/model_providers/gpustack/tts/__init__.py

Whitespace-only changes.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
from typing import Any, Optional
2+
3+
from core.model_runtime.model_providers.openai_api_compatible.tts.tts import OAICompatText2SpeechModel
4+
5+
6+
class GPUStackText2SpeechModel(OAICompatText2SpeechModel):
7+
"""
8+
Model class for GPUStack Text to Speech model.
9+
"""
10+
11+
def _invoke(
12+
self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, user: Optional[str] = None
13+
) -> Any:
14+
"""
15+
Invoke text2speech model
16+
17+
:param model: model name
18+
:param tenant_id: user tenant id
19+
:param credentials: model credentials
20+
:param content_text: text content to be translated
21+
:param voice: model timbre
22+
:param user: unique user id
23+
:return: text translated to audio file
24+
"""
25+
compatible_credentials = self._get_compatible_credentials(credentials)
26+
return super()._invoke(
27+
model=model,
28+
tenant_id=tenant_id,
29+
credentials=compatible_credentials,
30+
content_text=content_text,
31+
voice=voice,
32+
user=user,
33+
)
34+
35+
def validate_credentials(self, model: str, credentials: dict, user: Optional[str] = None) -> None:
36+
"""
37+
Validate model credentials
38+
39+
:param model: model name
40+
:param credentials: model credentials
41+
:param user: unique user id
42+
"""
43+
compatible_credentials = self._get_compatible_credentials(credentials)
44+
super().validate_credentials(model, compatible_credentials)
45+
46+
def _get_compatible_credentials(self, credentials: dict) -> dict:
47+
"""
48+
Get compatible credentials
49+
50+
:param credentials: model credentials
51+
:return: compatible credentials
52+
"""
53+
compatible_credentials = credentials.copy()
54+
base_url = credentials["endpoint_url"].rstrip("/").removesuffix("/v1-openai")
55+
compatible_credentials["endpoint_url"] = f"{base_url}/v1-openai"
56+
57+
return compatible_credentials
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import os
2+
from pathlib import Path
3+
4+
import pytest
5+
6+
from core.model_runtime.errors.validate import CredentialsValidateFailedError
7+
from core.model_runtime.model_providers.gpustack.speech2text.speech2text import GPUStackSpeech2TextModel
8+
9+
10+
def test_validate_credentials():
11+
model = GPUStackSpeech2TextModel()
12+
13+
with pytest.raises(CredentialsValidateFailedError):
14+
model.validate_credentials(
15+
model="faster-whisper-medium",
16+
credentials={
17+
"endpoint_url": "invalid_url",
18+
"api_key": "invalid_api_key",
19+
},
20+
)
21+
22+
model.validate_credentials(
23+
model="faster-whisper-medium",
24+
credentials={
25+
"endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"),
26+
"api_key": os.environ.get("GPUSTACK_API_KEY"),
27+
},
28+
)
29+
30+
31+
def test_invoke_model():
32+
model = GPUStackSpeech2TextModel()
33+
34+
# Get the directory of the current file
35+
current_dir = os.path.dirname(os.path.abspath(__file__))
36+
37+
# Get assets directory
38+
assets_dir = os.path.join(os.path.dirname(current_dir), "assets")
39+
40+
# Construct the path to the audio file
41+
audio_file_path = os.path.join(assets_dir, "audio.mp3")
42+
43+
file = Path(audio_file_path).read_bytes()
44+
45+
result = model.invoke(
46+
model="faster-whisper-medium",
47+
credentials={
48+
"endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"),
49+
"api_key": os.environ.get("GPUSTACK_API_KEY"),
50+
},
51+
file=file,
52+
)
53+
54+
assert isinstance(result, str)
55+
assert result == "1, 2, 3, 4, 5, 6, 7, 8, 9, 10"
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import os
2+
3+
from core.model_runtime.model_providers.gpustack.tts.tts import GPUStackText2SpeechModel
4+
5+
6+
def test_invoke_model():
7+
model = GPUStackText2SpeechModel()
8+
9+
result = model.invoke(
10+
model="cosyvoice-300m-sft",
11+
tenant_id="test",
12+
credentials={
13+
"endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"),
14+
"api_key": os.environ.get("GPUSTACK_API_KEY"),
15+
},
16+
content_text="Hello world",
17+
voice="Chinese Female",
18+
)
19+
20+
content = b""
21+
for chunk in result:
22+
content += chunk
23+
24+
assert content != b""

0 commit comments

Comments
 (0)