Skip to content

Commit b2dfc69

Browse files
committed
refactor: using ready funtion in service validator
1 parent cfdf1ee commit b2dfc69

File tree

1 file changed

+3
-44
lines changed

1 file changed

+3
-44
lines changed

src/rai_bench/rai_bench/tool_calling_agent/mocked_tools.py

Lines changed: 3 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
# limitations under the License.
1414

1515
import copy
16-
import importlib
1716
import uuid
1817
from threading import Lock
1918
from typing import Any, Dict, List, Optional, Tuple, Type
@@ -23,6 +22,7 @@
2322
import numpy.typing as npt
2423
from langchain_core.tools import BaseTool
2524
from pydantic import BaseModel, ValidationError, computed_field
25+
from rai.communication.ros2.api.conversion import import_message_from_str
2626
from rai.communication.ros2.connectors import ROS2Connector
2727
from rai.communication.ros2.messages import ROS2Message
2828
from rai.messages import MultimodalArtifact, preprocess_image
@@ -270,7 +270,7 @@ def _run(self, msg_type: str) -> str:
270270

271271
class ServiceValidator:
272272
"""
273-
Validator that is responsible for checking if gicen service type exists
273+
Validator that is responsible for checking if given service type exists
274274
and if it is used correctly.
275275
Validator uses ROS 2 native types when available,
276276
falls back to Pydantic models of custom interfaces when not.
@@ -280,47 +280,6 @@ def __init__(self, custom_models: Dict[str, Type[BaseModel]]):
280280
self.custom_models = custom_models
281281
self.ros2_services_cache: Dict[str, Any] = {}
282282

283-
def get_ros2_service_class(self, service_type: str):
284-
"""
285-
Dynamically import ROS 2 service class.
286-
287-
Parameters
288-
----------
289-
service_type : str
290-
The ROS 2 service type in format "package_name/srv/ServiceName".
291-
292-
Returns
293-
-------
294-
class
295-
The dynamically imported ROS 2 service class.
296-
297-
Raises
298-
------
299-
ValueError
300-
If the service_type format is invalid (not in format
301-
"package_name/srv/ServiceName").
302-
303-
Notes
304-
-----
305-
Results are cached in ros2_services_cache to avoid repeated imports
306-
of the same service type.
307-
"""
308-
if service_type in self.ros2_services_cache:
309-
return self.ros2_services_cache[service_type]
310-
311-
# Parse service type: "package_name/srv/ServiceName"
312-
parts = service_type.split("/")
313-
if len(parts) != 3 or parts[1] != "srv":
314-
raise ValueError(f"Service type: {service_type} is invalid")
315-
316-
package_name, _, service_name = parts
317-
318-
module = importlib.import_module(f"{package_name}.srv")
319-
service_class = getattr(module, service_name)
320-
321-
self.ros2_services_cache[service_type] = service_class
322-
return service_class
323-
324283
def validate_with_ros2(self, service_type: str, args: Dict[str, Any]):
325284
"""Validate using installed ROS2 packages services definition
326285
@@ -335,7 +294,7 @@ def validate_with_ros2(self, service_type: str, args: Dict[str, Any]):
335294
TypeError
336295
When service type does not exist in ROS2 installed packages
337296
"""
338-
service_class = self.get_ros2_service_class(service_type)
297+
service_class = import_message_from_str(service_type)
339298
if not service_class:
340299
raise TypeError(f"Service type: {service_type} does not exist.")
341300

0 commit comments

Comments
 (0)