diff --git a/.github/workflows/poetry-test.yml b/.github/workflows/poetry-test.yml index 96dfc8e45..1f01701a9 100644 --- a/.github/workflows/poetry-test.yml +++ b/.github/workflows/poetry-test.yml @@ -38,6 +38,12 @@ jobs: - name: Install Poetry uses: snok/install-poetry@v1 + - name: Create virtual audio device + run: | + apt-get update + DEBIAN_FRONTEND=noninteractive apt-get --yes install jackd + jackd -d dummy -r 44100 & + - name: Install python dependencies run: poetry install --with openset,nomad @@ -63,4 +69,4 @@ jobs: run: | source /opt/ros/${{ matrix.ros_distro }}/setup.bash source install/setup.bash - poetry run pytest + poetry run pytest -m "not billable" diff --git a/pyproject.toml b/pyproject.toml index 20e7eb32f..4457ea4cb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -85,7 +85,8 @@ profile = "black" [tool.pytest.ini_options] markers = [ "billable: marks test as billable (deselect with '-m \"not billable\"')", + "ci_only: marks test as cli only (deselect with '-m \"not ci_only\"')", ] -addopts = "-m 'not billable' --ignore=src" +addopts = "-m 'not billable and not ci_only' --ignore=src" log_cli = true log_cli_level = "DEBUG" diff --git a/src/rai/rai/agents/__init__.py b/src/rai/rai/agents/__init__.py index ef74fc891..dc101282b 100644 --- a/src/rai/rai/agents/__init__.py +++ b/src/rai/rai/agents/__init__.py @@ -11,3 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from rai.agents.conversational_agent import create_conversational_agent +from rai.agents.state_based import create_state_based_agent +from rai.agents.tool_runner import ToolRunner + +__all__ = [ + "ToolRunner", + "create_conversational_agent", + "create_state_based_agent", +] diff --git a/src/rai/rai/communication/__init__.py b/src/rai/rai/communication/__init__.py new file mode 100644 index 000000000..f22b87447 --- /dev/null +++ b/src/rai/rai/communication/__init__.py @@ -0,0 +1,23 @@ +# Copyright (C) 2024 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .base_connector import BaseConnector, BaseMessage +from .sound_device_connector import SoundDeviceError, StreamingAudioInputDevice + +__all__ = [ + "BaseMessage", + "BaseConnector", + "StreamingAudioInputDevice", + "SoundDeviceError", +] diff --git a/src/rai/rai/communication/base_connector.py b/src/rai/rai/communication/base_connector.py new file mode 100644 index 000000000..fe01097fc --- /dev/null +++ b/src/rai/rai/communication/base_connector.py @@ -0,0 +1,49 @@ +# Copyright (C) 2024 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod +from typing import Callable +from uuid import uuid4 + + +class BaseMessage(ABC): + pass + + +class BaseConnector(ABC): + + def _generate_handle(self) -> str: + return str(uuid4()) + + @abstractmethod + def send_message(self, msg: BaseMessage, target: str) -> None: + pass + + @abstractmethod + def receive_message(self, source: str) -> BaseMessage: + pass + + @abstractmethod + def send_and_wait(self, target: str) -> BaseMessage: + pass + + @abstractmethod + def start_action( + self, target: str, on_feedback: Callable, on_finish: Callable = lambda _: None + ) -> str: + pass + + @abstractmethod + def terminate_action(self, action_handle: str): + pass diff --git a/src/rai/rai/communication/sound_device_connector.py b/src/rai/rai/communication/sound_device_connector.py new file mode 100644 index 000000000..f88619c2b --- /dev/null +++ b/src/rai/rai/communication/sound_device_connector.py @@ -0,0 +1,141 @@ +# Copyright (C) 2024 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Callable, Optional, TypedDict + +import numpy as np +import sounddevice as sd +from scipy.signal import resample +from sounddevice import CallbackFlags + +from rai.communication.base_connector import BaseConnector, BaseMessage + + +class SoundDeviceError(Exception): + def __init__(self, msg: str): + super().__init__(msg) + + +class AudioInputDeviceConfig(TypedDict): + block_size: int + consumer_sampling_rate: int + target_sampling_rate: int + dtype: str + device_number: Optional[int] + + +class ConfiguredAudioInputDevice: + """ + A class to store the configuration of an audio device + + Attributes + ---------- + sample_rate (int): Device sample rate + consumer_sampling_rate (int): The sampling rate of the consumer + window_size_samples (int): The size of the window in samples + target_sampling_rate (int): The target sampling rate + dtype (str): The data type of the audio samples + """ + + def __init__(self, config: AudioInputDeviceConfig): + self.sample_rate = sd.query_devices( + device=config["device_number"], kind="input" + )[ + "default_samplerate" + ] # type: ignore + self.consumer_sampling_rate = config["consumer_sampling_rate"] + self.window_size_samples = int( + config["block_size"] * self.sample_rate / config["consumer_sampling_rate"] + ) + self.target_sampling_rate = int(config["target_sampling_rate"]) + self.dtype = config["dtype"] + + +class StreamingAudioInputDevice(BaseConnector): + def __init__(self): + self.streams = {} + sd.default.latency = ("low", "low") + self.configred_devices: dict[str, ConfiguredAudioInputDevice] = {} + + def configure_device(self, target: str, config: AudioInputDeviceConfig): + if target.isdigit(): + if config.get("device_number") is None: + config["device_number"] = int(target) + elif config["device_number"] != int(target): + raise SoundDeviceError( + "device_number in config must be the same as target" + ) + self.configred_devices[target] = ConfiguredAudioInputDevice(config) + else: + raise SoundDeviceError("target must be a device number!") + + def send_message(self, msg: BaseMessage, target: str) -> None: + raise SoundDeviceError( + "StreamingAudioInputDevice does not suport sending messages" + ) + + def receive_message(self, source: str) -> BaseMessage: + raise SoundDeviceError( + "StreamingAudioInputDevice does not suport receiving messages messages" + ) + + def send_and_wait(self, target: str) -> BaseMessage: + raise SoundDeviceError( + "StreamingAudioInputDevice does not suport sending messages" + ) + + def start_action( + self, + target: str, + on_feedback: Callable[[np.ndarray, dict[str, Any]], None], + on_finish: Callable = lambda _: None, + ) -> str: + + target_device = self.configred_devices.get(target) + if target_device is None: + raise SoundDeviceError(f"Device {target} has not been configured") + + def callback(indata: np.ndarray, frames: int, _, status: CallbackFlags): + indata = indata.flatten() + sample_time_length = len(indata) / target_device.target_sampling_rate + if target_device.sample_rate != target_device.target_sampling_rate: + indata = resample(indata, int(sample_time_length * target_device.target_sampling_rate)) # type: ignore + flag_dict = { + "input_overflow": status.input_overflow, + "input_underflow": status.input_underflow, + "output_overflow": status.output_overflow, + "output_underflow": status.output_underflow, + "priming_output": status.priming_output, + } + on_feedback(indata, flag_dict) + + handle = self._generate_handle() + try: + stream = sd.InputStream( + samplerate=target_device.sample_rate, + channels=1, + device=int(target), + dtype=target_device.dtype, + blocksize=target_device.window_size_samples, + callback=callback, + finished_callback=on_finish, + ) + except AttributeError: + raise SoundDeviceError(f"Device {target} has not been correctly configured") + stream.start() + self.streams[handle] = stream + return handle + + def terminate_action(self, action_handle: str): + self.streams[action_handle].stop() diff --git a/tests/communication/test_sound_device_connector.py b/tests/communication/test_sound_device_connector.py new file mode 100644 index 000000000..cb6d72101 --- /dev/null +++ b/tests/communication/test_sound_device_connector.py @@ -0,0 +1,114 @@ +# Copyright (C) 2024 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest import mock + +import pytest +import sounddevice as sd + +from rai.communication import SoundDeviceError, StreamingAudioInputDevice + + +@pytest.fixture +def setup_mock_input_stream(): + with mock.patch("sounddevice.InputStream") as mock_input_stream: + yield mock_input_stream + + +@pytest.fixture +def device_config(): + return { + "block_size": 1024, + "consumer_sampling_rate": 44100, + "target_sampling_rate": 16000, + "dtype": "float32", + } + + +@pytest.mark.ci_only +def test_configure( + setup_mock_input_stream, + device_config, +): + mock_input_stream = setup_mock_input_stream + mock_instance = mock.MagicMock() + mock_input_stream.return_value = mock_instance + audio_input_device = StreamingAudioInputDevice() + device = sd.query_devices(kind="input") + if type(device) is dict: + device_id = str(device["index"]) + elif isinstance(device, list): + device_id = str(device[0]["index"]) # type: ignore + else: + raise AssertionError("No input device found") + audio_input_device.configure_device(device_id, device_config) + assert ( + audio_input_device.configred_devices[device_id].consumer_sampling_rate == 44100 + ) + assert audio_input_device.configred_devices[device_id].window_size_samples == 1024 + assert audio_input_device.configred_devices[device_id].target_sampling_rate == 16000 + assert audio_input_device.configred_devices[device_id].dtype == "float32" + + +@pytest.mark.ci_only +def test_start_action_failed_init( + setup_mock_input_stream, +): + mock_input_stream = setup_mock_input_stream + mock_instance = mock.MagicMock() + mock_input_stream.return_value = mock_instance + audio_input_device = StreamingAudioInputDevice() + + feedback_callback = mock.MagicMock() + finish_callback = mock.MagicMock() + + recording_device = 0 + with pytest.raises(SoundDeviceError, match="Device 0 has not been configured"): + _ = audio_input_device.start_action( + str(recording_device), feedback_callback, finish_callback + ) + + +@pytest.mark.ci_only +def test_start_action( + setup_mock_input_stream, + device_config, +): + mock_input_stream = setup_mock_input_stream + mock_instance = mock.MagicMock() + mock_input_stream.return_value = mock_instance + audio_input_device = StreamingAudioInputDevice() + + feedback_callback = mock.MagicMock() + finish_callback = mock.MagicMock() + + device = sd.query_devices(kind="input") + if type(device) is dict: + device_id = str(device["index"]) + elif isinstance(device, list): + device_id = str(device[0]["index"]) # type: ignore + else: + raise AssertionError("No input device found") + audio_input_device.configure_device(device_id, device_config) + + stream_handle = audio_input_device.start_action( + device_id, feedback_callback, finish_callback + ) + + assert mock_input_stream.call_count == 1 + init_args = mock_input_stream.call_args.kwargs + assert init_args["device"] == int(device_id) + assert init_args["finished_callback"] == finish_callback + + assert audio_input_device.streams.get(stream_handle) is not None