diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index c312f93e1..2caf3cb76 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -44,4 +44,4 @@ repos: rev: 7.1.0 hooks: - id: flake8 - args: ["--ignore=E501,E731,W503,W504"] + args: ["--ignore=E501,E731,W503,W504,E203"] diff --git a/poetry.lock b/poetry.lock index 31034f817..91c2abbca 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.4 and should not be changed by hand. [[package]] name = "absl-py" @@ -893,6 +893,40 @@ mypy = ["contourpy[bokeh,docs]", "docutils-stubs", "mypy (==1.11.1)", "types-Pil test = ["Pillow", "contourpy[test-no-images]", "matplotlib"] test-no-images = ["pytest", "pytest-cov", "pytest-rerunfailures", "pytest-xdist", "wurlitzer"] +[[package]] +name = "ctranslate2" +version = "4.5.0" +description = "Fast inference engine for Transformer models" +optional = false +python-versions = ">=3.8" +files = [ + {file = "ctranslate2-4.5.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:241da685f8f7cb10b7afceeb3d879f778b56e6a1d55fc2964ddc949c80c9c7bb"}, + {file = "ctranslate2-4.5.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5328ec73b430ba1a99a85bc3b038291e7bbedc0c9987b354b3c8ca395a3b7e06"}, + {file = "ctranslate2-4.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b97ee9b15f75f84c35827df97ebe9c676f96c2e5118a2ed4d3efcf3c3e04a599"}, + {file = "ctranslate2-4.5.0-cp310-cp310-win_amd64.whl", hash = "sha256:5d9ec0a201d3c33ada1bb00929b3ff3d80642b34ca0d94465556dfa197d127c4"}, + {file = "ctranslate2-4.5.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:1bc072da977abdd4b09f0d50a45de745818a247608aa3f2865ef9a579ff11851"}, + {file = "ctranslate2-4.5.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4c56ccf1aa723ba85f4ea56b4d945dc7d2ea7f074b5eb716c85be0c8e0311c24"}, + {file = "ctranslate2-4.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:89db5b18dfc7f7bf84cafaf7cc36e885aafcaeac936977eefd3e4768fd7b2879"}, + {file = "ctranslate2-4.5.0-cp311-cp311-win_amd64.whl", hash = "sha256:253993fbbe20cd7e2602de81e6159b259dadb47b9b59486d928396bd4a4ecdaa"}, + {file = "ctranslate2-4.5.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:1a0509f172edc994aec6870fe0a90c799d85fd7ddf564059d25b60932ab2e2c4"}, + {file = "ctranslate2-4.5.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c158f2ada6e3347388ad13c69e4a6a729ba40c035a400dd447995950ecf5e62f"}, + {file = "ctranslate2-4.5.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:de3c5877fce31a0fcf3b5edbc8d4e6e22fd94a86c6b49680740ef41130efffc1"}, + {file = "ctranslate2-4.5.0-cp312-cp312-win_amd64.whl", hash = "sha256:a16a784ec7924166bdf3e86754feda0441f04d9851fc3412f34f1e2de7cbd51b"}, + {file = "ctranslate2-4.5.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:7c221153ecdda81e24679a07f0b577926879325a0347a89f8afaf2593641cb9b"}, + {file = "ctranslate2-4.5.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:919a5feab74f33694b66c0a5637f07ba7cf4995af87d960aca50e4cbe53b4054"}, + {file = "ctranslate2-4.5.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:45a45dabca3f9d8eb718685a792f9a7fc10af7362d318271181f16ebf54669b8"}, + {file = "ctranslate2-4.5.0-cp38-cp38-win_amd64.whl", hash = "sha256:5924e9adeff8b30ca0851e0f5ff13639d08e47d1219d27f615c0936a3cdedb57"}, + {file = "ctranslate2-4.5.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:d9e8120817c51515175ab163655dc14b4e21eb381d7196fd43b843b0d50efaf1"}, + {file = "ctranslate2-4.5.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f790e77458b83e109a743d0f07e9e5c023208314f5c824c26d1e3ebc62a12f71"}, + {file = "ctranslate2-4.5.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0af82185aa961869362c06ce33443b5207237790233b1614ccf92307a671aa72"}, + {file = "ctranslate2-4.5.0-cp39-cp39-win_amd64.whl", hash = "sha256:ccbccbdddb02e7c3b24666f2bc52cd475ca666fda8a317d23a97645eafd66dbe"}, +] + +[package.dependencies] +numpy = "*" +pyyaml = ">=5.3,<7" +setuptools = "*" + [[package]] name = "cycler" version = "0.12.1" @@ -1386,6 +1420,29 @@ files = [ {file = "fasteners-0.19.tar.gz", hash = "sha256:b4f37c3ac52d8a445af3a66bce57b33b5e90b97c696b7b984f530cf8f0ded09c"}, ] +[[package]] +name = "faster-whisper" +version = "1.1.1" +description = "Faster Whisper transcription with CTranslate2" +optional = false +python-versions = ">=3.9" +files = [ + {file = "faster-whisper-1.1.1.tar.gz", hash = "sha256:50d27571970c1be0c2b2680a2593d5d12f9f5d2f10484f242a1afbe7cb946604"}, + {file = "faster_whisper-1.1.1-py3-none-any.whl", hash = "sha256:5808dc334fb64fb4336921450abccfe5e313a859b31ba61def0ac7f639383d90"}, +] + +[package.dependencies] +av = ">=11" +ctranslate2 = ">=4.0,<5" +huggingface-hub = ">=0.13" +onnxruntime = ">=1.14,<2" +tokenizers = ">=0.13,<1" +tqdm = "*" + +[package.extras] +conversion = ["transformers[torch] (>=4.23)"] +dev = ["black (==23.*)", "flake8 (==6.*)", "isort (==5.*)", "pytest (==7.*)"] + [[package]] name = "filelock" version = "3.16.1" @@ -2763,8 +2820,8 @@ langchain-core = ">=0.3.29,<0.4.0" langchain-text-splitters = ">=0.3.3,<0.4.0" langsmith = ">=0.1.17,<0.3" numpy = [ - {version = ">=1.22.4,<2", markers = "python_version < \"3.12\""}, {version = ">=1.26.2,<3", markers = "python_version >= \"3.12\""}, + {version = ">=1.22.4,<2", markers = "python_version < \"3.12\""}, ] pydantic = ">=2.7.4,<3.0.0" PyYAML = ">=5.3" @@ -2787,8 +2844,8 @@ files = [ boto3 = ">=1.35.74" langchain-core = ">=0.3.27,<0.4.0" numpy = [ - {version = ">=1,<2", markers = "python_version < \"3.12\""}, {version = ">=1.26.0,<3", markers = "python_version >= \"3.12\""}, + {version = ">=1,<2", markers = "python_version < \"3.12\""}, ] pydantic = ">=2,<3" @@ -2811,8 +2868,8 @@ langchain = ">=0.3.14,<0.4.0" langchain-core = ">=0.3.29,<0.4.0" langsmith = ">=0.1.125,<0.3" numpy = [ - {version = ">=1.22.4,<2", markers = "python_version < \"3.12\""}, {version = ">=1.26.2,<3", markers = "python_version >= \"3.12\""}, + {version = ">=1.22.4,<2", markers = "python_version < \"3.12\""}, ] pydantic-settings = ">=2.4.0,<3.0.0" PyYAML = ">=5.3" @@ -4165,10 +4222,10 @@ files = [ [package.dependencies] numpy = [ + {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, {version = ">=1.23.5", markers = "python_version >= \"3.11\" and python_version < \"3.12\""}, {version = ">=1.21.4", markers = "python_version >= \"3.10\" and platform_system == \"Darwin\" and python_version < \"3.11\""}, {version = ">=1.21.2", markers = "platform_system != \"Darwin\" and python_version >= \"3.10\" and python_version < \"3.11\""}, - {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, ] [[package]] @@ -4189,10 +4246,10 @@ files = [ [package.dependencies] numpy = [ + {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, {version = ">=1.23.5", markers = "python_version >= \"3.11\" and python_version < \"3.12\""}, {version = ">=1.21.4", markers = "python_version >= \"3.10\" and platform_system == \"Darwin\" and python_version < \"3.11\""}, {version = ">=1.21.2", markers = "platform_system != \"Darwin\" and python_version >= \"3.10\" and python_version < \"3.11\""}, - {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, ] [[package]] @@ -4396,9 +4453,9 @@ files = [ [package.dependencies] numpy = [ + {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, {version = ">=1.23.2", markers = "python_version == \"3.11\""}, {version = ">=1.22.4", markers = "python_version < \"3.11\""}, - {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, ] python-dateutil = ">=2.8.2" pytz = ">=2020.1" @@ -6841,8 +6898,8 @@ files = [ contourpy = {version = ">=1.0.7", markers = "python_version >= \"3.8\" and python_version < \"3.13\""} defusedxml = ">=0.7.1,<0.8.0" matplotlib = [ - {version = ">=3.6.0", markers = "python_version >= \"3.9\" and python_version < \"3.12\""}, {version = ">=3.7.3", markers = "python_version >= \"3.12\""}, + {version = ">=3.6.0", markers = "python_version >= \"3.9\" and python_version < \"3.12\""}, ] numpy = {version = ">=1.21.2", markers = "python_version < \"3.13\""} opencv-python = ">=4.5.5.64" @@ -8175,4 +8232,4 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.0" python-versions = "^3.10, <3.13" -content-hash = "ee424289e94a1e02622089d2226e5b97a4ca2d54e9de0787487e81353d11814e" +content-hash = "da1a7720082bf43b4efc7cd972b63f39882fc7e0d69340bbc436f18e889e55b2" diff --git a/pyproject.toml b/pyproject.toml index 4457ea4cb..3534c688d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,6 +53,7 @@ tomli = "^2.0.1" openwakeword = { git = "https://github.com/maciejmajek/openWakeWord.git", branch = "chore/remove-tflite-backend" } pytest-timeout = "^2.3.1" tomli-w = "^1.1.0" +faster-whisper = "^1.1.1" [tool.poetry.group.dev.dependencies] ipykernel = "^6.29.4" diff --git a/src/rai/rai/agents/__init__.py b/src/rai/rai/agents/__init__.py index dc101282b..2b7d4461a 100644 --- a/src/rai/rai/agents/__init__.py +++ b/src/rai/rai/agents/__init__.py @@ -15,9 +15,11 @@ from rai.agents.conversational_agent import create_conversational_agent from rai.agents.state_based import create_state_based_agent from rai.agents.tool_runner import ToolRunner +from rai.agents.voice_agent import VoiceRecognitionAgent __all__ = [ "ToolRunner", "create_conversational_agent", "create_state_based_agent", + "VoiceRecognitionAgent", ] diff --git a/src/rai/rai/agents/base.py b/src/rai/rai/agents/base.py new file mode 100644 index 000000000..c2dd4fe50 --- /dev/null +++ b/src/rai/rai/agents/base.py @@ -0,0 +1,32 @@ +# Copyright (C) 2024 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from abc import ABC, abstractmethod +from typing import Optional + +from rai.communication import BaseConnector + + +class BaseAgent(ABC): + def __init__( + self, connectors: Optional[dict[str, BaseConnector]] = None, *args, **kwargs + ): + if connectors is None: + connectors = {} + self.connectors: dict[str, BaseConnector] = connectors + + @abstractmethod + def run(self, *args, **kwargs): + pass diff --git a/src/rai/rai/agents/voice_agent.py b/src/rai/rai/agents/voice_agent.py new file mode 100644 index 000000000..3fc770258 --- /dev/null +++ b/src/rai/rai/agents/voice_agent.py @@ -0,0 +1,199 @@ +# Copyright (C) 2024 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import logging +import time +from threading import Event, Lock, Thread +from typing import Any, List, Optional, TypedDict +from uuid import uuid4 + +import numpy as np +from numpy.typing import NDArray + +from rai.agents.base import BaseAgent +from rai.communication import ( + AudioInputDeviceConfig, + ROS2ARIConnector, + ROS2ARIMessage, + StreamingAudioInputDevice, +) +from rai_asr.models import BaseTranscriptionModel, BaseVoiceDetectionModel + + +class ThreadData(TypedDict): + thread: Thread + event: Event + transcription: str + joined: bool + + +class VoiceRecognitionAgent(BaseAgent): + def __init__( + self, + microphone_device_id: int, # TODO: Change to name based instead of id based identification + microphone_config: AudioInputDeviceConfig, + ros2_name: str, + transcription_model: BaseTranscriptionModel, + vad: BaseVoiceDetectionModel, + grace_period: float = 1.0, + logger: Optional[logging.Logger] = None, + ): + if logger is None: + self.logger = logging.getLogger(__name__) + else: + self.logger = logger + microphone = StreamingAudioInputDevice() + microphone.configure_device( + target=str(microphone_device_id), config=microphone_config + ) + ros2_connector = ROS2ARIConnector(ros2_name) + super().__init__(connectors={"microphone": microphone, "ros2": ros2_connector}) + self.microphone_device_id = str(microphone_device_id) + self.should_record_pipeline: List[BaseVoiceDetectionModel] = [] + self.should_stop_pipeline: List[BaseVoiceDetectionModel] = [] + + self.transcription_model = transcription_model + self.transcription_lock = Lock() + + self.vad: BaseVoiceDetectionModel = vad + + self.grace_period = grace_period + self.grace_period_start = 0 + + self.recording_started = False + self.ran_setup = False + + self.sample_buffer = [] + self.sample_buffer_lock = Lock() + self.active_thread = "" + self.transcription_threads: dict[str, ThreadData] = {} + self.transcription_buffers: dict[str, list[NDArray]] = {} + + def __call__(self): + self.run() + + def add_detection_model( + self, model: BaseVoiceDetectionModel, pipeline: str = "record" + ): + if pipeline == "record": + self.should_record_pipeline.append(model) + elif pipeline == "stop": + self.should_stop_pipeline.append(model) + else: + raise ValueError("Pipeline should be either 'record' or 'stop'") + + def run(self): + self.running = True + self.listener_handle = self.connectors["microphone"].start_action( + action_data=None, + target=self.microphone_device_id, + on_feedback=self.on_new_sample, + on_done=lambda: None, + ) + + def stop(self): + self.logger.info("Stopping voice agent") + self.running = False + self.connectors["microphone"].terminate_action(self.listener_handle) + while not all( + [thread["joined"] for thread in self.transcription_threads.values()] + ): + for thread_id in self.transcription_threads: + if self.transcription_threads[thread_id]["event"].is_set(): + self.transcription_threads[thread_id]["thread"].join() + self.transcription_threads[thread_id]["joined"] = True + else: + self.logger.info( + f"Waiting for transcription of {thread_id} to finish..." + ) + self.logger.info("Voice agent stopped") + + def on_new_sample(self, indata: np.ndarray, status_flags: dict[str, Any]): + sample_time = time.time() + with self.sample_buffer_lock: + self.sample_buffer.append(indata) + if not self.recording_started and len(self.sample_buffer) > 5: + self.sample_buffer = self.sample_buffer[-5:] + + # attempt to join finished threads: + for thread_id in self.transcription_threads: + if self.transcription_threads[thread_id]["event"].is_set(): + self.transcription_threads[thread_id]["thread"].join() + self.transcription_threads[thread_id]["joined"] = True + + voice_detected, output_parameters = self.vad(indata, {}) + should_record = False + # TODO: second condition is temporary + if voice_detected and not self.recording_started: + should_record = self.should_record(indata, output_parameters) + + if should_record: + self.logger.info("starting recording...") + self.recording_started = True + thread_id = str(uuid4())[0:8] + transcription_thread = Thread( + target=self.transcription_thread, + args=[thread_id], + ) + transcription_finished = Event() + self.active_thread = thread_id + self.transcription_threads[thread_id] = { + "thread": transcription_thread, + "event": transcription_finished, + "transcription": "", + "joined": False, + } + + if voice_detected: + self.logger.debug("Voice detected... resetting grace period") + self.grace_period_start = sample_time + + if ( + self.recording_started + and sample_time - self.grace_period_start > self.grace_period + ): + self.logger.info( + "Grace period ended... stopping recording, starting transcription" + ) + self.recording_started = False + self.grace_period_start = 0 + with self.sample_buffer_lock: + self.transcription_buffers[self.active_thread] = self.sample_buffer + self.sample_buffer = [] + self.transcription_threads[self.active_thread]["thread"].start() + self.active_thread = "" + + def should_record( + self, audio_data: NDArray, input_parameters: dict[str, Any] + ) -> bool: + for model in self.should_record_pipeline: + detected, output = model(audio_data, input_parameters) + if detected: + return True + return False + + def transcription_thread(self, identifier: str): + self.logger.info(f"transcription thread {identifier} started") + audio_data = np.concatenate(self.transcription_buffers[identifier]) + with self.transcription_lock: # this is only necessary for the local model... TODO: fix this somehow + transcription = self.transcription_model.transcribe(audio_data) + self.connectors["ros2"].send_message( + ROS2ARIMessage( + {"data": transcription}, {"msg_type": "std_msgs/msg/String"} + ), + "/from_human", + ) + self.transcription_threads[identifier]["transcription"] = transcription + self.transcription_threads[identifier]["event"].set() diff --git a/src/rai/rai/communication/__init__.py b/src/rai/rai/communication/__init__.py index 04c1fc4f9..5134c2d93 100644 --- a/src/rai/rai/communication/__init__.py +++ b/src/rai/rai/communication/__init__.py @@ -15,7 +15,12 @@ from .ari_connector import ARIConnector, ARIMessage from .base_connector import BaseConnector, BaseMessage from .hri_connector import HRIConnector, HRIMessage, HRIPayload -from .sound_device_connector import SoundDeviceError, StreamingAudioInputDevice +from .ros2.connectors import ROS2ARIConnector, ROS2ARIMessage +from .sound_device_connector import ( + AudioInputDeviceConfig, + SoundDeviceError, + StreamingAudioInputDevice, +) __all__ = [ "ARIConnector", @@ -25,6 +30,9 @@ "HRIConnector", "HRIMessage", "HRIPayload", + "ROS2ARIConnector", + "ROS2ARIMessage", "StreamingAudioInputDevice", "SoundDeviceError", + "AudioInputDeviceConfig", ] diff --git a/src/rai/rai/communication/sound_device_connector.py b/src/rai/rai/communication/sound_device_connector.py index 449c0890f..40edad01a 100644 --- a/src/rai/rai/communication/sound_device_connector.py +++ b/src/rai/rai/communication/sound_device_connector.py @@ -30,7 +30,6 @@ def __init__(self, msg: str): class AudioInputDeviceConfig(TypedDict): block_size: int consumer_sampling_rate: int - target_sampling_rate: int dtype: str device_number: Optional[int] @@ -44,7 +43,6 @@ class ConfiguredAudioInputDevice: sample_rate (int): Device sample rate consumer_sampling_rate (int): The sampling rate of the consumer window_size_samples (int): The size of the window in samples - target_sampling_rate (int): The target sampling rate dtype (str): The data type of the audio samples """ @@ -58,7 +56,6 @@ def __init__(self, config: AudioInputDeviceConfig): self.window_size_samples = int( config["block_size"] * self.sample_rate / config["consumer_sampling_rate"] ) - self.target_sampling_rate = int(config["target_sampling_rate"]) self.dtype = config["dtype"] @@ -132,9 +129,9 @@ def start_action( def callback(indata: np.ndarray, frames: int, _, status: CallbackFlags): indata = indata.flatten() - sample_time_length = len(indata) / target_device.target_sampling_rate - if target_device.sample_rate != target_device.target_sampling_rate: - indata = resample(indata, int(sample_time_length * target_device.target_sampling_rate)) # type: ignore + sample_time_length = len(indata) / target_device.sample_rate + if target_device.sample_rate != target_device.consumer_sampling_rate: + indata = resample(indata, int(sample_time_length * target_device.consumer_sampling_rate)) # type: ignore flag_dict = { "input_overflow": status.input_overflow, "input_underflow": status.input_underflow, diff --git a/src/rai_asr/rai_asr/asr_clients.py b/src/rai_asr/rai_asr/asr_clients.py index df538509e..e08d0afd0 100644 --- a/src/rai_asr/rai_asr/asr_clients.py +++ b/src/rai_asr/rai_asr/asr_clients.py @@ -24,6 +24,8 @@ from scipy.io import wavfile from whisper.transcribe import transcribe +# WARN: This file is going to be removed in favour of rai_asr.models + class ASRModel: def __init__(self, model_name: str, sample_rate: int, language: str = "en"): diff --git a/src/rai_asr/rai_asr/models/__init__.py b/src/rai_asr/rai_asr/models/__init__.py new file mode 100644 index 000000000..1d1a7e9de --- /dev/null +++ b/src/rai_asr/rai_asr/models/__init__.py @@ -0,0 +1,28 @@ +# Copyright (C) 2024 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from rai_asr.models.base import BaseTranscriptionModel, BaseVoiceDetectionModel +from rai_asr.models.local_whisper import LocalWhisper +from rai_asr.models.open_ai_whisper import OpenAIWhisper +from rai_asr.models.open_wake_word import OpenWakeWord +from rai_asr.models.silero_vad import SileroVAD + +__all__ = [ + "BaseVoiceDetectionModel", + "SileroVAD", + "OpenWakeWord", + "BaseTranscriptionModel", + "LocalWhisper", + "OpenAIWhisper", +] diff --git a/src/rai_asr/rai_asr/models/base.py b/src/rai_asr/rai_asr/models/base.py new file mode 100644 index 000000000..13142df87 --- /dev/null +++ b/src/rai_asr/rai_asr/models/base.py @@ -0,0 +1,47 @@ +# Copyright (C) 2024 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from abc import ABC, abstractmethod +from typing import Any, Tuple + +import numpy as np +from numpy._typing import NDArray + + +class BaseVoiceDetectionModel(ABC): + + def __call__( + self, audio_data: NDArray, input_parameters: dict[str, Any] + ) -> Tuple[bool, dict[str, Any]]: + return self.detect(audio_data, input_parameters) + + @abstractmethod + def detect( + self, audio_data: NDArray, input_parameters: dict[str, Any] + ) -> Tuple[bool, dict[str, Any]]: + pass + + +class BaseTranscriptionModel(ABC): + def __init__(self, model_name: str, sample_rate: int, language: str = "en"): + self.model_name = model_name + self.sample_rate = sample_rate + self.language = language + + self.latest_transcription = "" + + @abstractmethod + def transcribe(self, data: NDArray[np.int16]) -> str: + pass diff --git a/src/rai_asr/rai_asr/models/local_whisper.py b/src/rai_asr/rai_asr/models/local_whisper.py new file mode 100644 index 000000000..f8292e339 --- /dev/null +++ b/src/rai_asr/rai_asr/models/local_whisper.py @@ -0,0 +1,64 @@ +# Copyright (C) 2024 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import cast + +import numpy as np +import torch +import whisper +from faster_whisper import WhisperModel +from numpy._typing import NDArray + +from rai_asr.models.base import BaseTranscriptionModel + + +class LocalWhisper(BaseTranscriptionModel): + def __init__( + self, model_name: str, sample_rate: int, language: str = "en", **kwargs + ): + super().__init__(model_name, sample_rate, language) + if torch.cuda.is_available(): + self.whisper = whisper.load_model(self.model_name, device="cuda", **kwargs) + else: + self.whisper = whisper.load_model(self.model_name, **kwargs) + + self.logger = logging.getLogger(__name__) + + def transcribe(self, data: NDArray[np.int16]) -> str: + normalized_data = data.astype(np.float32) / 32768.0 + result = whisper.transcribe( + self.whisper, normalized_data + ) # TODO: handling of additional transcribe arguments (perhaps in model init) + transcription = result["text"] + self.logger.info("transcription: %s", transcription) + transcription = cast(str, transcription) + self.latest_transcription = transcription + return transcription + + +class FasterWhisper(BaseTranscriptionModel): + def __init__( + self, model_name: str, sample_rate: int, language: str = "en", **kwargs + ): + super().__init__(model_name, sample_rate, language) + self.model = WhisperModel(model_name, **kwargs) + self.logger = logging.getLogger(__name__) + + def transcribe(self, data: NDArray[np.int16]) -> str: + normalized_data = data.astype(np.float32) / 32768.0 + segments, _ = self.model.transcribe(normalized_data) + transcription = " ".join(segment.text for segment in segments) + self.logger.info("transcription: %s", transcription) + return transcription diff --git a/src/rai_asr/rai_asr/models/open_ai_whisper.py b/src/rai_asr/rai_asr/models/open_ai_whisper.py new file mode 100644 index 000000000..0f74dd093 --- /dev/null +++ b/src/rai_asr/rai_asr/models/open_ai_whisper.py @@ -0,0 +1,55 @@ +# Copyright (C) 2024 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import io +import logging +import os +from functools import partial + +import numpy as np +from numpy.typing import NDArray +from openai import OpenAI +from scipy.io import wavfile + +from rai_asr.models.base import BaseTranscriptionModel + + +class OpenAIWhisper(BaseTranscriptionModel): + def __init__( + self, model_name: str, sample_rate: int, language: str = "en", **kwargs + ): + super().__init__(model_name, sample_rate, language) + api_key = os.getenv("OPENAI_API_KEY") + if api_key is None: + raise ValueError("OPENAI_API_KEY environment variable is not set.") + self.api_key = api_key + self.openai_client = OpenAI() + self.model = partial( + self.openai_client.audio.transcriptions.create, + model=self.model_name, + **kwargs, + ) + self.logger = logging.getLogger(__name__) + self.samples = [] + + def transcribe(self, data: NDArray[np.int16]) -> str: + normalized_data = data.astype(np.float32) / 32768.0 + with io.BytesIO() as temp_wav_buffer: + wavfile.write(temp_wav_buffer, self.sample_rate, normalized_data) + temp_wav_buffer.seek(0) + temp_wav_buffer.name = "temp.wav" + response = self.model(file=temp_wav_buffer, language=self.language) + transcription = response.text + self.logger.info("transcription: %s", transcription) + return transcription diff --git a/src/rai_asr/rai_asr/models/open_wake_word.py b/src/rai_asr/rai_asr/models/open_wake_word.py new file mode 100644 index 000000000..1fb4211e0 --- /dev/null +++ b/src/rai_asr/rai_asr/models/open_wake_word.py @@ -0,0 +1,47 @@ +# Copyright (C) 2024 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Tuple + +from numpy.typing import NDArray +from openwakeword.model import Model as OWWModel +from openwakeword.utils import download_models + +from rai_asr.models import BaseVoiceDetectionModel + + +class OpenWakeWord(BaseVoiceDetectionModel): + def __init__(self, wake_word_model_path: str, threshold: float = 0.5): + super(OpenWakeWord, self).__init__() + self.model_name = "open_wake_word" + download_models() + self.model = OWWModel( + wakeword_models=[ + wake_word_model_path, + ], + inference_framework="onnx", + ) + self.threshold = threshold + + def detect( + self, audio_data: NDArray, input_parameters: dict[str, Any] + ) -> Tuple[bool, dict[str, Any]]: + predictions = self.model.predict(audio_data) + ret = input_parameters.copy() + ret.update({self.model_name: {"predictions": predictions}}) + for key, value in predictions.items(): + if value > self.threshold: + self.model.reset() + return True, ret + return False, ret diff --git a/src/rai_asr/rai_asr/models/silero_vad.py b/src/rai_asr/rai_asr/models/silero_vad.py new file mode 100644 index 000000000..fdecb8b5b --- /dev/null +++ b/src/rai_asr/rai_asr/models/silero_vad.py @@ -0,0 +1,61 @@ +# Copyright (C) 2024 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Literal, Tuple + +import numpy as np +import torch +from numpy.typing import NDArray + +from rai_asr.models import BaseVoiceDetectionModel + + +class SileroVAD(BaseVoiceDetectionModel): + def __init__(self, sampling_rate: Literal[8000, 16000] = 16000, threshold=0.5): + super(SileroVAD, self).__init__() + self.model_name = "silero_vad" + self.model, _ = torch.hub.load( + repo_or_dir="snakers4/silero-vad", + model=self.model_name, + ) # type: ignore + # NOTE: See silero vad implementation: https://github.com/snakers4/silero-vad/blob/9060f664f20eabb66328e4002a41479ff288f14c/src/silero_vad/utils_vad.py#L61 + if sampling_rate == 16000: + self.sampling_rate = 16000 + self.window_size = 512 + elif sampling_rate == 8000: + self.sampling_rate = 8000 + self.window_size = 256 + else: + raise ValueError( + "Only 8000 and 16000 sampling rates are supported" + ) # TODO: consider if this should be a ValueError or something else + self.threshold = threshold + + def int2float(self, sound: NDArray[np.int16]): + converted_sound = sound.astype("float32") + converted_sound *= 1 / 32768 + converted_sound = converted_sound.squeeze() + return converted_sound + + def detect( + self, audio_data: NDArray, input_parameters: dict[str, Any] + ) -> Tuple[bool, dict[str, Any]]: + vad_confidence = self.model( + torch.tensor(self.int2float(audio_data[-self.window_size :])), + self.sampling_rate, + ).item() + ret = input_parameters.copy() + ret.update({self.model_name: {"vad_confidence": vad_confidence}}) + + return vad_confidence > self.threshold, ret diff --git a/tests/communication/test_sound_device_connector.py b/tests/communication/test_sound_device_connector.py index 6eb2b0bd4..a98e251a3 100644 --- a/tests/communication/test_sound_device_connector.py +++ b/tests/communication/test_sound_device_connector.py @@ -31,7 +31,6 @@ def device_config(): return { "block_size": 1024, "consumer_sampling_rate": 44100, - "target_sampling_rate": 16000, "dtype": "float32", } @@ -57,7 +56,6 @@ def test_configure( audio_input_device.configred_devices[device_id].consumer_sampling_rate == 44100 ) assert audio_input_device.configred_devices[device_id].window_size_samples == 1024 - assert audio_input_device.configred_devices[device_id].target_sampling_rate == 16000 assert audio_input_device.configred_devices[device_id].dtype == "float32"