Skip to content

Support TTS and Speech2Text for Model Provider GPUStack #12381

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions api/core/model_runtime/model_providers/gpustack/gpustack.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ supported_model_types:
- llm
- text-embedding
- rerank
- speech2text
- tts
configurate_methods:
- customizable-model
model_credential_schema:
Expand Down Expand Up @@ -118,3 +120,19 @@ model_credential_schema:
label:
en_US: Not Support
zh_Hans: 不支持
- variable: voices
show_on:
- variable: __model_type
value: tts
label:
en_US: Available Voices (comma-separated)
zh_Hans: 可用声音(用英文逗号分隔)
type: text-input
required: false
default: "Chinese Female"
placeholder:
en_US: "Chinese Female, Chinese Male, Japanese Male, Cantonese Female, English Female, English Male, Korean Female"
zh_Hans: "Chinese Female, Chinese Male, Japanese Male, Cantonese Female, English Female, English Male, Korean Female"
help:
en_US: "List voice names separated by commas. First voice will be used as default."
zh_Hans: "用英文逗号分隔的声音列表。第一个声音将作为默认值。"
16 changes: 10 additions & 6 deletions api/core/model_runtime/model_providers/gpustack/llm/llm.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from collections.abc import Generator

from yarl import URL

from core.model_runtime.entities.llm_entities import LLMResult
from core.model_runtime.entities.message_entities import (
PromptMessage,
Expand All @@ -24,9 +22,10 @@ def _invoke(
stream: bool = True,
user: str | None = None,
) -> LLMResult | Generator:
compatible_credentials = self._get_compatible_credentials(credentials)
return super()._invoke(
model,
credentials,
compatible_credentials,
prompt_messages,
model_parameters,
tools,
Expand All @@ -36,10 +35,15 @@ def _invoke(
)

def validate_credentials(self, model: str, credentials: dict) -> None:
self._add_custom_parameters(credentials)
super().validate_credentials(model, credentials)
compatible_credentials = self._get_compatible_credentials(credentials)
super().validate_credentials(model, compatible_credentials)

def _get_compatible_credentials(self, credentials: dict) -> dict:
credentials = credentials.copy()
base_url = credentials["endpoint_url"].rstrip("/").removesuffix("/v1-openai")
credentials["endpoint_url"] = f"{base_url}/v1-openai"
return credentials

@staticmethod
def _add_custom_parameters(credentials: dict) -> None:
credentials["endpoint_url"] = str(URL(credentials["endpoint_url"]) / "v1-openai")
credentials["mode"] = "chat"
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from typing import IO, Optional

from core.model_runtime.model_providers.openai_api_compatible.speech2text.speech2text import OAICompatSpeech2TextModel


class GPUStackSpeech2TextModel(OAICompatSpeech2TextModel):
"""
Model class for GPUStack Speech to text model.
"""

def _invoke(self, model: str, credentials: dict, file: IO[bytes], user: Optional[str] = None) -> str:
"""
Invoke speech2text model
:param model: model name
:param credentials: model credentials
:param file: audio file
:param user: unique user id
:return: text for given audio file
"""
compatible_credentials = self._get_compatible_credentials(credentials)
return super()._invoke(model, compatible_credentials, file)

def validate_credentials(self, model: str, credentials: dict) -> None:
"""
Validate model credentials

:param model: model name
:param credentials: model credentials
"""
compatible_credentials = self._get_compatible_credentials(credentials)
super().validate_credentials(model, compatible_credentials)

def _get_compatible_credentials(self, credentials: dict) -> dict:
"""
Get compatible credentials

:param credentials: model credentials
:return: compatible credentials
"""
compatible_credentials = credentials.copy()
base_url = credentials["endpoint_url"].rstrip("/").removesuffix("/v1-openai")
compatible_credentials["endpoint_url"] = f"{base_url}/v1-openai"
return compatible_credentials
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from typing import Optional

from yarl import URL

from core.entities.embedding_type import EmbeddingInputType
from core.model_runtime.entities.text_embedding_entities import (
TextEmbeddingResult,
Expand All @@ -24,12 +22,15 @@ def _invoke(
user: Optional[str] = None,
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
) -> TextEmbeddingResult:
return super()._invoke(model, credentials, texts, user, input_type)
compatible_credentials = self._get_compatible_credentials(credentials)
return super()._invoke(model, compatible_credentials, texts, user, input_type)

def validate_credentials(self, model: str, credentials: dict) -> None:
self._add_custom_parameters(credentials)
super().validate_credentials(model, credentials)
compatible_credentials = self._get_compatible_credentials(credentials)
super().validate_credentials(model, compatible_credentials)

@staticmethod
def _add_custom_parameters(credentials: dict) -> None:
credentials["endpoint_url"] = str(URL(credentials["endpoint_url"]) / "v1-openai")
def _get_compatible_credentials(self, credentials: dict) -> dict:
credentials = credentials.copy()
base_url = credentials["endpoint_url"].rstrip("/").removesuffix("/v1-openai")
credentials["endpoint_url"] = f"{base_url}/v1-openai"
return credentials
Empty file.
57 changes: 57 additions & 0 deletions api/core/model_runtime/model_providers/gpustack/tts/tts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from typing import Any, Optional

from core.model_runtime.model_providers.openai_api_compatible.tts.tts import OAICompatText2SpeechModel


class GPUStackText2SpeechModel(OAICompatText2SpeechModel):
"""
Model class for GPUStack Text to Speech model.
"""

def _invoke(
self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, user: Optional[str] = None
) -> Any:
"""
Invoke text2speech model

:param model: model name
:param tenant_id: user tenant id
:param credentials: model credentials
:param content_text: text content to be translated
:param voice: model timbre
:param user: unique user id
:return: text translated to audio file
"""
compatible_credentials = self._get_compatible_credentials(credentials)
return super()._invoke(
model=model,
tenant_id=tenant_id,
credentials=compatible_credentials,
content_text=content_text,
voice=voice,
user=user,
)

def validate_credentials(self, model: str, credentials: dict, user: Optional[str] = None) -> None:
"""
Validate model credentials

:param model: model name
:param credentials: model credentials
:param user: unique user id
"""
compatible_credentials = self._get_compatible_credentials(credentials)
super().validate_credentials(model, compatible_credentials)

def _get_compatible_credentials(self, credentials: dict) -> dict:
"""
Get compatible credentials

:param credentials: model credentials
:return: compatible credentials
"""
compatible_credentials = credentials.copy()
base_url = credentials["endpoint_url"].rstrip("/").removesuffix("/v1-openai")
compatible_credentials["endpoint_url"] = f"{base_url}/v1-openai"

return compatible_credentials
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import os
from pathlib import Path

import pytest

from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.gpustack.speech2text.speech2text import GPUStackSpeech2TextModel


def test_validate_credentials():
model = GPUStackSpeech2TextModel()

with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model="faster-whisper-medium",
credentials={
"endpoint_url": "invalid_url",
"api_key": "invalid_api_key",
},
)

model.validate_credentials(
model="faster-whisper-medium",
credentials={
"endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"),
"api_key": os.environ.get("GPUSTACK_API_KEY"),
},
)


def test_invoke_model():
model = GPUStackSpeech2TextModel()

# Get the directory of the current file
current_dir = os.path.dirname(os.path.abspath(__file__))

# Get assets directory
assets_dir = os.path.join(os.path.dirname(current_dir), "assets")

# Construct the path to the audio file
audio_file_path = os.path.join(assets_dir, "audio.mp3")

file = Path(audio_file_path).read_bytes()

result = model.invoke(
model="faster-whisper-medium",
credentials={
"endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"),
"api_key": os.environ.get("GPUSTACK_API_KEY"),
},
file=file,
)

assert isinstance(result, str)
assert result == "1, 2, 3, 4, 5, 6, 7, 8, 9, 10"
24 changes: 24 additions & 0 deletions api/tests/integration_tests/model_runtime/gpustack/test_tts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import os

from core.model_runtime.model_providers.gpustack.tts.tts import GPUStackText2SpeechModel


def test_invoke_model():
model = GPUStackText2SpeechModel()

result = model.invoke(
model="cosyvoice-300m-sft",
tenant_id="test",
credentials={
"endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"),
"api_key": os.environ.get("GPUSTACK_API_KEY"),
},
content_text="Hello world",
voice="Chinese Female",
)

content = b""
for chunk in result:
content += chunk

assert content != b""
Loading