diff --git a/docs/docs.json b/docs/docs.json index cb931269..17d49594 100644 --- a/docs/docs.json +++ b/docs/docs.json @@ -125,6 +125,7 @@ { "group": "Patterns", "pages": [ + "patterns/tool-transformation", "patterns/decorating-methods", "patterns/http-requests", "patterns/testing", diff --git a/docs/patterns/tool-transformation.mdx b/docs/patterns/tool-transformation.mdx new file mode 100644 index 00000000..b3c85dc6 --- /dev/null +++ b/docs/patterns/tool-transformation.mdx @@ -0,0 +1,430 @@ +--- +title: Tool Transformation +sidebarTitle: Tool Transformation +description: Create enhanced tool variants with modified schemas, argument mappings, and custom behavior. +icon: wand-magic-sparkles +--- + +import { VersionBadge } from '/snippets/version-badge.mdx' + + + +Tool transformation allows you to create new, enhanced tools from existing ones. This powerful feature enables you to adapt tools for different contexts, simplify complex interfaces, or add custom logic without duplicating code. + +## Why Transform Tools? + +Often, an existing tool is *almost* perfect for your use case, but it might have: +- A confusing description (or no description at all). +- Argument names or descriptions that are not intuitive for an LLM (e.g., `q` instead of `query`). +- Unnecessary parameters that you want to hide from the LLM. +- A need for input validation before the original tool is called. +- A need to modify or format the tool's output. + +Instead of rewriting the tool from scratch, you can **transform** it to fit your needs. + +## Basic Transformation + +The primary way to create a transformed tool is with the `Tool.from_tool()` class method. At its simplest, you can use it to change a tool's top-level metadata like its `name`, `description`, or `tags`. + +In the following simple example, we take a generic `search` tool and adjust its name and description to help an LLM client better understand its purpose. + +```python {13-21} +from fastmcp import FastMCP +from fastmcp.tools import Tool + +mcp = FastMCP() + +# The original, generic tool +@mcp.tool +def search(query: str, category: str = "all") -> list[dict]: + """Searches for items in the database.""" + return database.search(query, category) + +# Create a more domain-specific version by changing its metadata +product_search_tool = Tool.from_tool( + search, + name="find_products", + description=""" + Search for products in the e-commerce catalog. + Use this when customers ask about finding specific items, + checking availability, or browsing product categories. + """, +) + +mcp.add_tool(product_search_tool) +``` +Now, clients see a tool named `find_products` with a clear, domain-specific purpose and relevant tags, even though it still uses the original generic `search` function's logic. + +### Parameters + +The `Tool.from_tool()` class method is the primary way to create a transformed tool. It takes the following parameters: + +- `tool`: The tool to transform. This is the only required argument. +- `name`: An optional name for the new tool. +- `description`: An optional description for the new tool. +- `transform_args`: A dictionary of `ArgTransform` objects, one for each argument you want to modify. +- `transform_fn`: An optional function that will be called instead of the parent tool's logic. +- `tags`: An optional set of tags for the new tool. +- `annotations`: An optional set of `ToolAnnotations` for the new tool. +- `serializer`: An optional function that will be called to serialize the result of the new tool. + +The result is a new `TransformedTool` object that wraps the parent tool and applies the transformations you specify. You can add this tool to your MCP server using its `add_tool()` method. + + + +## Modifying Arguments + +To modify a tool's parameters, provide a dictionary of `ArgTransform` objects to the `transform_args` parameter of `Tool.from_tool()`. Each key is the name of the *original* argument you want to modify. + + +You only need to provide a `transform_args` entry for arguments you want to modify. All other arguments will be passed through unchanged. + + +### The ArgTransform Class + +To modify an argument, you need to create an `ArgTransform` object. This object has the following parameters: + +- `name`: The new name for the argument. +- `description`: The new description for the argument. +- `default`: The new default value for the argument. +- `default_factory`: A function that will be called to generate a default value for the argument. This is useful for arguments that need to be generated for each tool call, such as timestamps or unique IDs. +- `hide`: Whether to hide the argument from the LLM. +- `required`: Whether the argument is required, usually used to make an optional argument be required instead. +- `type`: The new type for the argument. + + +Certain combinations of parameters are not allowed. For example, you can only use `default_factory` with `hide=True`, because dynamic defaults cannot be represented in a JSON schema for the client. You can only set required=True for arguments that do not declare a default value. + + + +### Descriptions + +By far the most common reason to transform a tool, after its own description, is to improve its argument descriptions. A good description is crucial for helping an LLM understand how to use a parameter correctly. This is especially important when wrapping tools from external APIs, whose argument descriptions may be missing or written for developers, not LLMs. + +In this example, we add a helpful description to the `user_id` argument: + +```python {16-19} +from fastmcp import FastMCP +from fastmcp.tools import Tool +from fastmcp.tools.tool_transform import ArgTransform + +mcp = FastMCP() + +@mcp.tool +def find_user(user_id: str): + """Finds a user by their ID.""" + ... + +new_tool = Tool.from_tool( + find_user, + transform_args={ + "user_id": ArgTransform( + description=( + "The unique identifier for the user, " + "usually in the format 'usr-xxxxxxxx'." + ) + ) + } +) +``` + +### Names + +At times, you may want to rename an argument to make it more intuitive for an LLM. + +For example, in the following example, we take a generic `q` argument and expand it to `search_query`: + +```python {15} +from fastmcp import FastMCP +from fastmcp.tools import Tool +from fastmcp.tools.tool_transform import ArgTransform + +mcp = FastMCP() + +@mcp.tool +def search(q: str): + """Searches for items in the database.""" + return database.search(q) + +new_tool = Tool.from_tool( + search, + transform_args={ + "q": ArgTransform(name="search_query") + } +) +``` + +### Default Values + +You can update the default value for any argument using the `default` parameter. Here, we change the default value of the `y` argument to 10: + +```python{15} +from fastmcp import FastMCP +from fastmcp.tools import Tool +from fastmcp.tools.tool_transform import ArgTransform + +mcp = FastMCP() + +@mcp.tool +def add(x: int, y: int) -> int: + """Adds two numbers.""" + return x + y + +new_tool = Tool.from_tool( + add, + transform_args={ + "y": ArgTransform(default=10) + } +) +``` + +Default values are especially useful in combination with hidden arguments. + +### Hiding Arguments + +Sometimes a tool requires arguments that shouldn't be exposed to the LLM, such as API keys, configuration flags, or internal IDs. You can hide these parameters using `hide=True`. Note that you can only hide arguments that have a default value (or for which you provide a new default), because the LLM can't provide a value at call time. + + +To pass a constant value to the parent tool, combine `hide=True` with `default=`. + + +```python {19-20} +import os +from fastmcp import FastMCP +from fastmcp.tools import Tool +from fastmcp.tools.tool_transform import ArgTransform + +mcp = FastMCP() + +@mcp.tool +def send_email(to: str, subject: str, body: str, api_key: str): + """Sends an email.""" + ... + +# Create a simplified version that hides the API key +new_tool = Tool.from_tool( + send_email, + name="send_notification", + transform_args={ + "api_key": ArgTransform( + hide=True, + default=os.environ.get("EMAIL_API_KEY"), + ) + } +) +``` +The LLM now only sees the `to`, `subject`, and `body` parameters. The `api_key` is supplied automatically from an environment variable. + +For values that must be generated for each tool call (like timestamps or unique IDs), use `default_factory`, which is called with no arguments every time the tool is called. For example, + +```python {3-4} +transform_args = { + 'timestamp': ArgTransform( + hide=True, + default_factory=lambda: datetime.now(), + ) +} +``` + + +`default_factory` can only be used with `hide=True`. This is because visible parameters need static defaults that can be represented in a JSON schema for the client. + + +### Required Values + +In rare cases where you want to make an optional argument required, you can set `required=True`. This has no effect if the argument was already required. + +```python {3} +transform_args = { + 'user_id': ArgTransform( + required=True, + ) +} +``` + +## Modifying Tool Behavior + + +With great power comes great responsibility. Modifying tool behavior is a very advanced feature. + + +In addition to changing a tool's schema, advanced users can also modify its behavior. This is useful for adding validation logic, or for post-processing the tool's output. + +The `from_tool()` method takes a `transform_fn` parameter, which is an async function that replaces the parent tool's logic and gives you complete control over the tool's execution. + +### The Transform Function + +The `transform_fn` is an async function that **completely replaces** the parent tool's logic. + +Critically, the transform function's arguments are used to determine the new tool's final schema. Any arguments that are not already present in the parent tool schema OR the `transform_args` will be added to the new tool's schema. Note that when `transform_args` and your function have the same argument name, the `transform_args` metadata will take precedence, if provided. + +```python +async def my_custom_logic(user_input: str, max_length: int = 100) -> str: + # Your custom logic here - this completely replaces the parent tool + return f"Custom result for: {user_input[:max_length]}" + +Tool.from_tool(transform_fn=my_custom_logic) +``` + + +The name / docstring of the `transform_fn` are ignored. Only its arguments are used to determine the final schema. + + +### Calling the Parent Tool + +Most of the time, you don't want to completely replace the parent tool's behavior. Instead, you want to add validation, modify inputs, or post-process outputs while still leveraging the parent tool's core functionality. For this, FastMCP provides the special `forward()` and `forward_raw()` functions. + +Both `forward()` and `forward_raw()` are async functions that let you call the parent tool from within your `transform_fn`: + +- **`forward()`** (recommended): Automatically handles argument mapping based on your `ArgTransform` configurations. Call it with the transformed argument names. +- **`forward_raw()`**: Bypasses all transformation and calls the parent tool directly with its original argument names. This is rarely needed unless you're doing complex argument manipulation, perhaps without `arg_transforms`. + +The most common transformation pattern is to validate (potentially renamed) arguments before calling the parent tool. Here's an example that validates that `x` and `y` are positive before calling the parent tool: + + + +In the simplest case, your parent tool and your transform function have the same arguments. You can call `forward()` with the same argument names as the parent tool: + +```python {15} +from fastmcp import FastMCP +from fastmcp.tools import Tool +from fastmcp.tools.tool_transform import forward + +mcp = FastMCP() + +@mcp.tool +def add(x: int, y: int) -> int: + """Adds two numbers.""" + return x + y + +async def ensure_positive(x: int, y: int) -> int: + if x <= 0 or y <= 0: + raise ValueError("x and y must be positive") + return await forward(x=x, y=y) + +new_tool = Tool.from_tool( + add, + transform_fn=ensure_positive, +) + +mcp.add_tool(new_tool) +``` + + + +When your transformed tool has different argument names than the parent tool, you can call `forward()` with the renamed arguments and it will automatically map the arguments to the parent tool's arguments: + +```python {15, 20-23} +from fastmcp import FastMCP +from fastmcp.tools import Tool +from fastmcp.tools.tool_transform import forward + +mcp = FastMCP() + +@mcp.tool +def add(x: int, y: int) -> int: + """Adds two numbers.""" + return x + y + +async def ensure_positive(a: int, b: int) -> int: + if a <= 0 or b <= 0: + raise ValueError("a and b must be positive") + return await forward(a=a, b=b) + +new_tool = Tool.from_tool( + add, + transform_fn=ensure_positive, + transform_args={ + "x": ArgTransform(name="a"), + "y": ArgTransform(name="b"), + } +) + +mcp.add_tool(new_tool) +``` + + +Finally, you can use `forward_raw()` to bypass all argument mapping and call the parent tool directly with its original argument names. + +```python {15, 20-23} +from fastmcp import FastMCP +from fastmcp.tools import Tool +from fastmcp.tools.tool_transform import forward + +mcp = FastMCP() + +@mcp.tool +def add(x: int, y: int) -> int: + """Adds two numbers.""" + return x + y + +async def ensure_positive(a: int, b: int) -> int: + if a <= 0 or b <= 0: + raise ValueError("a and b must be positive") + return await forward_raw(x=a, y=b) + +new_tool = Tool.from_tool( + add, + transform_fn=ensure_positive, + transform_args={ + "x": ArgTransform(name="a"), + "y": ArgTransform(name="b"), + } +) + +mcp.add_tool(new_tool) +``` + + + +### Passing Arguments with **kwargs + +If your `transform_fn` includes `**kwargs` in its signature, it will receive **all arguments from the parent tool after `ArgTransform` configurations have been applied**. This is powerful for creating flexible validation functions that don't require you to add every argument to the function signature. + +In the following example, we wrap a parent tool that accepts two arguments `x` and `y`. These are renamed to `a` and `b` in the transformed tool, and the transform only validates `a`, passing the other argument through as `**kwargs`. + +```python {12, 15} +from fastmcp import FastMCP +from fastmcp.tools import Tool +from fastmcp.tools.tool_transform import forward + +mcp = FastMCP() + +@mcp.tool +def add(x: int, y: int) -> int: + """Adds two numbers.""" + return x + y + +async def ensure_a_positive(a: int, **kwargs) -> int: + if a <= 0: + raise ValueError("a must be positive") + return await forward(a=a, **kwargs) + +new_tool = Tool.from_tool( + add, + transform_fn=ensure_a_positive, + transform_args={ + "x": ArgTransform(name="a"), + "y": ArgTransform(name="b"), + } +) + +mcp.add_tool(new_tool) +``` + + +In the above example, `**kwargs` receives the renamed argument `b`, not the original argument `y`. It is therefore recommended to use with `forward()`, not `forward_raw()`. + + +## Common Patterns + +Tool transformation is a flexible feature that supports many powerful patterns. Here are a few common use cases to give you ideas. + +### Adapting Remote or Generated Tools +This is one of the most common reasons to use tool transformation. Tools from remote servers (via a [proxy](/servers/proxy)) or generated from an [OpenAPI spec](/servers/openapi) are often too generic for direct use by an LLM. You can use transformation to create a simpler, more intuitive version for your specific needs. + +### Chaining Transformations +You can chain transformations by using an already transformed tool as the parent for a new transformation. This lets you build up complex behaviors in layers, for example, first renaming arguments, and then adding validation logic to the renamed tool. + +### Context-Aware Tool Factories +You can write functions that act as "factories," generating specialized versions of a tool for different contexts. For example, you could create a `get_my_data` tool that is specific to the currently logged-in user by hiding the `user_id` parameter and providing it automatically. diff --git a/src/fastmcp/server/auth/providers/bearer_env.py b/src/fastmcp/server/auth/providers/bearer_env.py index 96cf15cf..308f5fcd 100644 --- a/src/fastmcp/server/auth/providers/bearer_env.py +++ b/src/fastmcp/server/auth/providers/bearer_env.py @@ -1,13 +1,10 @@ +from types import EllipsisType + from pydantic_settings import BaseSettings, SettingsConfigDict from fastmcp.server.auth.providers.bearer import BearerAuthProvider -# Sentinel object to indicate that a setting is not set -class _NotSet: - pass - - class EnvBearerAuthProviderSettings(BaseSettings): """Settings for the BearerAuthProvider.""" @@ -33,11 +30,11 @@ class EnvBearerAuthProvider(BearerAuthProvider): def __init__( self, - public_key: str | None | type[_NotSet] = _NotSet, - jwks_uri: str | None | type[_NotSet] = _NotSet, - issuer: str | None | type[_NotSet] = _NotSet, - audience: str | None | type[_NotSet] = _NotSet, - required_scopes: list[str] | None | type[_NotSet] = _NotSet, + public_key: str | None | EllipsisType = ..., + jwks_uri: str | None | EllipsisType = ..., + issuer: str | None | EllipsisType = ..., + audience: str | None | EllipsisType = ..., + required_scopes: list[str] | None | EllipsisType = ..., ): """ Initialize the provider. @@ -57,6 +54,6 @@ def __init__( "required_scopes": required_scopes, } settings = EnvBearerAuthProviderSettings( - **{k: v for k, v in kwargs.items() if v is not _NotSet} + **{k: v for k, v in kwargs.items() if v is not ...} ) super().__init__(**settings.model_dump()) diff --git a/src/fastmcp/server/openapi.py b/src/fastmcp/server/openapi.py index e156e9bd..58a0d503 100644 --- a/src/fastmcp/server/openapi.py +++ b/src/fastmcp/server/openapi.py @@ -226,7 +226,6 @@ def __init__( tags: set[str] = set(), timeout: float | None = None, annotations: ToolAnnotations | None = None, - exclude_args: list[str] | None = None, serializer: Callable[[Any], str] | None = None, ): super().__init__( @@ -235,7 +234,6 @@ def __init__( parameters=parameters, tags=tags, annotations=annotations, - exclude_args=exclude_args, serializer=serializer, ) self._client = client diff --git a/src/fastmcp/server/server.py b/src/fastmcp/server/server.py index d033aa1b..5cda3cbb 100644 --- a/src/fastmcp/server/server.py +++ b/src/fastmcp/server/server.py @@ -268,6 +268,12 @@ async def get_tools(self) -> dict[str, Tool]: self._cache.set("tools", tools) return tools + async def get_tool(self, key: str) -> Tool: + tools = await self.get_tools() + if key not in tools: + raise NotFoundError(f"Unknown tool: {key}") + return tools[key] + async def get_resources(self) -> dict[str, Resource]: """Get all registered resources, indexed by registered key.""" if (resources := self._cache.get("resources")) is self._cache.NOT_FOUND: diff --git a/src/fastmcp/tools/__init__.py b/src/fastmcp/tools/__init__.py index e659a778..8fa72391 100644 --- a/src/fastmcp/tools/__init__.py +++ b/src/fastmcp/tools/__init__.py @@ -1,4 +1,5 @@ from .tool import Tool, FunctionTool from .tool_manager import ToolManager +from .tool_transform import forward, forward_raw -__all__ = ["Tool", "ToolManager", "FunctionTool"] +__all__ = ["Tool", "ToolManager", "FunctionTool", "forward", "forward_raw"] diff --git a/src/fastmcp/tools/tool.py b/src/fastmcp/tools/tool.py index aaa7ff35..1457bb3e 100644 --- a/src/fastmcp/tools/tool.py +++ b/src/fastmcp/tools/tool.py @@ -4,6 +4,7 @@ import json from abc import ABC, abstractmethod from collections.abc import Callable +from dataclasses import dataclass from typing import TYPE_CHECKING, Annotated, Any import pydantic_core @@ -24,7 +25,7 @@ ) if TYPE_CHECKING: - pass + from fastmcp.tools.tool_transform import ArgTransform, TransformedTool logger = get_logger(__name__) @@ -47,10 +48,6 @@ class Tool(FastMCPBaseModel, ABC): annotations: ToolAnnotations | None = Field( default=None, description="Additional annotations about the tool" ) - exclude_args: list[str] | None = Field( - default=None, - description="Arguments to exclude from the tool schema, such as State, Memory, or Credential", - ) serializer: Callable[[Any], str] | None = Field( default=None, description="Optional custom serializer for tool results" ) @@ -98,6 +95,31 @@ async def run( """Run the tool with arguments.""" raise NotImplementedError("Subclasses must implement run()") + @classmethod + def from_tool( + cls, + tool: Tool, + transform_fn: Callable[..., Any] | None = None, + name: str | None = None, + transform_args: dict[str, ArgTransform] | None = None, + description: str | None = None, + tags: set[str] | None = None, + annotations: ToolAnnotations | None = None, + serializer: Callable[[Any], str] | None = None, + ) -> TransformedTool: + from fastmcp.tools.tool_transform import TransformedTool + + return TransformedTool.from_tool( + tool=tool, + transform_fn=transform_fn, + name=name, + transform_args=transform_args, + description=description, + tags=tags, + annotations=annotations, + serializer=serializer, + ) + class FunctionTool(Tool): fn: Callable[..., Any] @@ -114,62 +136,19 @@ def from_function( serializer: Callable[[Any], str] | None = None, ) -> FunctionTool: """Create a Tool from a function.""" - from fastmcp.server.context import Context - - # Reject functions with *args or **kwargs - sig = inspect.signature(fn) - for param in sig.parameters.values(): - if param.kind == inspect.Parameter.VAR_POSITIONAL: - raise ValueError("Functions with *args are not supported as tools") - if param.kind == inspect.Parameter.VAR_KEYWORD: - raise ValueError("Functions with **kwargs are not supported as tools") - - if exclude_args: - for arg_name in exclude_args: - if arg_name not in sig.parameters: - raise ValueError( - f"Parameter '{arg_name}' in exclude_args does not exist in function." - ) - param = sig.parameters[arg_name] - if param.default == inspect.Parameter.empty: - raise ValueError( - f"Parameter '{arg_name}' in exclude_args must have a default value." - ) - func_name = name or getattr(fn, "__name__", None) or fn.__class__.__name__ + parsed_fn = ParsedFunction.from_function(fn, exclude_args=exclude_args) - if func_name == "": + if name is None and parsed_fn.name == "": raise ValueError("You must provide a name for lambda functions") - func_doc = description or fn.__doc__ - - # if the fn is a callable class, we need to get the __call__ method from here out - if not inspect.isroutine(fn): - fn = fn.__call__ - # if the fn is a staticmethod, we need to work with the underlying function - if isinstance(fn, staticmethod): - fn = fn.__func__ - - type_adapter = get_cached_typeadapter(fn) - schema = type_adapter.json_schema() - - prune_params: list[str] = [] - context_kwarg = find_kwarg_by_type(fn, kwarg_type=Context) - if context_kwarg: - prune_params.append(context_kwarg) - if exclude_args: - prune_params.extend(exclude_args) - - schema = compress_schema(schema, prune_params=prune_params) - return cls( - fn=fn, - name=func_name, - description=func_doc, - parameters=schema, + fn=parsed_fn.fn, + name=name or parsed_fn.name, + description=description or parsed_fn.description, + parameters=parsed_fn.parameters, tags=tags or set(), annotations=annotations, - exclude_args=exclude_args, serializer=serializer, ) @@ -222,6 +201,76 @@ async def run( return _convert_to_content(result, serializer=self.serializer) +@dataclass +class ParsedFunction: + fn: Callable[..., Any] + name: str + description: str | None + parameters: dict[str, Any] + + @classmethod + def from_function( + cls, + fn: Callable[..., Any], + exclude_args: list[str] | None = None, + validate: bool = True, + ) -> ParsedFunction: + from fastmcp.server.context import Context + + if validate: + sig = inspect.signature(fn) + # Reject functions with *args or **kwargs + for param in sig.parameters.values(): + if param.kind == inspect.Parameter.VAR_POSITIONAL: + raise ValueError("Functions with *args are not supported as tools") + if param.kind == inspect.Parameter.VAR_KEYWORD: + raise ValueError( + "Functions with **kwargs are not supported as tools" + ) + + # Reject exclude_args that don't exist in the function or don't have a default value + if exclude_args: + for arg_name in exclude_args: + if arg_name not in sig.parameters: + raise ValueError( + f"Parameter '{arg_name}' in exclude_args does not exist in function." + ) + param = sig.parameters[arg_name] + if param.default == inspect.Parameter.empty: + raise ValueError( + f"Parameter '{arg_name}' in exclude_args must have a default value." + ) + + # collect name and doc before we potentially modify the function + fn_name = getattr(fn, "__name__", None) or fn.__class__.__name__ + fn_doc = fn.__doc__ + + # if the fn is a callable class, we need to get the __call__ method from here out + if not inspect.isroutine(fn): + fn = fn.__call__ + # if the fn is a staticmethod, we need to work with the underlying function + if isinstance(fn, staticmethod): + fn = fn.__func__ + + type_adapter = get_cached_typeadapter(fn) + schema = type_adapter.json_schema() + + prune_params: list[str] = [] + context_kwarg = find_kwarg_by_type(fn, kwarg_type=Context) + if context_kwarg: + prune_params.append(context_kwarg) + if exclude_args: + prune_params.extend(exclude_args) + + schema = compress_schema(schema, prune_params=prune_params) + return cls( + fn=fn, + name=fn_name, + description=fn_doc, + parameters=schema, + ) + + def _convert_to_content( result: Any, serializer: Callable[[Any], str] | None = None, diff --git a/src/fastmcp/tools/tool_transform.py b/src/fastmcp/tools/tool_transform.py new file mode 100644 index 00000000..b193d490 --- /dev/null +++ b/src/fastmcp/tools/tool_transform.py @@ -0,0 +1,663 @@ +from __future__ import annotations + +import inspect +from collections.abc import Callable +from contextvars import ContextVar +from dataclasses import dataclass +from types import EllipsisType +from typing import Any, Literal + +from mcp.types import EmbeddedResource, ImageContent, TextContent, ToolAnnotations +from pydantic import ConfigDict + +from fastmcp.tools.tool import ParsedFunction, Tool +from fastmcp.utilities.logging import get_logger +from fastmcp.utilities.types import get_cached_typeadapter + +logger = get_logger(__name__) + +NotSet = ... + + +# Context variable to store current transformed tool +_current_tool: ContextVar[TransformedTool | None] = ContextVar( + "_current_tool", default=None +) + + +async def forward(**kwargs) -> Any: + """Forward to parent tool with argument transformation applied. + + This function can only be called from within a transformed tool's custom + function. It applies argument transformation (renaming, validation) before + calling the parent tool. + + For example, if the parent tool has args `x` and `y`, but the transformed + tool has args `a` and `b`, and an `transform_args` was provided that maps `x` to + `a` and `y` to `b`, then `forward(a=1, b=2)` will call the parent tool with + `x=1` and `y=2`. + + Args: + **kwargs: Arguments to forward to the parent tool (using transformed names). + + Returns: + The result from the parent tool execution. + + Raises: + RuntimeError: If called outside a transformed tool context. + TypeError: If provided arguments don't match the transformed schema. + """ + tool = _current_tool.get() + if tool is None: + raise RuntimeError("forward() can only be called within a transformed tool") + + # Use the forwarding function that handles mapping + return await tool.forwarding_fn(**kwargs) + + +async def forward_raw(**kwargs) -> Any: + """Forward directly to parent tool without transformation. + + This function bypasses all argument transformation and validation, calling the parent + tool directly with the provided arguments. Use this when you need to call the parent + with its original parameter names and structure. + + For example, if the parent tool has args `x` and `y`, then `forward_raw(x=1, + y=2)` will call the parent tool with `x=1` and `y=2`. + + Args: + **kwargs: Arguments to pass directly to the parent tool (using original names). + + Returns: + The result from the parent tool execution. + + Raises: + RuntimeError: If called outside a transformed tool context. + """ + tool = _current_tool.get() + if tool is None: + raise RuntimeError("forward_raw() can only be called within a transformed tool") + + return await tool.parent_tool.run(kwargs) + + +@dataclass(kw_only=True) +class ArgTransform: + """Configuration for transforming a parent tool's argument. + + This class allows fine-grained control over how individual arguments are transformed + when creating a new tool from an existing one. You can rename arguments, change their + descriptions, add default values, or hide them from clients while passing constants. + + Attributes: + name: New name for the argument. Use None to keep original name, or ... for no change. + description: New description for the argument. Use None to remove description, or ... for no change. + default: New default value for the argument. Use ... for no change. + default_factory: Callable that returns a default value. Cannot be used with default. + type: New type for the argument. Use ... for no change. + hide: If True, hide this argument from clients but pass a constant value to parent. + required: If True, make argument required (remove default). Use ... for no change. + + Examples: + # Rename argument 'old_name' to 'new_name' + ArgTransform(name="new_name") + + # Change description only + ArgTransform(description="Updated description") + + # Add a default value (makes argument optional) + ArgTransform(default=42) + + # Add a default factory (makes argument optional) + ArgTransform(default_factory=lambda: time.time()) + + # Change the type + ArgTransform(type=str) + + # Hide the argument entirely from clients + ArgTransform(hide=True) + + # Hide argument but pass a constant value to parent + ArgTransform(hide=True, default="constant_value") + + # Hide argument but pass a factory-generated value to parent + ArgTransform(hide=True, default_factory=lambda: uuid.uuid4().hex) + + # Make an optional parameter required (removes any default) + ArgTransform(required=True) + + # Combine multiple transformations + ArgTransform(name="new_name", description="New desc", default=None, type=int) + """ + + name: str | EllipsisType = NotSet + description: str | EllipsisType = NotSet + default: Any | EllipsisType = NotSet + default_factory: Callable[[], Any] | EllipsisType = NotSet + type: Any | EllipsisType = NotSet + hide: bool = False + required: Literal[True] | EllipsisType = NotSet + + def __post_init__(self): + """Validate that only one of default or default_factory is provided.""" + has_default = self.default is not NotSet + has_factory = self.default_factory is not NotSet + + if has_default and has_factory: + raise ValueError( + "Cannot specify both 'default' and 'default_factory' in ArgTransform. " + "Use either 'default' for a static value or 'default_factory' for a callable." + ) + + if has_factory and not self.hide: + raise ValueError( + "default_factory can only be used with hide=True. " + "Visible parameters must use static 'default' values since JSON schema " + "cannot represent dynamic factories." + ) + + if self.required is True and (has_default or has_factory): + raise ValueError( + "Cannot specify 'required=True' with 'default' or 'default_factory'. " + "Required parameters cannot have defaults." + ) + + if self.hide and self.required is True: + raise ValueError( + "Cannot specify both 'hide=True' and 'required=True'. " + "Hidden parameters cannot be required since clients cannot provide them." + ) + + if self.required is False: + raise ValueError( + "Cannot specify 'required=False'. Set a default value instead." + ) + + +class TransformedTool(Tool): + """A tool that is transformed from another tool. + + This class represents a tool that has been created by transforming another tool. + It supports argument renaming, schema modification, custom function injection, + and provides context for the forward() and forward_raw() functions. + + The transformation can be purely schema-based (argument renaming, dropping, etc.) + or can include a custom function that uses forward() to call the parent tool + with transformed arguments. + + Attributes: + parent_tool: The original tool that this tool was transformed from. + fn: The function to execute when this tool is called (either the forwarding + function for pure transformations or a custom user function). + forwarding_fn: Internal function that handles argument transformation and + validation when forward() is called from custom functions. + """ + + model_config = ConfigDict(extra="allow", arbitrary_types_allowed=True) + + parent_tool: Tool + fn: Callable[..., Any] + forwarding_fn: Callable[..., Any] # Always present, handles arg transformation + transform_args: dict[str, ArgTransform] + + async def run( + self, arguments: dict[str, Any] + ) -> list[TextContent | ImageContent | EmbeddedResource]: + """Run the tool with context set for forward() functions. + + This method executes the tool's function while setting up the context + that allows forward() and forward_raw() to work correctly within custom + functions. + + Args: + arguments: Dictionary of arguments to pass to the tool's function. + + Returns: + List of content objects (text, image, or embedded resources) representing + the tool's output. + """ + from fastmcp.tools.tool import _convert_to_content + + # Fill in missing arguments with schema defaults to ensure + # ArgTransform defaults take precedence over function defaults + arguments = arguments.copy() + properties = self.parameters.get("properties", {}) + + for param_name, param_schema in properties.items(): + if param_name not in arguments and "default" in param_schema: + # Check if this parameter has a default_factory from transform_args + # We need to call the factory for each run, not use the cached schema value + has_factory_default = False + if self.transform_args: + # Find the original parameter name that maps to this param_name + for orig_name, transform in self.transform_args.items(): + transform_name = ( + transform.name + if transform.name is not NotSet + else orig_name + ) + if ( + transform_name == param_name + and transform.default_factory is not NotSet + ): + # Type check to ensure default_factory is callable + if callable(transform.default_factory): + arguments[param_name] = transform.default_factory() + has_factory_default = True + break + + if not has_factory_default: + arguments[param_name] = param_schema["default"] + + token = _current_tool.set(self) + try: + result = await self.fn(**arguments) + return _convert_to_content(result, serializer=self.serializer) + finally: + _current_tool.reset(token) + + @classmethod + def from_tool( + cls, + tool: Tool, + name: str | None = None, + description: str | None = None, + tags: set[str] | None = None, + transform_fn: Callable[..., Any] | None = None, + transform_args: dict[str, ArgTransform] | None = None, + annotations: ToolAnnotations | None = None, + serializer: Callable[[Any], str] | None = None, + ) -> TransformedTool: + """Create a transformed tool from a parent tool. + + Args: + tool: The parent tool to transform. + transform_fn: Optional custom function. Can use forward() and forward_raw() + to call the parent tool. Functions with **kwargs receive transformed + argument names. + name: New name for the tool. Defaults to parent tool's name. + transform_args: Optional transformations for parent tool arguments. + Only specified arguments are transformed, others pass through unchanged: + - str: Simple rename + - ArgTransform: Complex transformation (rename/description/default/drop) + - None: Drop the argument + description: New description. Defaults to parent's description. + tags: New tags. Defaults to parent's tags. + annotations: New annotations. Defaults to parent's annotations. + serializer: New serializer. Defaults to parent's serializer. + + Returns: + TransformedTool with the specified transformations. + + Examples: + # Transform specific arguments only + Tool.from_tool(parent, transform_args={"old": "new"}) # Others unchanged + + # Custom function with partial transforms + async def custom(x: int, y: int) -> str: + result = await forward(x=x, y=y) + return f"Custom: {result}" + + Tool.from_tool(parent, transform_fn=custom, transform_args={"a": "x", "b": "y"}) + + # Using **kwargs (gets all args, transformed and untransformed) + async def flexible(**kwargs) -> str: + result = await forward(**kwargs) + return f"Got: {kwargs}" + + Tool.from_tool(parent, transform_fn=flexible, transform_args={"a": "x"}) + """ + transform_args = transform_args or {} + + # Validate transform_args + parent_params = set(tool.parameters.get("properties", {}).keys()) + unknown_args = set(transform_args.keys()) - parent_params + if unknown_args: + raise ValueError( + f"Unknown arguments in transform_args: {', '.join(sorted(unknown_args))}. " + f"Parent tool has: {', '.join(sorted(parent_params))}" + ) + + # Always create the forwarding transform + schema, forwarding_fn = cls._create_forwarding_transform(tool, transform_args) + + if transform_fn is None: + # User wants pure transformation - use forwarding_fn as the main function + final_fn = forwarding_fn + final_schema = schema + else: + # User provided custom function - merge schemas + parsed_fn = ParsedFunction.from_function(transform_fn, validate=False) + final_fn = transform_fn + + has_kwargs = cls._function_has_kwargs(transform_fn) + + # Validate function parameters against transformed schema + fn_params = set(parsed_fn.parameters.get("properties", {}).keys()) + transformed_params = set(schema.get("properties", {}).keys()) + + if not has_kwargs: + # Without **kwargs, function must declare all transformed params + # Check if function is missing any parameters required after transformation + missing_params = transformed_params - fn_params + if missing_params: + raise ValueError( + f"Function missing parameters required after transformation: " + f"{', '.join(sorted(missing_params))}. " + f"Function declares: {', '.join(sorted(fn_params))}" + ) + + # ArgTransform takes precedence over function signature + # Start with function schema as base, then override with transformed schema + final_schema = cls._merge_schema_with_precedence( + parsed_fn.parameters, schema + ) + else: + # With **kwargs, function can access all transformed params + # ArgTransform takes precedence over function signature + # No validation needed - kwargs makes everything accessible + + # Start with function schema as base, then override with transformed schema + final_schema = cls._merge_schema_with_precedence( + parsed_fn.parameters, schema + ) + + # Additional validation: check for naming conflicts after transformation + if transform_args: + new_names = [] + for old_name, transform in transform_args.items(): + if not transform.hide: + if transform.name is not NotSet: + new_names.append(transform.name) + else: + new_names.append(old_name) + + # Check for duplicate names after transformation + name_counts = {} + for arg_name in new_names: + name_counts[arg_name] = name_counts.get(arg_name, 0) + 1 + + duplicates = [ + arg_name for arg_name, count in name_counts.items() if count > 1 + ] + if duplicates: + raise ValueError( + f"Multiple arguments would be mapped to the same names: " + f"{', '.join(sorted(duplicates))}" + ) + + final_description = description if description is not None else tool.description + + transformed_tool = cls( + fn=final_fn, + forwarding_fn=forwarding_fn, + parent_tool=tool, + name=name or tool.name, + description=final_description, + parameters=final_schema, + tags=tags or tool.tags, + annotations=annotations or tool.annotations, + serializer=serializer or tool.serializer, + transform_args=transform_args, + ) + + return transformed_tool + + @classmethod + def _create_forwarding_transform( + cls, + parent_tool: Tool, + transform_args: dict[str, ArgTransform] | None, + ) -> tuple[dict[str, Any], Callable[..., Any]]: + """Create schema and forwarding function that encapsulates all transformation logic. + + This method builds a new JSON schema for the transformed tool and creates a + forwarding function that validates arguments against the new schema and maps + them back to the parent tool's expected arguments. + + Args: + parent_tool: The original tool to transform. + transform_args: Dictionary defining how to transform each argument. + + Returns: + A tuple containing: + - dict: The new JSON schema for the transformed tool + - Callable: Async function that validates and forwards calls to the parent tool + """ + + # Build transformed schema and mapping + parent_props = parent_tool.parameters.get("properties", {}).copy() + parent_required = set(parent_tool.parameters.get("required", [])) + + new_props = {} + new_required = set() + new_to_old = {} + hidden_defaults = {} # Track hidden parameters with constant values + + for old_name, old_schema in parent_props.items(): + # Check if parameter is in transform_args + if transform_args and old_name in transform_args: + transform = transform_args[old_name] + else: + # Default behavior - pass through (no transformation) + transform = ArgTransform() # Default ArgTransform with no changes + + # Handle hidden parameters with defaults + if transform.hide: + # Validate that hidden parameters without user defaults have parent defaults + has_user_default = ( + transform.default is not NotSet + or transform.default_factory is not NotSet + ) + if not has_user_default and old_name in parent_required: + raise ValueError( + f"Hidden parameter '{old_name}' has no default value in parent tool " + f"and no default or default_factory provided in ArgTransform. Either provide a default " + f"or default_factory in ArgTransform or don't hide required parameters." + ) + if has_user_default: + # Store info for later factory calling or direct value + hidden_defaults[old_name] = transform + # Skip adding to schema (not exposed to clients) + continue + + transform_result = cls._apply_single_transform( + old_name, + old_schema, + transform, + old_name in parent_required, + ) + + if transform_result: + new_name, new_schema, is_required = transform_result + new_props[new_name] = new_schema + new_to_old[new_name] = old_name + if is_required: + new_required.add(new_name) + + schema = { + "type": "object", + "properties": new_props, + "required": list(new_required), + } + + # Create forwarding function that closes over everything it needs + async def _forward(**kwargs): + # Validate arguments + valid_args = set(new_props.keys()) + provided_args = set(kwargs.keys()) + unknown_args = provided_args - valid_args + + if unknown_args: + raise TypeError( + f"Got unexpected keyword argument(s): {', '.join(sorted(unknown_args))}" + ) + + # Check required arguments + missing_args = new_required - provided_args + if missing_args: + raise TypeError( + f"Missing required argument(s): {', '.join(sorted(missing_args))}" + ) + + # Map arguments to parent names + parent_args = {} + for new_name, value in kwargs.items(): + old_name = new_to_old.get(new_name, new_name) + parent_args[old_name] = value + + # Add hidden defaults (constant values for hidden parameters) + for old_name, transform in hidden_defaults.items(): + if transform.default is not NotSet: + parent_args[old_name] = transform.default + elif transform.default_factory is not NotSet: + # Type check to ensure default_factory is callable + if callable(transform.default_factory): + parent_args[old_name] = transform.default_factory() + + return await parent_tool.run(parent_args) + + return schema, _forward + + @staticmethod + def _apply_single_transform( + old_name: str, + old_schema: dict[str, Any], + transform: ArgTransform, + is_required: bool, + ) -> tuple[str, dict[str, Any], bool] | None: + """Apply transformation to a single parameter. + + This method handles the transformation of a single argument according to + the specified transformation rules. + + Args: + old_name: Original name of the parameter. + old_schema: Original JSON schema for the parameter. + transform: ArgTransform object specifying how to transform the parameter. + is_required: Whether the original parameter was required. + + Returns: + Tuple of (new_name, new_schema, new_is_required) if parameter should be kept, + None if parameter should be dropped. + """ + if transform.hide: + return None + + # Handle name transformation - ensure we always have a string + if transform.name is not NotSet: + new_name = transform.name if transform.name is not None else old_name + else: + new_name = old_name + + # Ensure new_name is always a string + if not isinstance(new_name, str): + new_name = old_name + + new_schema = old_schema.copy() + + # Handle description transformation + if transform.description is not NotSet: + if transform.description is None: + new_schema.pop("description", None) # Remove description + else: + new_schema["description"] = transform.description + + # Handle required transformation first + if transform.required is not NotSet: + is_required = bool(transform.required) + if transform.required is True: + # Remove any existing default when making required + new_schema.pop("default", None) + + # Handle default value transformation (only if not making required) + if transform.default is not NotSet and transform.required is not True: + new_schema["default"] = transform.default + is_required = False + + # Handle type transformation + if transform.type is not NotSet: + # Use TypeAdapter to get proper JSON schema for the type + type_schema = get_cached_typeadapter(transform.type).json_schema() + # Update the schema with the type information from TypeAdapter + new_schema.update(type_schema) + + return new_name, new_schema, is_required + + @staticmethod + def _merge_schema_with_precedence( + base_schema: dict[str, Any], override_schema: dict[str, Any] + ) -> dict[str, Any]: + """Merge two schemas, with the override schema taking precedence. + + Args: + base_schema: Base schema to start with + override_schema: Schema that takes precedence for overlapping properties + + Returns: + Merged schema with override taking precedence + """ + merged_props = base_schema.get("properties", {}).copy() + merged_required = set(base_schema.get("required", [])) + + override_props = override_schema.get("properties", {}) + override_required = set(override_schema.get("required", [])) + + # Override properties + for param_name, param_schema in override_props.items(): + if param_name in merged_props: + # Merge the schemas, with override taking precedence + base_param = merged_props[param_name].copy() + base_param.update(param_schema) + merged_props[param_name] = base_param + else: + merged_props[param_name] = param_schema.copy() + + # Handle required parameters - override takes complete precedence + # Start with override's required set + final_required = override_required.copy() + + # For parameters not in override, inherit base requirement status + # but only if they don't have a default in the final merged properties + for param_name in merged_required: + if param_name not in override_props: + # Parameter not mentioned in override, keep base requirement status + final_required.add(param_name) + elif ( + param_name in override_props + and "default" not in merged_props[param_name] + ): + # Parameter in override but no default, keep required if it was required in base + if param_name not in override_required: + # Override doesn't specify it as required, and it has no default, + # so inherit from base + final_required.add(param_name) + + # Remove any parameters that have defaults (they become optional) + for param_name, param_schema in merged_props.items(): + if "default" in param_schema: + final_required.discard(param_name) + + return { + "type": "object", + "properties": merged_props, + "required": list(final_required), + } + + @staticmethod + def _function_has_kwargs(fn: Callable[..., Any]) -> bool: + """Check if function accepts **kwargs. + + This determines whether a custom function can accept arbitrary keyword arguments, + which affects how schemas are merged during tool transformation. + + Args: + fn: Function to inspect. + + Returns: + True if the function has a **kwargs parameter, False otherwise. + """ + sig = inspect.signature(fn) + return any( + p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values() + ) diff --git a/tests/server/test_tool_exclude_args.py b/tests/server/test_tool_exclude_args.py index d91e012c..efe8349a 100644 --- a/tests/server/test_tool_exclude_args.py +++ b/tests/server/test_tool_exclude_args.py @@ -21,9 +21,7 @@ def echo(message: str, state: dict[str, Any] | None = None) -> str: tools = mcp._tool_manager.list_tools() assert len(tools) == 1 - assert tools[0].exclude_args is not None - for args in tools[0].exclude_args: - assert args not in tools[0].parameters + assert "state" not in echo.parameters["properties"] async def test_tool_exclude_args_without_default_value_raises_error(): @@ -64,10 +62,7 @@ def create_item( # Check internal tool objects directly tools = mcp._tool_manager.list_tools() assert len(tools) == 1 - assert tools[0].exclude_args is not None - assert tools[0].exclude_args == ["state"] - for args in tools[0].exclude_args: - assert args not in tools[0].parameters + assert "state" not in tools[0].parameters["properties"] async def test_tool_functionality_with_exclude_args(): diff --git a/tests/tools/test_tool_transform.py b/tests/tools/test_tool_transform.py new file mode 100644 index 00000000..595d981b --- /dev/null +++ b/tests/tools/test_tool_transform.py @@ -0,0 +1,944 @@ +import re +from dataclasses import dataclass +from typing import Annotated, Any + +import pytest +from dirty_equals import IsList +from pydantic import BaseModel, Field +from typing_extensions import TypedDict + +from fastmcp import FastMCP +from fastmcp.client.client import Client +from fastmcp.tools import Tool, forward, forward_raw +from fastmcp.tools.tool import FunctionTool +from fastmcp.tools.tool_transform import ArgTransform, TransformedTool + + +def get_property(tool: Tool, name: str) -> dict[str, Any]: + return tool.parameters["properties"][name] + + +@pytest.fixture +def add_tool() -> FunctionTool: + def add( + old_x: Annotated[int, Field(description="old_x description")], old_y: int = 10 + ) -> int: + print("running!") + return old_x + old_y + + return Tool.from_function(add) + + +def test_tool_from_tool_no_change(add_tool): + new_tool = Tool.from_tool(add_tool) + assert isinstance(new_tool, TransformedTool) + assert new_tool.parameters == add_tool.parameters + assert new_tool.name == add_tool.name + assert new_tool.description == add_tool.description + + +async def test_renamed_arg_description_is_maintained(add_tool): + new_tool = Tool.from_tool( + add_tool, transform_args={"old_x": ArgTransform(name="new_x")} + ) + assert ( + new_tool.parameters["properties"]["new_x"]["description"] == "old_x description" + ) + + +async def test_tool_defaults_are_maintained_on_unmapped_args(add_tool): + new_tool = Tool.from_tool( + add_tool, transform_args={"old_x": ArgTransform(name="new_x")} + ) + result = await new_tool.run(arguments={"new_x": 1}) + assert result[0].text == "11" # type: ignore + + +async def test_tool_defaults_are_maintained_on_mapped_args(add_tool): + new_tool = Tool.from_tool( + add_tool, transform_args={"old_y": ArgTransform(name="new_y")} + ) + result = await new_tool.run(arguments={"old_x": 1}) + assert result[0].text == "11" # type: ignore + + +def test_tool_change_arg_name(add_tool): + new_tool = Tool.from_tool( + add_tool, transform_args={"old_x": ArgTransform(name="new_x")} + ) + + assert sorted(new_tool.parameters["properties"]) == ["new_x", "old_y"] + assert get_property(new_tool, "new_x") == get_property(add_tool, "old_x") + assert get_property(new_tool, "old_y") == get_property(add_tool, "old_y") + assert new_tool.parameters["required"] == ["new_x"] + + +def test_tool_change_arg_description(add_tool): + new_tool = Tool.from_tool( + add_tool, transform_args={"old_x": ArgTransform(description="new description")} + ) + assert get_property(new_tool, "old_x")["description"] == "new description" + + +async def test_tool_drop_arg(add_tool): + new_tool = Tool.from_tool( + add_tool, transform_args={"old_y": ArgTransform(hide=True)} + ) + assert sorted(new_tool.parameters["properties"]) == ["old_x"] + result = await new_tool.run(arguments={"old_x": 1}) + assert result[0].text == "11" # type: ignore + + +async def test_dropped_args_error_if_provided(add_tool): + new_tool = Tool.from_tool( + add_tool, transform_args={"old_y": ArgTransform(hide=True)} + ) + with pytest.raises( + TypeError, match="Got unexpected keyword argument\\(s\\): old_y" + ): + await new_tool.run(arguments={"old_x": 1, "old_y": 2}) + + +async def test_hidden_arg_with_constant_default(add_tool): + """Test that hidden argument with default value passes constant to parent.""" + new_tool = Tool.from_tool( + add_tool, transform_args={"old_y": ArgTransform(hide=True, default=20)} + ) + # Only old_x should be exposed + assert sorted(new_tool.parameters["properties"]) == ["old_x"] + # Should pass old_x=5 and old_y=20 to parent + result = await new_tool.run(arguments={"old_x": 5}) + assert result[0].text == "25" # type: ignore + + +async def test_hidden_arg_without_default_uses_parent_default(add_tool): + """Test that hidden argument without default uses parent's default.""" + new_tool = Tool.from_tool( + add_tool, transform_args={"old_y": ArgTransform(hide=True)} + ) + # Only old_x should be exposed + assert sorted(new_tool.parameters["properties"]) == ["old_x"] + # Should pass old_x=3 and let parent use its default old_y=10 + result = await new_tool.run(arguments={"old_x": 3}) + assert result[0].text == "13" # type: ignore + + +async def test_mixed_hidden_args_with_custom_function(add_tool): + """Test custom function with both hidden constant and hidden default parameters.""" + + async def custom_fn(visible_x: int) -> int: + # This custom function should receive the transformed visible parameter + # and the hidden parameters should be automatically handled + result = await forward(visible_x=visible_x) + return result + + new_tool = Tool.from_tool( + add_tool, + transform_fn=custom_fn, + transform_args={ + "old_x": ArgTransform(name="visible_x"), # Rename and expose + "old_y": ArgTransform(hide=True, default=25), # Hidden with constant + }, + ) + + # Only visible_x should be exposed + assert sorted(new_tool.parameters["properties"]) == ["visible_x"] + # Should pass visible_x=7 as old_x=7 and old_y=25 to parent + result = await new_tool.run(arguments={"visible_x": 7}) + assert result[0].text == "32" # type: ignore + + +async def test_hide_required_param_without_default_raises_error(): + """Test that hiding a required parameter without providing default raises error.""" + + @Tool.from_function + def tool_with_required_param(required_param: int, optional_param: int = 10) -> int: + return required_param + optional_param + + # This should raise an error because required_param has no default and we're not providing one + with pytest.raises( + ValueError, + match=r"Hidden parameter 'required_param' has no default value in parent tool", + ): + Tool.from_tool( + tool_with_required_param, + transform_args={"required_param": ArgTransform(hide=True)}, + ) + + +async def test_hide_required_param_with_user_default_works(): + """Test that hiding a required parameter works when user provides a default.""" + + @Tool.from_function + def tool_with_required_param(required_param: int, optional_param: int = 10) -> int: + return required_param + optional_param + + # This should work because we're providing a default for the hidden required param + new_tool = Tool.from_tool( + tool_with_required_param, + transform_args={"required_param": ArgTransform(hide=True, default=5)}, + ) + + # Only optional_param should be exposed + assert sorted(new_tool.parameters["properties"]) == ["optional_param"] + # Should pass required_param=5 and optional_param=20 to parent + result = await new_tool.run(arguments={"optional_param": 20}) + assert result[0].text == "25" # type: ignore + + +async def test_forward_with_argument_mapping(add_tool): + """Test that forward() applies argument mapping correctly.""" + + async def custom_fn(new_x: int, new_y: int = 5) -> int: + return await forward(new_x=new_x, new_y=new_y) + + new_tool = Tool.from_tool( + add_tool, + transform_fn=custom_fn, + transform_args={ + "old_x": ArgTransform(name="new_x"), + "old_y": ArgTransform(name="new_y"), + }, + ) + + result = await new_tool.run(arguments={"new_x": 2, "new_y": 3}) + assert result[0].text == "5" # type: ignore + + +async def test_forward_with_incorrect_args_raises_error(add_tool): + async def custom_fn(new_x: int, new_y: int = 5) -> int: + # the forward should use the new args, not the old ones + return await forward(old_x=new_x, old_y=new_y) + + new_tool = Tool.from_tool( + add_tool, + transform_fn=custom_fn, + transform_args={ + "old_x": ArgTransform(name="new_x"), + "old_y": ArgTransform(name="new_y"), + }, + ) + with pytest.raises( + TypeError, match=re.escape("Got unexpected keyword argument(s): old_x, old_y") + ): + await new_tool.run(arguments={"new_x": 2, "new_y": 3}) + + +async def test_forward_raw_without_argument_mapping(add_tool): + """Test that forward_raw() calls parent directly without mapping.""" + + async def custom_fn(new_x: int, new_y: int = 5) -> int: + # Call parent directly with original argument names + result = await forward_raw(old_x=new_x, old_y=new_y) + return result + + new_tool = Tool.from_tool( + add_tool, + transform_fn=custom_fn, + transform_args={ + "old_x": ArgTransform(name="new_x"), + "old_y": ArgTransform(name="new_y"), + }, + ) + + result = await new_tool.run(arguments={"new_x": 2, "new_y": 3}) + assert result[0].text == "5" # type: ignore + + +async def test_custom_fn_with_kwargs_and_no_transform_args(add_tool): + async def custom_fn(extra: int, **kwargs) -> int: + sum = await forward(**kwargs) + return int(sum[0].text) + extra # type: ignore[attr-defined] + + new_tool = Tool.from_tool(add_tool, transform_fn=custom_fn) + result = await new_tool.run(arguments={"extra": 1, "old_x": 2, "old_y": 3}) + assert result[0].text == "6" # type: ignore + assert new_tool.parameters["required"] == IsList( + "extra", "old_x", check_order=False + ) + assert list(new_tool.parameters["properties"]) == IsList( + "extra", "old_x", "old_y", check_order=False + ) + + +async def test_fn_with_kwargs_passes_through_original_args(add_tool): + async def custom_fn(new_y: int = 5, **kwargs) -> int: + assert kwargs == {"old_y": 3} + result = await forward(old_x=new_y, **kwargs) + return result + + new_tool = Tool.from_tool(add_tool, transform_fn=custom_fn) + result = await new_tool.run(arguments={"new_y": 2, "old_y": 3}) + assert result[0].text == "5" # type: ignore + + +async def test_fn_with_kwargs_receives_transformed_arg_names(add_tool): + """Test that **kwargs receives arguments with their transformed names from transform_args.""" + + async def custom_fn(new_x: int, **kwargs) -> int: + # kwargs should contain 'old_y': 3 (transformed name), not 'old_y': 3 (original name) + assert kwargs == {"old_y": 3} + result = await forward(new_x=new_x, **kwargs) + return result + + new_tool = Tool.from_tool( + add_tool, + transform_fn=custom_fn, + transform_args={"old_x": ArgTransform(name="new_x")}, + ) + result = await new_tool.run(arguments={"new_x": 2, "old_y": 3}) + assert result[0].text == "5" # type: ignore + + +async def test_fn_with_kwargs_handles_partial_explicit_args(add_tool): + """Test that function can explicitly handle some transformed args while others pass through kwargs.""" + + async def custom_fn(new_x: int, some_other_param: str = "default", **kwargs) -> int: + # x is explicitly handled, y should come through kwargs with transformed name + assert kwargs == {"old_y": 7} + result = await forward(new_x=new_x, **kwargs) + return result + + new_tool = Tool.from_tool( + add_tool, + transform_fn=custom_fn, + transform_args={"old_x": ArgTransform(name="new_x")}, + ) + result = await new_tool.run( + arguments={"new_x": 3, "old_y": 7, "some_other_param": "test"} + ) + assert result[0].text == "10" # type: ignore + + +async def test_fn_with_kwargs_mixed_mapped_and_unmapped_args(add_tool): + """Test **kwargs behavior with mix of mapped and unmapped arguments.""" + + async def custom_fn(new_x: int, **kwargs) -> int: + # new_x is explicitly handled, old_y should pass through kwargs with original name (unmapped) + assert kwargs == {"old_y": 5} + result = await forward(new_x=new_x, **kwargs) + return result + + new_tool = Tool.from_tool( + add_tool, + transform_fn=custom_fn, + transform_args={"old_x": ArgTransform(name="new_x")}, + ) # only map 'a' + result = await new_tool.run(arguments={"new_x": 1, "old_y": 5}) + assert result[0].text == "6" # type: ignore + + +async def test_fn_with_kwargs_dropped_args_not_in_kwargs(add_tool): + """Test that dropped arguments don't appear in **kwargs.""" + + async def custom_fn(new_x: int, **kwargs) -> int: + # 'b' was dropped, so kwargs should be empty + assert kwargs == {} + # Can't use 'old_y' since it was dropped, so just use 'old_x' mapped to 'new_x' + result = await forward(new_x=new_x) + return result + + new_tool = Tool.from_tool( + add_tool, + transform_fn=custom_fn, + transform_args={ + "old_x": ArgTransform(name="new_x"), + "old_y": ArgTransform(hide=True), + }, + ) # drop 'old_y' + result = await new_tool.run(arguments={"new_x": 8}) + # 8 + 10 (default value of b in parent) + assert result[0].text == "18" # type: ignore[attr-defined] + + +async def test_forward_outside_context_raises_error(): + """Test that forward() raises RuntimeError when called outside a transformed tool.""" + with pytest.raises( + RuntimeError, + match=re.escape("forward() can only be called within a transformed tool"), + ): + await forward(new_x=1, old_y=2) + + +async def test_forward_raw_outside_context_raises_error(): + """Test that forward_raw() raises RuntimeError when called outside a transformed tool.""" + with pytest.raises( + RuntimeError, + match=re.escape("forward_raw() can only be called within a transformed tool"), + ): + await forward_raw(new_x=1, old_y=2) + + +def test_transform_args_validation_unknown_arg(add_tool): + """Test that transform_args with unknown arguments raises ValueError.""" + with pytest.raises( + ValueError, match="Unknown arguments in transform_args: unknown_param" + ): + Tool.from_tool( + add_tool, transform_args={"unknown_param": ArgTransform(name="new_name")} + ) + + +def test_transform_args_creates_duplicate_names(add_tool): + """Test that transform_args creating duplicate parameter names raises ValueError.""" + with pytest.raises( + ValueError, + match="Multiple arguments would be mapped to the same names: same_name", + ): + Tool.from_tool( + add_tool, + transform_args={ + "old_x": ArgTransform(name="same_name"), + "old_y": ArgTransform(name="same_name"), + }, + ) + + +def test_function_without_kwargs_missing_params(add_tool): + """Test that function missing required transformed parameters raises ValueError.""" + + def invalid_fn(new_x: int, non_existent: str) -> str: + return f"{new_x}_{non_existent}" + + with pytest.raises( + ValueError, + match="Function missing parameters required after transformation: new_y", + ): + Tool.from_tool( + add_tool, + transform_fn=invalid_fn, + transform_args={ + "old_x": ArgTransform(name="new_x"), + "old_y": ArgTransform(name="new_y"), + }, + ) + + +def test_function_without_kwargs_can_have_extra_params(add_tool): + """Test that function can have extra parameters not in parent tool.""" + + def valid_fn(new_x: int, new_y: int, extra_param: str = "default") -> str: + return f"{new_x}_{new_y}_{extra_param}" + + # Should work - extra_param is fine as long as it has a default + new_tool = Tool.from_tool( + add_tool, + transform_fn=valid_fn, + transform_args={ + "old_x": ArgTransform(name="new_x"), + "old_y": ArgTransform(name="new_y"), + }, + ) + + # The final schema should include all function parameters + assert "new_x" in new_tool.parameters["properties"] + assert "new_y" in new_tool.parameters["properties"] + assert "extra_param" in new_tool.parameters["properties"] + + +def test_function_with_kwargs_can_add_params(add_tool): + """Test that function with **kwargs can add new parameters.""" + + async def valid_fn(extra_param: str, **kwargs) -> str: + result = await forward(**kwargs) + return f"{extra_param}: {result}" + + # This should work fine - kwargs allows access to all transformed params + tool = Tool.from_tool( + add_tool, + transform_fn=valid_fn, + transform_args={ + "old_x": ArgTransform(name="new_x"), + "old_y": ArgTransform(name="new_y"), + }, + ) + + # extra_param is added, new_x and new_y are available + assert "extra_param" in tool.parameters["properties"] + assert "new_x" in tool.parameters["properties"] + assert "new_y" in tool.parameters["properties"] + + +async def test_tool_transform_chaining(add_tool): + """Test that transformed tools can be transformed again.""" + # First transformation: a -> x + tool1 = Tool.from_tool(add_tool, transform_args={"old_x": ArgTransform(name="x")}) + + # Second transformation: x -> final_x, using tool1 + tool2 = Tool.from_tool(tool1, transform_args={"x": ArgTransform(name="final_x")}) + + result = await tool2.run(arguments={"final_x": 5}) + assert result[0].text == "15" # type: ignore + + # Transform tool1 with custom function that handles all parameters + async def custom(final_x: int, **kwargs) -> str: + result = await forward(final_x=final_x, **kwargs) + return f"custom {result[0].text}" # Extract text from content + + tool3 = Tool.from_tool( + tool1, transform_fn=custom, transform_args={"x": ArgTransform(name="final_x")} + ) + result = await tool3.run(arguments={"final_x": 3, "old_y": 5}) + assert result[0].text == "custom 8" # type: ignore + + +class MyModel(BaseModel): + x: int + y: str + + +@dataclass +class MyDataclass: + x: int + y: str + + +class MyTypedDict(TypedDict): + x: int + y: str + + +@pytest.mark.parametrize( + "py_type, json_type", + [ + (int, "integer"), + (float, "number"), + (str, "string"), + (bool, "boolean"), + (list, "array"), + (list[int], "array"), + (dict, "object"), + (dict[str, int], "object"), + (MyModel, "object"), + (MyDataclass, "object"), + (MyTypedDict, "object"), + ], +) +def test_arg_transform_type_handling(add_tool, py_type, json_type): + """Test that ArgTransform type attribute gets applied to schema.""" + new_tool = Tool.from_tool( + add_tool, transform_args={"old_x": ArgTransform(type=py_type)} + ) + + # Check that the type was changed in the schema + x_prop = get_property(new_tool, "old_x") + assert x_prop["type"] == json_type + + +def test_arg_transform_annotated_types(add_tool): + """Test that ArgTransform works with annotated types and complex types.""" + from typing import Annotated + + from pydantic import Field + + # Test with Annotated types + tool = Tool.from_tool( + add_tool, + transform_args={ + "old_x": ArgTransform( + type=Annotated[int, Field(description="An annotated integer")] + ) + }, + ) + + x_prop = get_property(tool, "old_x") + assert x_prop["type"] == "integer" + # The ArgTransform description should override the annotation description + # (since we didn't set a description in ArgTransform, it should use the original) + + # Test with Annotated string that has constraints + tool2 = Tool.from_tool( + add_tool, + transform_args={ + "old_x": ArgTransform( + type=Annotated[str, Field(min_length=1, max_length=10)] + ) + }, + ) + + x_prop2 = get_property(tool2, "old_x") + assert x_prop2["type"] == "string" + assert x_prop2["minLength"] == 1 + assert x_prop2["maxLength"] == 10 + + +def test_arg_transform_precedence_over_function_without_kwargs(): + """Test that ArgTransform attributes take precedence over function signature (no **kwargs).""" + + @Tool.from_function + def base(x: int, y: str = "default") -> str: + return f"{x}: {y}" + + # Function signature says x: int with no default, y: str = "function_default" + # ArgTransform should override these + def custom_fn(x: str = "transform_default", y: int = 99) -> str: + return f"custom: {x}, {y}" + + tool = Tool.from_tool( + base, + transform_fn=custom_fn, + transform_args={ + "x": ArgTransform(type=str, default="transform_default"), + "y": ArgTransform(type=int, default=99), + }, + ) + + # ArgTransform should take precedence + x_prop = get_property(tool, "x") + y_prop = get_property(tool, "y") + + assert x_prop["type"] == "string" # ArgTransform type wins + assert x_prop["default"] == "transform_default" # ArgTransform default wins + assert y_prop["type"] == "integer" # ArgTransform type wins + assert y_prop["default"] == 99 # ArgTransform default wins + + # Neither parameter should be required due to ArgTransform defaults + assert "x" not in tool.parameters["required"] + assert "y" not in tool.parameters["required"] + + +async def test_arg_transform_precedence_over_function_with_kwargs(): + """Test that ArgTransform attributes take precedence over function signature (with **kwargs).""" + + @Tool.from_function + def base(x: int, y: str = "base_default") -> str: + return f"{x}: {y}" + + # Function signature has different types/defaults than ArgTransform + async def custom_fn(x: str = "function_default", **kwargs) -> str: + result = await forward(x=x, **kwargs) + return f"custom: {result}" + + tool = Tool.from_tool( + base, + transform_fn=custom_fn, + transform_args={ + "x": ArgTransform(type=int, default=42), # Different type and default + "y": ArgTransform(description="ArgTransform description"), + }, + ) + + # ArgTransform should take precedence + x_prop = get_property(tool, "x") + y_prop = get_property(tool, "y") + + assert x_prop["type"] == "integer" # ArgTransform type wins over function's str + assert x_prop["default"] == 42 # ArgTransform default wins over function's default + assert ( + y_prop["description"] == "ArgTransform description" + ) # ArgTransform description + + # x should not be required due to ArgTransform default + assert "x" not in tool.parameters["required"] + + # Test it works at runtime + result = await tool.run(arguments={"y": "test"}) + # Should use ArgTransform default of 42 + assert "42: test" in result[0].text # type: ignore + + +def test_arg_transform_combined_attributes(): + """Test that multiple ArgTransform attributes work together.""" + + @Tool.from_function + def base(param: int) -> str: + return str(param) + + tool = Tool.from_tool( + base, + transform_args={ + "param": ArgTransform( + name="renamed_param", + type=str, + description="New description", + default="default_value", + ) + }, + ) + + # Check all attributes were applied + assert "renamed_param" in tool.parameters["properties"] + assert "param" not in tool.parameters["properties"] + + prop = get_property(tool, "renamed_param") + assert prop["type"] == "string" + assert prop["description"] == "New description" + assert prop["default"] == "default_value" + assert "renamed_param" not in tool.parameters["required"] # Has default + + +async def test_arg_transform_type_precedence_runtime(): + """Test that ArgTransform type changes work correctly at runtime.""" + + @Tool.from_function + def base(x: int, y: int = 10) -> int: + return x + y + + # Transform x to string type but keep same logic + async def custom_fn(x: str, y: int = 10) -> str: + # Convert string back to int for the original function + result = await forward_raw(x=int(x), y=y) + # Extract the text from the result + result_text = result[0].text + return f"String input '{x}' converted to result: {result_text}" + + tool = Tool.from_tool( + base, transform_fn=custom_fn, transform_args={"x": ArgTransform(type=str)} + ) + + # Verify schema shows string type + assert get_property(tool, "x")["type"] == "string" + + # Test it works with string input + result = await tool.run(arguments={"x": "5", "y": 3}) + assert "String input '5'" in result[0].text # type: ignore + assert "result: 8" in result[0].text # type: ignore + + +class TestProxy: + @pytest.fixture + def mcp_server(self) -> FastMCP: + mcp = FastMCP() + + @mcp.tool + def add(old_x: int, old_y: int = 10) -> int: + return old_x + old_y + + return mcp + + @pytest.fixture + def proxy_server(self, mcp_server: FastMCP) -> FastMCP: + from fastmcp.client.transports import FastMCPTransport + + proxy = FastMCP.as_proxy(Client(transport=FastMCPTransport(mcp_server))) + return proxy + + async def test_transform_proxy(self, proxy_server: FastMCP): + # when adding transformed tools to proxy servers. Needs separate investigation. + + add_tool = await proxy_server.get_tool("add") + new_add_tool = Tool.from_tool( + add_tool, + name="add_transformed", + transform_args={"old_x": ArgTransform(name="new_x")}, + ) + proxy_server.add_tool(new_add_tool) + + async with Client(proxy_server) as client: + # The tool should be registered with its transformed name + result = await client.call_tool("add_transformed", {"new_x": 1, "old_y": 2}) + assert result[0].text == "3" # type: ignore + + +async def test_arg_transform_default_factory(): + """Test ArgTransform with default_factory for hidden parameters.""" + + @Tool.from_function + def base_tool(x: int, timestamp: float) -> str: + return f"{x}_{timestamp}" + + # Create a tool with default_factory for hidden timestamp + new_tool = Tool.from_tool( + base_tool, + transform_args={ + "timestamp": ArgTransform(hide=True, default_factory=lambda: 12345.0) + }, + ) + + # Only x should be visible since timestamp is hidden + assert sorted(new_tool.parameters["properties"]) == ["x"] + + # Should work without providing timestamp (gets value from factory) + result = await new_tool.run(arguments={"x": 42}) + assert result[0].text == "42_12345.0" # type: ignore + + +async def test_arg_transform_default_factory_called_each_time(): + """Test that default_factory is called for each execution.""" + call_count = 0 + + def counter_factory(): + nonlocal call_count + call_count += 1 + return call_count + + @Tool.from_function + def base_tool(x: int, counter: int = 0) -> str: + return f"{x}_{counter}" + + new_tool = Tool.from_tool( + base_tool, + transform_args={ + "counter": ArgTransform(hide=True, default_factory=counter_factory) + }, + ) + + # Only x should be visible since counter is hidden + assert sorted(new_tool.parameters["properties"]) == ["x"] + + # First call + result1 = await new_tool.run(arguments={"x": 1}) + assert result1[0].text == "1_1" # type: ignore + + # Second call should get a different value + result2 = await new_tool.run(arguments={"x": 2}) + assert result2[0].text == "2_2" # type: ignore + + +async def test_arg_transform_hidden_with_default_factory(): + """Test hidden parameter with default_factory.""" + + @Tool.from_function + def base_tool(x: int, request_id: str) -> str: + return f"{x}_{request_id}" + + def make_request_id(): + return "req_123" + + new_tool = Tool.from_tool( + base_tool, + transform_args={ + "request_id": ArgTransform(hide=True, default_factory=make_request_id) + }, + ) + + # Only x should be visible + assert sorted(new_tool.parameters["properties"]) == ["x"] + + # Should pass hidden request_id with factory value + result = await new_tool.run(arguments={"x": 42}) + assert result[0].text == "42_req_123" # type: ignore + + +async def test_arg_transform_default_and_factory_raises_error(): + """Test that providing both default and default_factory raises an error.""" + with pytest.raises( + ValueError, match="Cannot specify both 'default' and 'default_factory'" + ): + ArgTransform(default=42, default_factory=lambda: 24) + + +async def test_arg_transform_default_factory_requires_hide(): + """Test that default_factory requires hide=True.""" + with pytest.raises( + ValueError, match="default_factory can only be used with hide=True" + ): + ArgTransform(default_factory=lambda: 42) # hide=False by default + + +async def test_arg_transform_required_true(): + """Test that required=True makes an optional parameter required.""" + + @Tool.from_function + def base_tool(optional_param: int = 42) -> str: + return f"value: {optional_param}" + + # Make the optional parameter required + new_tool = Tool.from_tool( + base_tool, transform_args={"optional_param": ArgTransform(required=True)} + ) + + # Parameter should now be required (no default in schema) + assert "optional_param" in new_tool.parameters["required"] + assert "default" not in new_tool.parameters["properties"]["optional_param"] + + # Should work when parameter is provided + result = await new_tool.run(arguments={"optional_param": 100}) + assert result[0].text == "value: 100" # type: ignore + + # Should fail when parameter is not provided + with pytest.raises(TypeError, match="Missing required argument"): + await new_tool.run(arguments={}) + + +async def test_arg_transform_required_false(): + """Test that required=False makes a required parameter optional with default.""" + + @Tool.from_function + def base_tool(required_param: int) -> str: + return f"value: {required_param}" + + with pytest.raises( + ValueError, + match="Cannot specify 'required=False'. Set a default value instead.", + ): + Tool.from_tool( + base_tool, + transform_args={"required_param": ArgTransform(required=False, default=99)}, # type: ignore + ) + + +async def test_arg_transform_required_with_rename(): + """Test that required works correctly with argument renaming.""" + + @Tool.from_function + def base_tool(optional_param: int = 42) -> str: + return f"value: {optional_param}" + + # Rename and make required + new_tool = Tool.from_tool( + base_tool, + transform_args={ + "optional_param": ArgTransform(name="new_param", required=True) + }, + ) + + # New parameter name should be required + assert "new_param" in new_tool.parameters["required"] + assert "optional_param" not in new_tool.parameters["properties"] + assert "new_param" in new_tool.parameters["properties"] + assert "default" not in new_tool.parameters["properties"]["new_param"] + + # Should work with new name + result = await new_tool.run(arguments={"new_param": 200}) + assert result[0].text == "value: 200" # type: ignore + + +async def test_arg_transform_required_true_with_default_raises_error(): + """Test that required=True with default raises an error.""" + with pytest.raises( + ValueError, match="Cannot specify 'required=True' with 'default'" + ): + ArgTransform(required=True, default=42) + + +async def test_arg_transform_required_true_with_factory_raises_error(): + """Test that required=True with default_factory raises an error.""" + with pytest.raises( + ValueError, match="default_factory can only be used with hide=True" + ): + ArgTransform(required=True, default_factory=lambda: 42) + + +async def test_arg_transform_required_no_change(): + """Test that required=... (NotSet) leaves requirement status unchanged.""" + + @Tool.from_function + def base_tool(required_param: int, optional_param: int = 42) -> str: + return f"values: {required_param}, {optional_param}" + + # Transform without changing required status + new_tool = Tool.from_tool( + base_tool, + transform_args={ + "required_param": ArgTransform(name="req"), + "optional_param": ArgTransform(name="opt"), + }, + ) + + # Required status should be unchanged + assert "req" in new_tool.parameters["required"] + assert "opt" not in new_tool.parameters["required"] + assert new_tool.parameters["properties"]["opt"]["default"] == 42 + + # Should work as expected + result = await new_tool.run(arguments={"req": 1}) + assert result[0].text == "values: 1, 42" # type: ignore + + +async def test_arg_transform_hide_and_required_raises_error(): + """Test that hide=True and required=True together raises an error.""" + with pytest.raises( + ValueError, match="Cannot specify both 'hide=True' and 'required=True'" + ): + ArgTransform(hide=True, required=True) diff --git a/tests/utilities/test_types.py b/tests/utilities/test_types.py index de3c33f7..bb50589f 100644 --- a/tests/utilities/test_types.py +++ b/tests/utilities/test_types.py @@ -1,4 +1,5 @@ import base64 +from types import EllipsisType from typing import Annotated, Any import pytest @@ -308,6 +309,14 @@ def func(a: int, b: SENTINEL, c: str): # type: ignore assert find_kwarg_by_type(func, SENTINEL) is None # type: ignore + def test_ellipsis_annotation(self): + """Test finding parameter with an ellipsis annotation.""" + + def func(a: int, b: EllipsisType, c: str): # type: ignore # noqa: F821 + pass + + assert find_kwarg_by_type(func, EllipsisType) == "b" # type: ignore + def test_missing_type_annotation(self): """Test finding parameter with a missing type annotation."""