From 91662baf122df191785f2169c80f8537499dae8d Mon Sep 17 00:00:00 2001 From: "Craig.Walton" Date: Thu, 26 Jun 2025 13:09:24 +0000 Subject: [PATCH 01/24] Add hooks feature. --- docs/extensions.qmd | 53 ++++ src/inspect_ai/_eval/context.py | 2 - src/inspect_ai/_eval/eval.py | 10 +- src/inspect_ai/_eval/task/run.py | 101 ++++--- src/inspect_ai/_util/platform.py | 5 + src/inspect_ai/_util/registry.py | 6 + src/inspect_ai/hooks/__init__.py | 27 ++ src/inspect_ai/hooks/_hooks.py | 259 ++++++++++++++++++ .../{_util/hooks.py => hooks/_legacy.py} | 29 +- src/inspect_ai/hooks/_startup.py | 39 +++ src/inspect_ai/model/_model.py | 14 +- tests/hooks/test_hooks.py | 201 ++++++++++++++ tests/test_extensions.py | 11 + .../test_package/inspect_package/_registry.py | 8 + .../inspect_package/hooks/custom.py | 10 + 15 files changed, 712 insertions(+), 63 deletions(-) create mode 100644 src/inspect_ai/hooks/__init__.py create mode 100644 src/inspect_ai/hooks/_hooks.py rename src/inspect_ai/{_util/hooks.py => hooks/_legacy.py} (84%) create mode 100644 src/inspect_ai/hooks/_startup.py create mode 100644 tests/hooks/test_hooks.py create mode 100644 tests/test_package/inspect_package/hooks/custom.py diff --git a/docs/extensions.qmd b/docs/extensions.qmd index be15cf785a..f28029bfc3 100644 --- a/docs/extensions.qmd +++ b/docs/extensions.qmd @@ -14,6 +14,8 @@ There are several ways to extend Inspect to integrate with systems not directly 4. Storage Systems (for datasets, prompts, and evaluation logs) +5. Hooks (for logging and monitoring frameworks) + For each of these, you can create an extension within a Python package, and then use it without any special registration with Inspect (this is done via [setuptools entry points](https://setuptools.pypa.io/en/latest/userguide/entry_point.html)). ## Model APIs {#sec-model-api-extensions} @@ -425,3 +427,54 @@ myfs = "evaltools:MyFs" ::: Once this package is installed, you'll be able to use `myfs://` with Inspect without any further registration. + +## Hooks + +Hooks allow you to run arbitrary code during Inspect's lifecycle, for example when runs, tasks or samples start and end. + +Here is a hypothetical hook for integration with Weights & Biases. + +``` python +import wandb + +from inspect_ai.hooks import Hooks, RunEnd, RunStart, SampleEnd, hooks + +@hooks(name="w&b_hook") +class WBHook(Hooks): + async def on_run_start(self, data: RunStart) -> None: + wandb.init(name=data.run_id) + + async def on_run_end(self, data: RunEnd) -> None: + wandb.finish() + + async def on_sample_end(self, data: SampleEnd) -> None: + scores = {k: v.value for k, v in data.summary.scores.items()} + wandb.log({ + "sample_id": data.sample_id, + "scores": scores + }) +``` + +Implementations of the `Hooks` class can override an `enabled()` method. + +### API Key Override + +There is a hook to optionally override the value of model API key environment variables. This could be used to: + +* inject API keys at runtime (e.g. fetched from a secrets manager), to avoid having to store these in your environment or .env file +* use some custom model API authentication mechanism in conjunction with a custom reverse proxy for the model API to avoid Inspect ever having access to real API keys + +``` python +from inspect_ai.hooks import hooks, Hooks, ApiKeyOverride + +@hooks(name="api_key_fetcher") +class ApiKeyFetcher(Hooks): + def override_api_key(self, data: ApiKeyOverride) -> str | None: + original_env_var_value = data.value + if original_env_var_value.startswith("arn:aws:secretsmanager:"): + return fetch_aws_secret(original_env_var_value) + return None + +def fetch_aws_secret(aws_arn: str) -> str: + ... +``` diff --git a/src/inspect_ai/_eval/context.py b/src/inspect_ai/_eval/context.py index baf6d3fc1f..839e9c7732 100644 --- a/src/inspect_ai/_eval/context.py +++ b/src/inspect_ai/_eval/context.py @@ -2,7 +2,6 @@ from inspect_ai._util.dotenv import init_dotenv from inspect_ai._util.eval_task_group import init_eval_task_group -from inspect_ai._util.hooks import init_hooks from inspect_ai._util.logger import init_logger from inspect_ai.approval._apply import have_tool_approval, init_tool_approval from inspect_ai.approval._human.manager import init_human_approval_manager @@ -28,7 +27,6 @@ def init_eval_context( init_logger(log_level, log_level_transcript) init_concurrency() init_max_subprocesses(max_subprocesses) - init_hooks() init_active_samples() init_human_approval_manager() init_eval_task_group(task_group) diff --git a/src/inspect_ai/_eval/eval.py b/src/inspect_ai/_eval/eval.py index eaa51e61be..66f7d061dd 100644 --- a/src/inspect_ai/_eval/eval.py +++ b/src/inspect_ai/_eval/eval.py @@ -469,6 +469,8 @@ async def _eval_async_inner( score_display: bool | None = None, **kwargs: Unpack[GenerateConfigArgs], ) -> list[EvalLog]: + from inspect_ai.hooks._hooks import emit_run_end, emit_run_start + # only a single call to eval_async can be active at a time, this used # to be due to running tasks switching to the task's directory, however # that feature no longer exists so we may be able to revisit this @@ -488,6 +490,8 @@ async def _eval_async_inner( model_args = resolve_args(model_args) task_args = resolve_args(task_args) + run_id = uuid() + try: # intialise eval model = eval_init( @@ -609,7 +613,7 @@ async def _eval_async_inner( # run tasks - 2 codepaths, one for the traditional task at a time # (w/ optional multiple models) and the other for true multi-task # (which requires different scheduling and UI) - run_id = uuid() + await emit_run_start(run_id, resolved_tasks) task_definitions = len(resolved_tasks) // len(model) parallel = 1 if (task_definitions == 1 or max_tasks is None) else max_tasks @@ -668,6 +672,10 @@ async def _eval_async_inner( cleanup_sample_buffers(log_dir) finally: + try: + await emit_run_end(run_id, logs) + except UnboundLocalError: + await emit_run_end(run_id, EvalLogs([])) _eval_async_running = False # return logs diff --git a/src/inspect_ai/_eval/task/run.py b/src/inspect_ai/_eval/task/run.py index e3153fd568..f5263ac7bf 100644 --- a/src/inspect_ai/_eval/task/run.py +++ b/src/inspect_ai/_eval/task/run.py @@ -29,7 +29,6 @@ from inspect_ai._util.datetime import iso_now from inspect_ai._util.error import exception_message from inspect_ai._util.exception import TerminateSampleError -from inspect_ai._util.hooks import send_telemetry from inspect_ai._util.json import to_json_str_safe from inspect_ai._util.registry import ( is_registry_object, @@ -139,6 +138,9 @@ class TaskRunOptions: async def task_run(options: TaskRunOptions) -> EvalLog: + from inspect_ai.hooks._hooks import emit_task_end, emit_task_start + from inspect_ai.hooks._legacy import send_telemetry_legacy + # destructure options task = options.task model = options.model @@ -230,6 +232,8 @@ async def task_run(options: TaskRunOptions) -> EvalLog: log_location=log_location, ) + await emit_task_start(logger) + with display().task( profile, ) as td: @@ -383,6 +387,8 @@ async def run(tg: TaskGroup) -> None: # finish w/ success status eval_log = await logger.log_finish("success", stats, results, reductions) + await emit_task_end(logger, eval_log) + # display task summary td.complete( TaskSuccess( @@ -433,13 +439,14 @@ async def run(tg: TaskGroup) -> None: view_notify_eval(logger.location) try: + # TODO: Also send this to "new" hooks (including eval log location). if ( - await send_telemetry("eval_log_location", eval_log.location) + await send_telemetry_legacy("eval_log_location", eval_log.location) == "not_handled" ): # Converting the eval log to JSON is expensive. Only do so if # eval_log_location was not handled. - await send_telemetry("eval_log", eval_log_json_str(eval_log)) + await send_telemetry_legacy("eval_log", eval_log_json_str(eval_log)) except Exception as ex: py_logger.warning(f"Error occurred sending telemetry: {exception_message(ex)}") @@ -537,6 +544,12 @@ async def task_run_sample( working_limit: int | None, semaphore: anyio.Semaphore | None, ) -> dict[str, SampleScore] | None: + from inspect_ai.hooks._hooks import ( + emit_sample_abort, + emit_sample_end, + emit_sample_start, + ) + # if there is an existing sample then tick off its progress, log it, and return it if sample_source and sample.id is not None: previous_sample = sample_source(sample.id, state.epoch) @@ -679,16 +692,18 @@ def log_sample_error() -> None: # mark started active.started = datetime.now().timestamp() + sample_summary = EvalSampleSummary( + id=sample_id, + epoch=state.epoch, + input=sample.input, + target=sample.target, + metadata=sample.metadata or {}, + ) + if logger is not None: - await logger.start_sample( - EvalSampleSummary( - id=sample_id, - epoch=state.epoch, - input=sample.input, - target=sample.target, - metadata=sample.metadata or {}, - ) - ) + await logger.start_sample(sample_summary) + # TODO: Where can I get run_id and eval_id? + await emit_sample_start("", "", sample_id, sample_summary) # set progress for plan then run it async with span("solvers"): @@ -828,29 +843,32 @@ def log_sample_error() -> None: if not error or (retry_on_error == 0): progress(SAMPLE_TOTAL_PROGRESS_UNITS) - # log it - if logger is not None: - # if we are logging images then be sure to base64 images injected by solvers - if log_images: - state = (await states_with_base64_content([state]))[0] - - # otherwise ensure there are no base64 images in sample or messages - else: - sample = sample_without_base64_content(sample) - state = state_without_base64_content(state) - - # log the sample - await log_sample( - start_time=start_time, - logger=logger, - sample=sample, - state=state, - scores=results, - error=error, - limit=limit, - error_retries=error_retries, - log_images=log_images, - ) + # if we are logging images then be sure to base64 images injected by solvers + if log_images: + state = (await states_with_base64_content([state]))[0] + + # otherwise ensure there are no base64 images in sample or messages + else: + sample = sample_without_base64_content(sample) + state = state_without_base64_content(state) + + # log the sample + eval_sample = await log_sample( + start_time=start_time, + logger=logger, + sample=sample, + state=state, + scores=results, + error=error, + limit=limit, + error_retries=error_retries, + log_images=log_images, + ) + + # TODO: Where can I get run_id and eval_id? + # TODO: Do we only want to emit sample end if there was no error? + if not error: + await emit_sample_end("", "", sample_id, eval_sample.summary()) # error that should be retried (we do this outside of the above scope so that we can # retry outside of the original semaphore -- our retry will therefore go to the back @@ -899,16 +917,18 @@ def log_sample_error() -> None: # we have an error and should raise it elif raise_error is not None: + await emit_sample_abort("", "", sample_id, error) raise raise_error # we have an error and should not raise it + # TODO: Do I need to emit a sample abort for this? else: return None async def log_sample( start_time: float | None, - logger: TaskLogger, + logger: TaskLogger | None, sample: Sample, state: TaskState, scores: dict[str, SampleScore], @@ -916,7 +936,7 @@ async def log_sample( limit: EvalSampleLimit | None, error_retries: list[EvalError], log_images: bool, -) -> None: +) -> EvalSample: # sample must have id to be logged id = sample.id if id is None: @@ -955,7 +975,12 @@ async def log_sample( limit=limit, ) - await logger.complete_sample(condense_sample(eval_sample, log_images), flush=True) + if logger is not None: + await logger.complete_sample( + condense_sample(eval_sample, log_images), flush=True + ) + + return eval_sample async def resolve_dataset( diff --git a/src/inspect_ai/_util/platform.py b/src/inspect_ai/_util/platform.py index 74a8dc8fa5..0b02444735 100644 --- a/src/inspect_ai/_util/platform.py +++ b/src/inspect_ai/_util/platform.py @@ -20,6 +20,11 @@ def running_in_notebook() -> bool: def platform_init() -> None: + from inspect_ai.hooks._startup import init_hooks + + # TODO: init hooks here. Ensure idempotent. + init_hooks() + # set exception hook if we haven't already set_exception_hook() diff --git a/src/inspect_ai/_util/registry.py b/src/inspect_ai/_util/registry.py index 949f864f3e..92966f18ca 100644 --- a/src/inspect_ai/_util/registry.py +++ b/src/inspect_ai/_util/registry.py @@ -26,6 +26,7 @@ from inspect_ai import Task from inspect_ai.agent import Agent from inspect_ai.approval import Approver + from inspect_ai.hooks._hooks import Hooks from inspect_ai.model import ModelAPI from inspect_ai.scorer import Metric, Scorer, ScoreReducer from inspect_ai.solver import Plan, Solver @@ -37,6 +38,7 @@ RegistryType = Literal[ "agent", "approver", + "hooks", "metric", "modelapi", "plan", @@ -238,6 +240,10 @@ def registry_create( ) -> Approver: ... +@overload +def registry_create(type: Literal["hooks"], name: str, **kwargs: Any) -> Hooks: ... + + @overload def registry_create(type: Literal["metric"], name: str, **kwargs: Any) -> Metric: ... diff --git a/src/inspect_ai/hooks/__init__.py b/src/inspect_ai/hooks/__init__.py new file mode 100644 index 0000000000..d8d74f768e --- /dev/null +++ b/src/inspect_ai/hooks/__init__.py @@ -0,0 +1,27 @@ +from inspect_ai.hooks._hooks import ( + ApiKeyOverride, + Hooks, + ModelUsageData, + RunEnd, + RunStart, + SampleAbort, + SampleEnd, + SampleStart, + TaskEnd, + TaskStart, + hooks, +) + +__all__ = [ + "ApiKeyOverride", + "Hooks", + "ModelUsageData", + "RunEnd", + "RunStart", + "SampleAbort", + "SampleEnd", + "SampleStart", + "TaskEnd", + "TaskStart", + "hooks", +] diff --git a/src/inspect_ai/hooks/_hooks.py b/src/inspect_ai/hooks/_hooks.py new file mode 100644 index 0000000000..3918b2248b --- /dev/null +++ b/src/inspect_ai/hooks/_hooks.py @@ -0,0 +1,259 @@ +from dataclasses import dataclass +from logging import getLogger +from typing import Awaitable, Callable, Type, TypeVar, cast + +from inspect_ai._eval.eval import EvalLogs +from inspect_ai._eval.task.log import TaskLogger +from inspect_ai._eval.task.resolved import ResolvedTask +from inspect_ai._util.error import EvalError +from inspect_ai._util.registry import ( + RegistryInfo, + registry_add, + registry_find, + registry_name, +) +from inspect_ai.hooks._legacy import override_api_key_legacy +from inspect_ai.log._log import EvalLog, EvalSampleSummary, EvalSpec +from inspect_ai.model._model_output import ModelUsage + +logger = getLogger(__name__) + + +@dataclass(frozen=True) +class RunStart: + run_id: str + task_names: list[str] + + +@dataclass(frozen=True) +class RunEnd: + run_id: str + logs: EvalLogs + + +@dataclass(frozen=True) +class TaskStart: + run_id: str + eval_id: str + spec: EvalSpec + + +@dataclass(frozen=True) +class TaskEnd: + run_id: str + eval_id: str + log: EvalLog + + +@dataclass(frozen=True) +class SampleStart: + run_id: str + eval_id: str + sample_id: int | str + summary: EvalSampleSummary + + +@dataclass(frozen=True) +class SampleEnd: + run_id: str + eval_id: str + sample_id: int | str + summary: EvalSampleSummary + + +@dataclass(frozen=True) +class SampleAbort: + run_id: str + eval_id: str + sample_id: int | str + error: EvalError + + +@dataclass(frozen=True) +class ModelUsageData: + model_name: str + usage: ModelUsage + call_duration: float | None = None + + +@dataclass(frozen=True) +class ApiKeyOverride: + env_var_name: str + """The name of the environment var containing the API key (e.g. OPENAI_API_KEY).""" + value: str + """The original value of the environment variable.""" + + +class Hooks: + """Base class for hooks. + + Note that whenever hooks are called, they are wrapped in a try/except block to + catch any exceptions that may occur. This is to ensure that a hook failure does not + affect the overall execution of the eval. If a hook fails, a warning will be logged. + """ + + def enabled(self) -> bool: + """Check if the hook should be enabled. + + Default implementation returns True. + + Hooks may wish to override this to e.g. check the presence of an environment + variable or a configuration setting. + + Will be called frequently, so consider caching the result if the computation is + expensive. + """ + return True + + async def on_run_start(self, data: RunStart) -> None: + pass + + async def on_run_end(self, data: RunEnd) -> None: + pass + + async def on_task_start(self, data: TaskStart) -> None: + pass + + async def on_task_end(self, data: TaskEnd) -> None: + pass + + async def on_sample_start(self, data: SampleStart) -> None: + pass + + async def on_sample_end(self, data: SampleEnd) -> None: + pass + + async def on_sample_abort(self, data: SampleAbort) -> None: + """A sample has been aborted due to an error, and will not be retried.""" + pass + + async def on_model_usage(self, data: ModelUsageData) -> None: + pass + + def override_api_key(self, data: ApiKeyOverride) -> str | None: + """Optionally override an API key. + + When overridden, this method may return a new API key value which will be used + in place of the original one during the eval. + + Returns: + str | None: The new API key value to use, or None to use the original value. + """ + return None + + +T = TypeVar("T", bound=Hooks) + + +def hooks(name: str) -> Callable[..., Type[T]]: + """Decorator for registering a hook subscriber. + + Args: + name (str): Name of the subscriber. + """ + + def wrapper(hook_type: Type[T] | Callable[..., Type[T]]) -> Type[T]: + # Resolve the hook type it's a function. + if not isinstance(hook_type, type): + hook_type = hook_type() + if not issubclass(hook_type, Hooks): + raise TypeError(f"Hook must be a subclass of Hooks, got {hook_type}") + + # Instantiate an instance of the hook class. + hook_instance = hook_type() + hook_name = registry_name(hook_instance, name) + registry_add( + hook_instance, + RegistryInfo(type="hooks", name=hook_name), + ) + return cast(Type[T], hook_instance) + + return wrapper + + +async def emit_run_start(run_id: str, tasks: list[ResolvedTask]) -> None: + data = RunStart(run_id=run_id, task_names=[task.task.name for task in tasks]) + await _emit_to_all(lambda hook: hook.on_run_start(data)) + + +async def emit_run_end(run_id: str, logs: EvalLogs) -> None: + data = RunEnd(run_id=run_id, logs=logs) + await _emit_to_all(lambda hook: hook.on_run_end(data)) + + +async def emit_task_start(logger: TaskLogger) -> None: + data = TaskStart( + run_id=logger.eval.run_id, eval_id=logger.eval.eval_id, spec=logger.eval + ) + await _emit_to_all(lambda hook: hook.on_task_start(data)) + + +async def emit_task_end(logger: TaskLogger, log: EvalLog) -> None: + data = TaskEnd(run_id=logger.eval.run_id, eval_id=logger.eval.eval_id, log=log) + await _emit_to_all(lambda hook: hook.on_task_end(data)) + + +async def emit_sample_start( + run_id: str, eval_id: str, sample_id: int | str, summary: EvalSampleSummary +) -> None: + data = SampleStart( + run_id=run_id, eval_id=eval_id, sample_id=sample_id, summary=summary + ) + await _emit_to_all(lambda hook: hook.on_sample_start(data)) + + +async def emit_sample_end( + run_id: str, eval_id: str, sample_id: int | str, summary: EvalSampleSummary +) -> None: + data = SampleEnd( + run_id=run_id, eval_id=eval_id, sample_id=sample_id, summary=summary + ) + await _emit_to_all(lambda hook: hook.on_sample_end(data)) + + +async def emit_sample_abort( + run_id: str, eval_id: str, sample_id: int | str, error: EvalError +) -> None: + data = SampleAbort(run_id=run_id, eval_id=eval_id, sample_id=sample_id, error=error) + await _emit_to_all(lambda hook: hook.on_sample_abort(data)) + + +async def emit_model_usage( + model_name: str, usage: ModelUsage, call_duration: float | None +) -> None: + data = ModelUsageData( + model_name=model_name, usage=usage, call_duration=call_duration + ) + await _emit_to_all(lambda hook: hook.on_model_usage(data)) + + +def override_api_key(env_var_name: str, value: str) -> str | None: + data = ApiKeyOverride(env_var_name=env_var_name, value=value) + for hook in get_all_hooks(): + if not hook.enabled(): + continue + try: + overridden = hook.override_api_key(data) + if overridden is not None: + return overridden + except Exception as ex: + logger.warning( + f"Exception calling override_api_key on hook '{hook.__class__.__name__}': {ex}" + ) + return override_api_key_legacy(env_var_name, value) + + +def get_all_hooks() -> list[Hooks]: + """Get all registered hooks.""" + results = registry_find(lambda info: info.type == "hooks") + return cast(list[Hooks], results) + + +async def _emit_to_all(callable: Callable[[Hooks], Awaitable[None]]) -> None: + for hook in get_all_hooks(): + if not hook.enabled(): + continue + try: + await callable(hook) + except Exception as ex: + logger.warning(f"Exception calling hook '{hook.__class__.__name__}': {ex}") diff --git a/src/inspect_ai/_util/hooks.py b/src/inspect_ai/hooks/_legacy.py similarity index 84% rename from src/inspect_ai/_util/hooks.py rename to src/inspect_ai/hooks/_legacy.py index b56eaba372..73933c1cc1 100644 --- a/src/inspect_ai/_util/hooks.py +++ b/src/inspect_ai/hooks/_legacy.py @@ -1,11 +1,10 @@ +"""Legacy hooks for telemetry and API key overrides.""" + import importlib import os from typing import Any, Awaitable, Callable, Literal, cast -from rich import print - -from .constants import PKG_NAME -from .error import PrerequisiteError +from inspect_ai._util.error import PrerequisiteError # Hooks are functions inside packages that are installed with an # environment variable (e.g. INSPECT_TELEMETRY='mypackage.send_telemetry') @@ -31,7 +30,7 @@ TelemetrySend = Callable[[str, str], Awaitable[bool]] -async def send_telemetry( +async def send_telemetry_legacy( type: Literal["model_usage", "eval_log", "eval_log_location"], json: str ) -> Literal["handled", "not_handled", "no_subscribers"]: global _send_telemetry @@ -57,7 +56,7 @@ async def send_telemetry( ApiKeyOverride = Callable[[str, str], str | None] -def override_api_key(var: str, value: str) -> str | None: +def override_api_key_legacy(var: str, value: str) -> str | None: global _override_api_key if _override_api_key: return _override_api_key(var, value) @@ -68,14 +67,12 @@ def override_api_key(var: str, value: str) -> str | None: _override_api_key: ApiKeyOverride | None = None -def init_hooks() -> None: - # messages we'll print for hooks if we have them +def init_legacy_hooks() -> list[str]: messages: list[str] = [] - # telemetry global _send_telemetry if not _send_telemetry: - result = init_hook( + result = init_legacy_hook( "telemetry", "INSPECT_TELEMETRY", "(eval logs and token usage will be recorded by the provider)", @@ -87,7 +84,7 @@ def init_hooks() -> None: # api key override global _override_api_key if not _override_api_key: - result = init_hook( + result = init_legacy_hook( "api key override", "INSPECT_API_KEY_OVERRIDE", "(api keys will be read and modified by the provider)", @@ -96,16 +93,10 @@ def init_hooks() -> None: _override_api_key, message = result messages.append(message) - # if any hooks are enabled, let the user know - if len(messages) > 0: - version = importlib.metadata.version(PKG_NAME) - all_messages = "\n".join([f"- {message}" for message in messages]) - print( - f"[blue][bold]inspect_ai v{version}[/bold][/blue]\n[bright_black]{all_messages}[/bright_black]\n" - ) + return messages -def init_hook( +def init_legacy_hook( name: str, env: str, message: str ) -> tuple[Callable[..., Any], str] | None: hook = os.environ.get(env, "") diff --git a/src/inspect_ai/hooks/_startup.py b/src/inspect_ai/hooks/_startup.py new file mode 100644 index 0000000000..78632e9e11 --- /dev/null +++ b/src/inspect_ai/hooks/_startup.py @@ -0,0 +1,39 @@ +import importlib + +from rich import print + +from inspect_ai._util.constants import PKG_NAME +from inspect_ai._util.registry import registry_info +from inspect_ai.hooks._legacy import init_legacy_hooks + +_registry_hooks_loaded: bool = False + + +def init_hooks() -> None: + # messages we'll print for hooks if we have them + messages: list[str] = [] + + messages.extend(init_legacy_hooks()) + + from inspect_ai.hooks._hooks import get_all_hooks + + global _registry_hooks_loaded + if not _registry_hooks_loaded: + # Note that hooks loaded by virtue of load_file_tasks() -> load_module() (e.g. + # if the user defines an @hook alongside their task) won't be loaded by now. + hooks = get_all_hooks() + _registry_hooks_loaded = True + if hooks: + hook_names = [f" {registry_info(hook).name}" for hook in hooks] + messages.append( + f"[bold]hooks enabled: {len(hooks)}[/bold]\n{'\n'.join(hook_names)}" + ) + + # if any hooks are enabled, let the user know + if len(messages) > 0: + version = importlib.metadata.version(PKG_NAME) + all_messages = "\n".join([f"- {message}" for message in messages]) + print( + f"[blue][bold]inspect_ai v{version}[/bold][/blue]\n" + f"[bright_black]{all_messages}[/bright_black]\n" + ) diff --git a/src/inspect_ai/model/_model.py b/src/inspect_ai/model/_model.py index f66b393ed0..a09255d997 100644 --- a/src/inspect_ai/model/_model.py +++ b/src/inspect_ai/model/_model.py @@ -39,7 +39,6 @@ ContentReasoning, ContentText, ) -from inspect_ai._util.hooks import init_hooks, override_api_key, send_telemetry from inspect_ai._util.interrupt import check_sample_interrupt from inspect_ai._util.logger import warn_once from inspect_ai._util.notgiven import NOT_GIVEN, NotGiven @@ -129,6 +128,8 @@ def __init__( self.base_url = base_url self.config = config + from inspect_ai.hooks._hooks import override_api_key + # apply api key override for key in api_key_vars: # if there is an explicit api_key passed then it @@ -490,6 +491,8 @@ async def _generate( config: GenerateConfig, cache: bool | CachePolicy = False, ) -> tuple[ModelOutput, BaseModel]: + from inspect_ai.hooks._hooks import emit_model_usage + from inspect_ai.hooks._legacy import send_telemetry_legacy from inspect_ai.log._samples import track_active_model_event from inspect_ai.log._transcript import ModelEvent @@ -682,8 +685,11 @@ async def generate() -> tuple[ModelOutput, BaseModel]: # record usage record_and_check_model_usage(f"{self}", output.usage) - # send telemetry if its hooked up - await send_telemetry( + # send telemetry to hooks + await emit_model_usage( + model_name=str(self), usage=output.usage, call_duration=output.time + ) + await send_telemetry_legacy( "model_usage", json.dumps(dict(model=str(self), usage=output.usage.model_dump())), ) @@ -922,6 +928,8 @@ def get_model( Model instance. """ + from inspect_ai.hooks._startup import init_hooks + # start with seeing if a model was passed if isinstance(model, Model): return model diff --git a/tests/hooks/test_hooks.py b/tests/hooks/test_hooks.py new file mode 100644 index 0000000000..59c1ce86a4 --- /dev/null +++ b/tests/hooks/test_hooks.py @@ -0,0 +1,201 @@ +from typing import Generator, Type, TypeVar +from unittest.mock import patch + +import pytest + +from inspect_ai import eval +from inspect_ai._eval.task.task import Task +from inspect_ai._util.environ import environ_var +from inspect_ai._util.registry import _registry, registry_lookup +from inspect_ai.dataset._dataset import Sample +from inspect_ai.hooks._hooks import ( + ApiKeyOverride, + Hooks, + ModelUsageData, + RunEnd, + RunStart, + SampleAbort, + SampleEnd, + SampleStart, + TaskEnd, + TaskStart, + hooks, + override_api_key, +) + + +class MockHook(Hooks): + def __init__(self) -> None: + self.should_enable = True + self.run_start_events: list[RunStart] = [] + self.run_end_events: list[RunEnd] = [] + self.task_start_events: list[TaskStart] = [] + self.task_end_events: list[TaskEnd] = [] + self.sample_start_events: list[SampleStart] = [] + self.sample_end_events: list[SampleEnd] = [] + self.sample_abort_events: list[SampleAbort] = [] + self.model_usage_events: list[ModelUsageData] = [] + + def assert_no_events(self) -> None: + assert not self.run_start_events + assert not self.run_end_events + assert not self.task_start_events + assert not self.task_end_events + assert not self.sample_start_events + assert not self.sample_end_events + assert not self.sample_abort_events + assert not self.model_usage_events + + def enabled(self) -> bool: + return self.should_enable + + async def on_run_start(self, data: RunStart) -> None: + self.run_start_events.append(data) + + async def on_run_end(self, data: RunEnd) -> None: + self.run_end_events.append(data) + + async def on_task_start(self, data: TaskStart) -> None: + self.task_start_events.append(data) + + async def on_task_end(self, data: TaskEnd) -> None: + self.task_end_events.append(data) + + async def on_sample_start(self, data: SampleStart) -> None: + self.sample_start_events.append(data) + + async def on_sample_end(self, data: SampleEnd) -> None: + self.sample_end_events.append(data) + + async def on_sample_abort(self, data: SampleAbort) -> None: + self.sample_abort_events.append(data) + + async def on_model_usage(self, data: ModelUsageData) -> None: + self.model_usage_events.append(data) + + def override_api_key(self, data: ApiKeyOverride) -> str | None: + return f"mocked-{data.env_var_name}-{data.value}" + + +class MockMinimalHook(Hooks): + def __init__(self) -> None: + self.run_start_events: list[RunStart] = [] + + async def on_run_start(self, data: RunStart) -> None: + self.run_start_events.append(data) + + +@pytest.fixture +def mock_hook() -> Generator[MockHook, None, None]: + yield from _create_mock_hook("test_hook", MockHook) + + +@pytest.fixture +def hook_2() -> Generator[MockHook, None, None]: + yield from _create_mock_hook("test_hook_2", MockHook) + + +@pytest.fixture +def hook_minimal() -> Generator[MockMinimalHook, None, None]: + yield from _create_mock_hook("test_hook_minimal", MockMinimalHook) + + +def test_can_run_eval_with_no_hooks() -> None: + eval(Task(dataset=[Sample("hello"), Sample("bye")], model="mockllm/model")) + + +def test_respects_enabled(mock_hook: MockHook) -> None: + mock_hook.assert_no_events() + + mock_hook.should_enable = False + eval(Task(dataset=[Sample("hello"), Sample("bye")], model="mockllm/model")) + + mock_hook.assert_no_events() + + mock_hook.should_enable = True + eval(Task(dataset=[Sample("hello"), Sample("bye")], model="mockllm/model")) + + assert len(mock_hook.run_start_events) == 1 + + +def test_can_subscribe_to_events(mock_hook: MockHook) -> None: + mock_hook.assert_no_events() + + eval(Task(dataset=[Sample("hello"), Sample("bye")], model="mockllm/model")) + + assert len(mock_hook.run_start_events) == 1 + assert mock_hook.run_start_events[0].run_id is not None + assert len(mock_hook.run_end_events) == 1 + assert len(mock_hook.task_start_events) == 1 + assert len(mock_hook.task_end_events) == 1 + assert len(mock_hook.sample_start_events) == 2 + assert len(mock_hook.sample_end_events) == 2 + assert len(mock_hook.sample_abort_events) == 0 + assert len(mock_hook.model_usage_events) == 0 + + +def test_can_subscribe_to_events_with_multiple_hooks( + mock_hook: MockHook, hook_2: MockHook +) -> None: + mock_hook.assert_no_events() + hook_2.assert_no_events() + + eval(Task(dataset=[Sample("hello"), Sample("bye")], model="mockllm/model")) + + for h in (mock_hook, hook_2): + assert len(h.run_start_events) == 1 + assert h.run_start_events[0].run_id is not None + assert len(h.run_end_events) == 1 + assert len(h.task_start_events) == 1 + assert len(h.task_end_events) == 1 + assert len(h.sample_start_events) == 2 + assert len(h.sample_end_events) == 2 + assert len(h.sample_abort_events) == 0 + assert len(h.model_usage_events) == 0 + + +def test_hook_does_not_need_to_subscribe_to_all_events( + hook_minimal: MockMinimalHook, +) -> None: + eval(Task(dataset=[Sample("hello"), Sample("bye")], model="mockllm/model")) + + assert len(hook_minimal.run_start_events) == 1 + + +def test_api_key_override(mock_hook: MockHook) -> None: + overridden = override_api_key("TEST_VAR", "test_value") + + assert overridden == "mocked-TEST_VAR-test_value" + + +def test_api_key_override_falls_back_to_legacy(mock_hook: MockHook) -> None: + mock_hook.should_enable = False + + with environ_var("INSPECT_API_KEY_OVERRIDE", "._legacy_hook_override"): + with patch( + "inspect_ai.hooks._hooks.override_api_key_legacy", _legacy_hook_override + ): + overridden = override_api_key("TEST_VAR", "test_value") + + assert overridden == "legacy-TEST_VAR-test_value" + + +def _legacy_hook_override(var: str, value: str) -> str | None: + return f"legacy-{var}-{value}" + + +T = TypeVar("T", bound=Hooks) + + +def _create_mock_hook(name: str, hook_class: Type[T]) -> Generator[T, None, None]: + @hooks(name) + def get_hook_class() -> type[T]: + return hook_class + + hook = registry_lookup("hooks", name) + assert isinstance(hook, hook_class) + try: + yield hook + finally: + # Remove the hook from the registry to avoid conflicts in other tests. + del _registry[f"hooks:{name}"] diff --git a/tests/test_extensions.py b/tests/test_extensions.py index 005e64a315..4ee0841116 100644 --- a/tests/test_extensions.py +++ b/tests/test_extensions.py @@ -7,6 +7,7 @@ from inspect_ai import Task, eval_async from inspect_ai.dataset import Sample +from inspect_ai.hooks._hooks import emit_run_start from inspect_ai.model import get_model from inspect_ai.scorer import includes from inspect_ai.solver import generate, use_tools @@ -107,3 +108,13 @@ def test_supports_str_config(): assert recreated == spec assert recreated.config == spec.config assert isinstance(recreated.config, str) + + +async def test_hooks(): + ensure_test_package_installed() + module = importlib.import_module("inspect_package.hooks.custom") + module.run_ids = [] + + await emit_run_start(run_id="42", tasks=[]) + + assert module.run_ids == ["42"] diff --git a/tests/test_package/inspect_package/_registry.py b/tests/test_package/inspect_package/_registry.py index d690ccbff7..f632da1577 100644 --- a/tests/test_package/inspect_package/_registry.py +++ b/tests/test_package/inspect_package/_registry.py @@ -1,5 +1,6 @@ # ruff: noqa: F401 +from inspect_ai.hooks import hooks from inspect_ai.model import modelapi from inspect_ai.util import sandboxenv @@ -25,3 +26,10 @@ def podman(): from .sandboxenv.podman import PodmanSandboxEnvironment return PodmanSandboxEnvironment + + +@hooks(name="custom_hook") +def custom_hook(): + from .hooks.custom import CustomHooks + + return CustomHooks diff --git a/tests/test_package/inspect_package/hooks/custom.py b/tests/test_package/inspect_package/hooks/custom.py new file mode 100644 index 0000000000..f20642cc43 --- /dev/null +++ b/tests/test_package/inspect_package/hooks/custom.py @@ -0,0 +1,10 @@ +from inspect_ai.hooks import Hooks, RunStart + + +class CustomHooks(Hooks): + async def on_run_start(self, event: RunStart) -> None: + global run_ids + run_ids.append(event.run_id) + + +run_ids: list[str] = [] From 88f2184830d5ea0071c611a03cdb7aefad559a4b Mon Sep 17 00:00:00 2001 From: "J.J. Allaire" Date: Thu, 26 Jun 2025 10:37:50 -0400 Subject: [PATCH 02/24] doc bootstrap --- docs/reference/_sidebar.yml | 25 ++++++++++ docs/reference/filter/parse.py | 2 + docs/reference/filter/sidebar.py | 3 +- docs/reference/inspect_ai.hooks.qmd | 19 ++++++++ docs/scripts/post-render.sh | 2 +- src/inspect_ai/hooks/_hooks.py | 46 ++++++++++++++++++- src/inspect_ai/model/_providers/util/hooks.py | 4 +- 7 files changed, 96 insertions(+), 5 deletions(-) create mode 100644 docs/reference/inspect_ai.hooks.qmd diff --git a/docs/reference/_sidebar.yml b/docs/reference/_sidebar.yml index 214e4e67b7..3c14dc9c48 100644 --- a/docs/reference/_sidebar.yml +++ b/docs/reference/_sidebar.yml @@ -615,6 +615,31 @@ website: href: reference/inspect_ai.util.qmd#jsonschema - text: json_schema href: reference/inspect_ai.util.qmd#json_schema + - section: inspect_ai.hooks + href: reference/inspect_ai.hooks.qmd + contents: + - text: ApiKeyOverride + href: reference/inspect_ai.hooks.qmd#apikeyoverride + - text: Hooks + href: reference/inspect_ai.hooks.qmd#hooks + - text: ModelUsageData + href: reference/inspect_ai.hooks.qmd#modelusagedata + - text: RunEnd + href: reference/inspect_ai.hooks.qmd#runend + - text: RunStart + href: reference/inspect_ai.hooks.qmd#runstart + - text: SampleAbort + href: reference/inspect_ai.hooks.qmd#sampleabort + - text: SampleEnd + href: reference/inspect_ai.hooks.qmd#sampleend + - text: SampleStart + href: reference/inspect_ai.hooks.qmd#samplestart + - text: TaskEnd + href: reference/inspect_ai.hooks.qmd#taskend + - text: TaskStart + href: reference/inspect_ai.hooks.qmd#taskstart + - text: 'hooks ' + href: 'reference/inspect_ai.hooks.qmd#hooks ' - section: Inspect CLI href: reference/inspect_eval.qmd contents: diff --git a/docs/reference/filter/parse.py b/docs/reference/filter/parse.py index be899727cd..83fe3a2784 100644 --- a/docs/reference/filter/parse.py +++ b/docs/reference/filter/parse.py @@ -1,3 +1,4 @@ +import sys from dataclasses import dataclass from itertools import islice from pathlib import Path @@ -268,6 +269,7 @@ def read_source( object: Object, options: DocParseOptions ) -> tuple[str, str, list[DocstringSection]]: # assert preconditions + sys.stderr.write(object.name + "\n") assert isinstance(object.filepath, Path) assert object.lineno is not None assert object.docstring is not None diff --git a/docs/reference/filter/sidebar.py b/docs/reference/filter/sidebar.py index feed03b221..59259a7410 100644 --- a/docs/reference/filter/sidebar.py +++ b/docs/reference/filter/sidebar.py @@ -1,6 +1,6 @@ import json import os -import yaml +import yaml # type: ignore # only execute if a reference doc is in the inputs @@ -22,6 +22,7 @@ "log.qmd", "analysis.qmd", "util.qmd", + "hooks.qmd" ] ] diff --git a/docs/reference/inspect_ai.hooks.qmd b/docs/reference/inspect_ai.hooks.qmd new file mode 100644 index 0000000000..442826db2d --- /dev/null +++ b/docs/reference/inspect_ai.hooks.qmd @@ -0,0 +1,19 @@ +--- +title: "inspect_ai.hooks" +--- + +### ApiKeyOverride +### Hooks + + + diff --git a/docs/scripts/post-render.sh b/docs/scripts/post-render.sh index 95945430bc..eef1fbf5e4 100755 --- a/docs/scripts/post-render.sh +++ b/docs/scripts/post-render.sh @@ -1,6 +1,6 @@ #!/bin/bash -files=("index" "tutorial" "options" "log-viewer" "vscode" "tasks" "datasets" "solvers" "scorers" "models" "providers" "caching" "multimodal" "reasoning" "structured" "tools" "tools-standard" "tools-mcp" "tools-custom" "sandboxing" "approval" "agents" "react-agent" "agent-custom" "agent-bridge" "human-agent" "eval-logs" "dataframe" "eval-sets" "errors-and-limits" "typing" "tracing" "parallelism" "interactivity" "extensions" "reference/inspect_ai" "reference/inspect_ai.solver" "reference/inspect_ai.tool" "reference/inspect_ai.agent" "reference/inspect_ai.scorer" "reference/inspect_ai.model" "reference/inspect_ai.agent" "reference/inspect_ai.dataset" "reference/inspect_ai.approval" "reference/inspect_ai.log" "reference/inspect_ai.analysis" "reference/inspect_ai.util" "reference/inspect_eval" "reference/inspect_eval-set" "reference/inspect_eval-retry" "reference/inspect_score" "reference/inspect_view" "reference/inspect_log" "reference/inspect_trace" "reference/inspect_sandbox" "reference/inspect_cache" "reference/inspect_list" "reference/inspect_info") +files=("index" "tutorial" "options" "log-viewer" "vscode" "tasks" "datasets" "solvers" "scorers" "models" "providers" "caching" "multimodal" "reasoning" "structured" "tools" "tools-standard" "tools-mcp" "tools-custom" "sandboxing" "approval" "agents" "react-agent" "agent-custom" "agent-bridge" "human-agent" "eval-logs" "dataframe" "eval-sets" "errors-and-limits" "typing" "tracing" "parallelism" "interactivity" "extensions" "reference/inspect_ai" "reference/inspect_ai.solver" "reference/inspect_ai.tool" "reference/inspect_ai.agent" "reference/inspect_ai.scorer" "reference/inspect_ai.model" "reference/inspect_ai.agent" "reference/inspect_ai.dataset" "reference/inspect_ai.approval" "reference/inspect_ai.log" "reference/inspect_ai.analysis" "reference/inspect_ai.util" "reference/inspect_ai.hooks" "reference/inspect_eval" "reference/inspect_eval-set" "reference/inspect_eval-retry" "reference/inspect_score" "reference/inspect_view" "reference/inspect_log" "reference/inspect_trace" "reference/inspect_sandbox" "reference/inspect_cache" "reference/inspect_list" "reference/inspect_info") if [ "$QUARTO_PROJECT_RENDER_ALL" = "1" ]; then diff --git a/src/inspect_ai/hooks/_hooks.py b/src/inspect_ai/hooks/_hooks.py index 3918b2248b..240af28a36 100644 --- a/src/inspect_ai/hooks/_hooks.py +++ b/src/inspect_ai/hooks/_hooks.py @@ -78,6 +78,8 @@ class ModelUsageData: @dataclass(frozen=True) class ApiKeyOverride: + """Api key override info.""" + env_var_name: str """The name of the environment var containing the API key (e.g. OPENAI_API_KEY).""" value: str @@ -106,28 +108,67 @@ def enabled(self) -> bool: return True async def on_run_start(self, data: RunStart) -> None: + """On run start. + + Args: + data: Run start data. + """ pass async def on_run_end(self, data: RunEnd) -> None: + """On run end. + + Args: + data: Run end data. + """ pass async def on_task_start(self, data: TaskStart) -> None: + """On task start. + + Args: + data: Task start data. + """ pass async def on_task_end(self, data: TaskEnd) -> None: + """On task end. + + Args: + data: Task end data. + """ pass async def on_sample_start(self, data: SampleStart) -> None: + """On sample start. + + Args: + data: Sample start data. + """ pass async def on_sample_end(self, data: SampleEnd) -> None: + """On sample end. + + Args: + data: Sample end data. + """ pass async def on_sample_abort(self, data: SampleAbort) -> None: - """A sample has been aborted due to an error, and will not be retried.""" + """A sample has been aborted due to an error, and will not be retried. + + Args: + data: Sample end data. + """ pass async def on_model_usage(self, data: ModelUsageData) -> None: + """On model usage. + + Args: + data: Model usage data. + """ pass def override_api_key(self, data: ApiKeyOverride) -> str | None: @@ -136,6 +177,9 @@ def override_api_key(self, data: ApiKeyOverride) -> str | None: When overridden, this method may return a new API key value which will be used in place of the original one during the eval. + Args: + data: Api key override data. + Returns: str | None: The new API key value to use, or None to use the original value. """ diff --git a/src/inspect_ai/model/_providers/util/hooks.py b/src/inspect_ai/model/_providers/util/hooks.py index 3326e2b525..2732455add 100644 --- a/src/inspect_ai/model/_providers/util/hooks.py +++ b/src/inspect_ai/model/_providers/util/hooks.py @@ -70,7 +70,7 @@ def update_request_time(self, request_id: str) -> None: class ConverseHooks(HttpHooks): def __init__(self, session: Any) -> None: - from aiobotocore.session import AioSession + from aiobotocore.session import AioSession # type: ignore super().__init__() @@ -91,7 +91,7 @@ def converse_before_send(self, **kwargs: Any) -> None: self.update_request_time(request_id) def converse_after_call(self, http_response: Any, **kwargs: Any) -> None: - from botocore.awsrequest import AWSResponse + from botocore.awsrequest import AWSResponse # type: ignore response = cast(AWSResponse, http_response) logger.log(HTTP, f"POST {response.url} - {response.status_code}") From 2ee598870fb2e5858d5dd8e194aeec68e295a8d0 Mon Sep 17 00:00:00 2001 From: "Craig.Walton" Date: Thu, 26 Jun 2025 13:09:24 +0000 Subject: [PATCH 03/24] Improve comments and docs. --- docs/extensions.qmd | 2 +- src/inspect_ai/hooks/_legacy.py | 6 +++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/docs/extensions.qmd b/docs/extensions.qmd index f28029bfc3..dc01cfd59f 100644 --- a/docs/extensions.qmd +++ b/docs/extensions.qmd @@ -455,7 +455,7 @@ class WBHook(Hooks): }) ``` -Implementations of the `Hooks` class can override an `enabled()` method. +See the `Hooks` class for more documentation and the full list of available hooks. ### API Key Override diff --git a/src/inspect_ai/hooks/_legacy.py b/src/inspect_ai/hooks/_legacy.py index 73933c1cc1..489c3f3178 100644 --- a/src/inspect_ai/hooks/_legacy.py +++ b/src/inspect_ai/hooks/_legacy.py @@ -1,4 +1,8 @@ -"""Legacy hooks for telemetry and API key overrides.""" +"""Legacy hooks for telemetry and API key overrides. + +These are deprecated and will be removed in a future release. Please use the new hooks +defined in `inspect_ai.hooks` instead. +""" import importlib import os From 60bdccd66780d2d4e23134d7268e0e834bd95cb0 Mon Sep 17 00:00:00 2001 From: "Craig.Walton" Date: Thu, 26 Jun 2025 16:24:16 +0000 Subject: [PATCH 04/24] Improve docs. --- docs/reference/_sidebar.yml | 8 ++++---- docs/reference/inspect_ai.hooks.qmd | 9 +++++---- src/inspect_ai/hooks/_hooks.py | 22 +++++++++++++++++++++- 3 files changed, 30 insertions(+), 9 deletions(-) diff --git a/docs/reference/_sidebar.yml b/docs/reference/_sidebar.yml index 3c14dc9c48..1f78a5f261 100644 --- a/docs/reference/_sidebar.yml +++ b/docs/reference/_sidebar.yml @@ -618,10 +618,12 @@ website: - section: inspect_ai.hooks href: reference/inspect_ai.hooks.qmd contents: - - text: ApiKeyOverride - href: reference/inspect_ai.hooks.qmd#apikeyoverride - text: Hooks href: reference/inspect_ai.hooks.qmd#hooks + - text: hooks + href: reference/inspect_ai.hooks.qmd#hooks + - text: ApiKeyOverride + href: reference/inspect_ai.hooks.qmd#apikeyoverride - text: ModelUsageData href: reference/inspect_ai.hooks.qmd#modelusagedata - text: RunEnd @@ -638,8 +640,6 @@ website: href: reference/inspect_ai.hooks.qmd#taskend - text: TaskStart href: reference/inspect_ai.hooks.qmd#taskstart - - text: 'hooks ' - href: 'reference/inspect_ai.hooks.qmd#hooks ' - section: Inspect CLI href: reference/inspect_eval.qmd contents: diff --git a/docs/reference/inspect_ai.hooks.qmd b/docs/reference/inspect_ai.hooks.qmd index 442826db2d..ce94b93c7c 100644 --- a/docs/reference/inspect_ai.hooks.qmd +++ b/docs/reference/inspect_ai.hooks.qmd @@ -2,11 +2,14 @@ title: "inspect_ai.hooks" --- -### ApiKeyOverride +## Registration + ### Hooks +### hooks +## Hook Data - diff --git a/src/inspect_ai/hooks/_hooks.py b/src/inspect_ai/hooks/_hooks.py index 240af28a36..4a3bb6759a 100644 --- a/src/inspect_ai/hooks/_hooks.py +++ b/src/inspect_ai/hooks/_hooks.py @@ -21,18 +21,24 @@ @dataclass(frozen=True) class RunStart: + """Run start hook event data.""" + run_id: str task_names: list[str] @dataclass(frozen=True) class RunEnd: + """Run end hook event data.""" + run_id: str logs: EvalLogs @dataclass(frozen=True) class TaskStart: + """Task start hook event data.""" + run_id: str eval_id: str spec: EvalSpec @@ -40,6 +46,8 @@ class TaskStart: @dataclass(frozen=True) class TaskEnd: + """Task end hook event data.""" + run_id: str eval_id: str log: EvalLog @@ -47,6 +55,8 @@ class TaskEnd: @dataclass(frozen=True) class SampleStart: + """Sample start hook event data.""" + run_id: str eval_id: str sample_id: int | str @@ -55,6 +65,8 @@ class SampleStart: @dataclass(frozen=True) class SampleEnd: + """Sample end hook event data.""" + run_id: str eval_id: str sample_id: int | str @@ -63,6 +75,8 @@ class SampleEnd: @dataclass(frozen=True) class SampleAbort: + """Sample abort hook event data.""" + run_id: str eval_id: str sample_id: int | str @@ -71,6 +85,8 @@ class SampleAbort: @dataclass(frozen=True) class ModelUsageData: + """Model usage hook event data.""" + model_name: str usage: ModelUsage call_duration: float | None = None @@ -78,7 +94,7 @@ class ModelUsageData: @dataclass(frozen=True) class ApiKeyOverride: - """Api key override info.""" + """Api key override hook event data.""" env_var_name: str """The name of the environment var containing the API key (e.g. OPENAI_API_KEY).""" @@ -192,6 +208,10 @@ def override_api_key(self, data: ApiKeyOverride) -> str | None: def hooks(name: str) -> Callable[..., Type[T]]: """Decorator for registering a hook subscriber. + Either decorate a subclass of `Hooks`, or a function which returns the type + of a subclass of `Hooks`. This decorator will instantiate the hook class + and store it in the registry. + Args: name (str): Name of the subscriber. """ From 3273d39c64e8186358ddbe8c80ac4171be2bb4c8 Mon Sep 17 00:00:00 2001 From: "Craig.Walton" Date: Thu, 26 Jun 2025 16:53:54 +0000 Subject: [PATCH 05/24] Pass run_id and task_id, refactor creation of EvalSample. --- src/inspect_ai/_eval/task/run.py | 46 +++++++++++++++++--------------- 1 file changed, 25 insertions(+), 21 deletions(-) diff --git a/src/inspect_ai/_eval/task/run.py b/src/inspect_ai/_eval/task/run.py index f5263ac7bf..bbc24e0b74 100644 --- a/src/inspect_ai/_eval/task/run.py +++ b/src/inspect_ai/_eval/task/run.py @@ -346,6 +346,8 @@ async def run(tg: TaskGroup) -> None: time_limit=config.time_limit, working_limit=config.working_limit, semaphore=sample_semaphore, + run_id=logger.eval.run_id, + task_id=logger.eval.eval_id, ) finally: tg.cancel_scope.cancel() @@ -439,7 +441,7 @@ async def run(tg: TaskGroup) -> None: view_notify_eval(logger.location) try: - # TODO: Also send this to "new" hooks (including eval log location). + # Log file locations are emitted to the "new" hooks via the "task end" event, if ( await send_telemetry_legacy("eval_log_location", eval_log.location) == "not_handled" @@ -543,6 +545,8 @@ async def task_run_sample( time_limit: int | None, working_limit: int | None, semaphore: anyio.Semaphore | None, + run_id: str, + task_id: str, ) -> dict[str, SampleScore] | None: from inspect_ai.hooks._hooks import ( emit_sample_abort, @@ -692,6 +696,7 @@ def log_sample_error() -> None: # mark started active.started = datetime.now().timestamp() + # emit/log sample start sample_summary = EvalSampleSummary( id=sample_id, epoch=state.epoch, @@ -699,11 +704,11 @@ def log_sample_error() -> None: target=sample.target, metadata=sample.metadata or {}, ) - if logger is not None: await logger.start_sample(sample_summary) - # TODO: Where can I get run_id and eval_id? - await emit_sample_start("", "", sample_id, sample_summary) + await emit_sample_start( + run_id, task_id, sample_id, sample_summary + ) # set progress for plan then run it async with span("solvers"): @@ -852,23 +857,23 @@ def log_sample_error() -> None: sample = sample_without_base64_content(sample) state = state_without_base64_content(state) - # log the sample - eval_sample = await log_sample( + # emit/log sample end + eval_sample = create_eval_sample( start_time=start_time, - logger=logger, sample=sample, state=state, scores=results, error=error, limit=limit, error_retries=error_retries, - log_images=log_images, ) - - # TODO: Where can I get run_id and eval_id? + if logger: + await log_sample( + eval_sample=eval_sample, logger=logger, log_images=log_images + ) # TODO: Do we only want to emit sample end if there was no error? if not error: - await emit_sample_end("", "", sample_id, eval_sample.summary()) + await emit_sample_end(run_id, task_id, sample_id, eval_sample.summary()) # error that should be retried (we do this outside of the above scope so that we can # retry outside of the original semaphore -- our retry will therefore go to the back @@ -906,6 +911,8 @@ def log_sample_error() -> None: working_limit=working_limit, semaphore=semaphore, tg=tg, + run_id=run_id, + task_id=task_id, ) # no error @@ -917,7 +924,7 @@ def log_sample_error() -> None: # we have an error and should raise it elif raise_error is not None: - await emit_sample_abort("", "", sample_id, error) + await emit_sample_abort(run_id, task_id, sample_id, error) raise raise_error # we have an error and should not raise it @@ -926,16 +933,14 @@ def log_sample_error() -> None: return None -async def log_sample( +def create_eval_sample( start_time: float | None, - logger: TaskLogger | None, sample: Sample, state: TaskState, scores: dict[str, SampleScore], error: EvalError | None, limit: EvalSampleLimit | None, error_retries: list[EvalError], - log_images: bool, ) -> EvalSample: # sample must have id to be logged id = sample.id @@ -949,7 +954,7 @@ async def log_sample( # compute total time if we can total_time = time.monotonic() - start_time if start_time is not None else None - eval_sample = EvalSample( + return EvalSample( id=id, epoch=state.epoch, input=sample.input, @@ -975,12 +980,11 @@ async def log_sample( limit=limit, ) - if logger is not None: - await logger.complete_sample( - condense_sample(eval_sample, log_images), flush=True - ) - return eval_sample +async def log_sample( + eval_sample: EvalSample, logger: TaskLogger, log_images: bool +) -> None: + await logger.complete_sample(condense_sample(eval_sample, log_images), flush=True) async def resolve_dataset( From 631c476105c366db6ca97279bb026be14d78c544 Mon Sep 17 00:00:00 2001 From: "Craig.Walton" Date: Thu, 26 Jun 2025 17:00:56 +0000 Subject: [PATCH 06/24] Remove unused type ignores. --- docs/reference/filter/sidebar.py | 2 +- src/inspect_ai/model/_providers/util/hooks.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/reference/filter/sidebar.py b/docs/reference/filter/sidebar.py index 59259a7410..2bfedec075 100644 --- a/docs/reference/filter/sidebar.py +++ b/docs/reference/filter/sidebar.py @@ -1,6 +1,6 @@ import json import os -import yaml # type: ignore +import yaml # only execute if a reference doc is in the inputs diff --git a/src/inspect_ai/model/_providers/util/hooks.py b/src/inspect_ai/model/_providers/util/hooks.py index 2732455add..3326e2b525 100644 --- a/src/inspect_ai/model/_providers/util/hooks.py +++ b/src/inspect_ai/model/_providers/util/hooks.py @@ -70,7 +70,7 @@ def update_request_time(self, request_id: str) -> None: class ConverseHooks(HttpHooks): def __init__(self, session: Any) -> None: - from aiobotocore.session import AioSession # type: ignore + from aiobotocore.session import AioSession super().__init__() @@ -91,7 +91,7 @@ def converse_before_send(self, **kwargs: Any) -> None: self.update_request_time(request_id) def converse_after_call(self, http_response: Any, **kwargs: Any) -> None: - from botocore.awsrequest import AWSResponse # type: ignore + from botocore.awsrequest import AWSResponse response = cast(AWSResponse, http_response) logger.log(HTTP, f"POST {response.url} - {response.status_code}") From dafbcea40338237aadc8eb5303f3af4512db7056 Mon Sep 17 00:00:00 2001 From: "Craig.Walton" Date: Thu, 26 Jun 2025 17:01:16 +0000 Subject: [PATCH 07/24] Improve commentary/ordering. --- src/inspect_ai/_eval/eval.py | 3 ++- src/inspect_ai/hooks/_hooks.py | 1 + 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/inspect_ai/_eval/eval.py b/src/inspect_ai/_eval/eval.py index 66f7d061dd..d4457094b4 100644 --- a/src/inspect_ai/_eval/eval.py +++ b/src/inspect_ai/_eval/eval.py @@ -613,10 +613,11 @@ async def _eval_async_inner( # run tasks - 2 codepaths, one for the traditional task at a time # (w/ optional multiple models) and the other for true multi-task # (which requires different scheduling and UI) - await emit_run_start(run_id, resolved_tasks) task_definitions = len(resolved_tasks) // len(model) parallel = 1 if (task_definitions == 1 or max_tasks is None) else max_tasks + await emit_run_start(run_id, resolved_tasks) + # single task definition (could be multi-model) or max_tasks capped to 1 if parallel == 1: results: list[EvalLog] = [] diff --git a/src/inspect_ai/hooks/_hooks.py b/src/inspect_ai/hooks/_hooks.py index 4a3bb6759a..1c8dc7c229 100644 --- a/src/inspect_ai/hooks/_hooks.py +++ b/src/inspect_ai/hooks/_hooks.py @@ -304,6 +304,7 @@ def override_api_key(env_var_name: str, value: str) -> str | None: logger.warning( f"Exception calling override_api_key on hook '{hook.__class__.__name__}': {ex}" ) + # If none have been overridden, fall back to legacy behaviour. return override_api_key_legacy(env_var_name, value) From 8a6c7f116cf1652d4ff28fcbc1e9e83552ec6a0d Mon Sep 17 00:00:00 2001 From: "Craig.Walton" Date: Thu, 26 Jun 2025 17:18:58 +0000 Subject: [PATCH 08/24] Improve tests, remove now redundant comment. --- src/inspect_ai/_util/platform.py | 1 - tests/hooks/test_hooks.py | 17 ++++++++++++++--- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/src/inspect_ai/_util/platform.py b/src/inspect_ai/_util/platform.py index 0b02444735..9baf99e49e 100644 --- a/src/inspect_ai/_util/platform.py +++ b/src/inspect_ai/_util/platform.py @@ -22,7 +22,6 @@ def running_in_notebook() -> bool: def platform_init() -> None: from inspect_ai.hooks._startup import init_hooks - # TODO: init hooks here. Ensure idempotent. init_hooks() # set exception hook if we haven't already diff --git a/tests/hooks/test_hooks.py b/tests/hooks/test_hooks.py index 59c1ce86a4..7d0713b187 100644 --- a/tests/hooks/test_hooks.py +++ b/tests/hooks/test_hooks.py @@ -169,19 +169,30 @@ def test_api_key_override(mock_hook: MockHook) -> None: def test_api_key_override_falls_back_to_legacy(mock_hook: MockHook) -> None: + def legacy_hook_override(var: str, value: str) -> str | None: + return f"legacy-{var}-{value}" + mock_hook.should_enable = False with environ_var("INSPECT_API_KEY_OVERRIDE", "._legacy_hook_override"): with patch( - "inspect_ai.hooks._hooks.override_api_key_legacy", _legacy_hook_override + "inspect_ai.hooks._hooks.override_api_key_legacy", legacy_hook_override ): overridden = override_api_key("TEST_VAR", "test_value") assert overridden == "legacy-TEST_VAR-test_value" -def _legacy_hook_override(var: str, value: str) -> str | None: - return f"legacy-{var}-{value}" +def test_init_hooks_can_be_called_multiple_times(mock_hook: MockHook) -> None: + from inspect_ai.hooks._startup import init_hooks + + # Ensure that init_hooks can be called multiple times without issues. + init_hooks() + init_hooks() + + eval(Task(dataset=[Sample("hello"), Sample("bye")], model="mockllm/model")) + + assert len(mock_hook.run_start_events) == 1 T = TypeVar("T", bound=Hooks) From 0fbee1e90c58979635c976ef54b933abf065829d Mon Sep 17 00:00:00 2001 From: "Craig.Walton" Date: Fri, 27 Jun 2025 09:10:44 +0000 Subject: [PATCH 09/24] Split up f string expression to avoid "Cannot use an escape sequence (backslash) in f-strings on Python 3.10 (syntax was added in Python 3.12)". --- src/inspect_ai/hooks/_startup.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/inspect_ai/hooks/_startup.py b/src/inspect_ai/hooks/_startup.py index 78632e9e11..c9c6a50d8f 100644 --- a/src/inspect_ai/hooks/_startup.py +++ b/src/inspect_ai/hooks/_startup.py @@ -25,8 +25,9 @@ def init_hooks() -> None: _registry_hooks_loaded = True if hooks: hook_names = [f" {registry_info(hook).name}" for hook in hooks] + hook_names_joined = "\n".join(hook_names) messages.append( - f"[bold]hooks enabled: {len(hooks)}[/bold]\n{'\n'.join(hook_names)}" + f"[bold]hooks enabled: {len(hooks)}[/bold]\n{hook_names_joined}" ) # if any hooks are enabled, let the user know From 1a9051f4080c22601b7d3aa804890afb5fd52fed Mon Sep 17 00:00:00 2001 From: "Craig.Walton" Date: Fri, 27 Jun 2025 14:17:03 +0000 Subject: [PATCH 10/24] Improve docstrings for hooks. --- src/inspect_ai/hooks/_hooks.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/src/inspect_ai/hooks/_hooks.py b/src/inspect_ai/hooks/_hooks.py index 1c8dc7c229..5a04c13391 100644 --- a/src/inspect_ai/hooks/_hooks.py +++ b/src/inspect_ai/hooks/_hooks.py @@ -126,6 +126,10 @@ def enabled(self) -> bool: async def on_run_start(self, data: RunStart) -> None: """On run start. + A "run" is a single invocation of `eval()` or `eval_retry()` which may contain + many Tasks, each with many Samples and many epochs. Note that `eval_retry()` + can be invoked multiple times within an `eval_set()`. + Args: data: Run start data. """ @@ -158,6 +162,10 @@ async def on_task_end(self, data: TaskEnd) -> None: async def on_sample_start(self, data: SampleStart) -> None: """On sample start. + If a sample is run for multiple epochs, this will be called once per epoch. + + If a sample is retried, this will be called again for each new attempt. + Args: data: Sample start data. """ @@ -166,6 +174,10 @@ async def on_sample_start(self, data: SampleStart) -> None: async def on_sample_end(self, data: SampleEnd) -> None: """On sample end. + This will be called when a sample has completed without error. If there are + multiple epochs for a sample, this will be called once per successfully + completed epoch. + Args: data: Sample end data. """ @@ -174,13 +186,16 @@ async def on_sample_end(self, data: SampleEnd) -> None: async def on_sample_abort(self, data: SampleAbort) -> None: """A sample has been aborted due to an error, and will not be retried. + If there are multiple epochs for a sample, this will be called once per + aborted epoch of the sample. + Args: data: Sample end data. """ pass async def on_model_usage(self, data: ModelUsageData) -> None: - """On model usage. + """Called when a call to a model's generate() method completes successfully. Args: data: Model usage data. From 5f23ec1219f3e9f3683f47eeea7baed90aaf862e Mon Sep 17 00:00:00 2001 From: "Craig.Walton" Date: Fri, 27 Jun 2025 14:18:36 +0000 Subject: [PATCH 11/24] Improve comment. --- src/inspect_ai/hooks/_hooks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/inspect_ai/hooks/_hooks.py b/src/inspect_ai/hooks/_hooks.py index 5a04c13391..3d3999ea5c 100644 --- a/src/inspect_ai/hooks/_hooks.py +++ b/src/inspect_ai/hooks/_hooks.py @@ -238,7 +238,7 @@ def wrapper(hook_type: Type[T] | Callable[..., Type[T]]) -> Type[T]: if not issubclass(hook_type, Hooks): raise TypeError(f"Hook must be a subclass of Hooks, got {hook_type}") - # Instantiate an instance of the hook class. + # Instantiate an instance of the Hooks class. hook_instance = hook_type() hook_name = registry_name(hook_instance, name) registry_add( From f32d44bdbe4f0c59c34da15eca8ddfa6c83c4756 Mon Sep 17 00:00:00 2001 From: "Craig.Walton" Date: Fri, 27 Jun 2025 16:24:47 +0000 Subject: [PATCH 12/24] Clarify when model usage will be called re caching. --- src/inspect_ai/hooks/_hooks.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/inspect_ai/hooks/_hooks.py b/src/inspect_ai/hooks/_hooks.py index 3d3999ea5c..72faf8b9c0 100644 --- a/src/inspect_ai/hooks/_hooks.py +++ b/src/inspect_ai/hooks/_hooks.py @@ -197,6 +197,10 @@ async def on_sample_abort(self, data: SampleAbort) -> None: async def on_model_usage(self, data: ModelUsageData) -> None: """Called when a call to a model's generate() method completes successfully. + Note that this is not called when Inspect's local cache is used and is a cache + hit (i.e. if no external API call was made). Provider-side caching will result + in this being called. + Args: data: Model usage data. """ @@ -232,7 +236,7 @@ def hooks(name: str) -> Callable[..., Type[T]]: """ def wrapper(hook_type: Type[T] | Callable[..., Type[T]]) -> Type[T]: - # Resolve the hook type it's a function. + # Resolve the hook type if it's a function. if not isinstance(hook_type, type): hook_type = hook_type() if not issubclass(hook_type, Hooks): From b9c5fbe71b6b58b24862765cff2e5de8f37ef4e1 Mon Sep 17 00:00:00 2001 From: "Craig.Walton" Date: Mon, 30 Jun 2025 09:50:29 +0000 Subject: [PATCH 13/24] Add hook descriptions for printing to console. --- docs/extensions.qmd | 4 ++-- src/inspect_ai/hooks/_hooks.py | 10 +++++++--- src/inspect_ai/hooks/_startup.py | 9 ++++++++- tests/hooks/test_hooks.py | 11 +++++++++-- tests/test_package/inspect_package/_registry.py | 2 +- 5 files changed, 27 insertions(+), 9 deletions(-) diff --git a/docs/extensions.qmd b/docs/extensions.qmd index dc01cfd59f..1d807f9e0e 100644 --- a/docs/extensions.qmd +++ b/docs/extensions.qmd @@ -439,7 +439,7 @@ import wandb from inspect_ai.hooks import Hooks, RunEnd, RunStart, SampleEnd, hooks -@hooks(name="w&b_hook") +@hooks(name="w&b_hook", description="Weights & Biases integration") class WBHook(Hooks): async def on_run_start(self, data: RunStart) -> None: wandb.init(name=data.run_id) @@ -467,7 +467,7 @@ There is a hook to optionally override the value of model API key environment va ``` python from inspect_ai.hooks import hooks, Hooks, ApiKeyOverride -@hooks(name="api_key_fetcher") +@hooks(name="api_key_fetcher", description="Fetches API key from secrets manager") class ApiKeyFetcher(Hooks): def override_api_key(self, data: ApiKeyOverride) -> str | None: original_env_var_value = data.value diff --git a/src/inspect_ai/hooks/_hooks.py b/src/inspect_ai/hooks/_hooks.py index 72faf8b9c0..0ccc2abbe8 100644 --- a/src/inspect_ai/hooks/_hooks.py +++ b/src/inspect_ai/hooks/_hooks.py @@ -224,7 +224,7 @@ def override_api_key(self, data: ApiKeyOverride) -> str | None: T = TypeVar("T", bound=Hooks) -def hooks(name: str) -> Callable[..., Type[T]]: +def hooks(name: str, description: str) -> Callable[..., Type[T]]: """Decorator for registering a hook subscriber. Either decorate a subclass of `Hooks`, or a function which returns the type @@ -232,7 +232,9 @@ def hooks(name: str) -> Callable[..., Type[T]]: and store it in the registry. Args: - name (str): Name of the subscriber. + name (str): Name of the subscriber (e.g. "audit logging"). + description (str): Short description of the hook (e.g. "Copies eval files to + S3 bucket for auditing."). """ def wrapper(hook_type: Type[T] | Callable[..., Type[T]]) -> Type[T]: @@ -247,7 +249,9 @@ def wrapper(hook_type: Type[T] | Callable[..., Type[T]]) -> Type[T]: hook_name = registry_name(hook_instance, name) registry_add( hook_instance, - RegistryInfo(type="hooks", name=hook_name), + RegistryInfo( + type="hooks", name=hook_name, metadata={"description": description} + ), ) return cast(Type[T], hook_instance) diff --git a/src/inspect_ai/hooks/_startup.py b/src/inspect_ai/hooks/_startup.py index c9c6a50d8f..0bb9e4a4a8 100644 --- a/src/inspect_ai/hooks/_startup.py +++ b/src/inspect_ai/hooks/_startup.py @@ -4,6 +4,7 @@ from inspect_ai._util.constants import PKG_NAME from inspect_ai._util.registry import registry_info +from inspect_ai.hooks._hooks import Hooks from inspect_ai.hooks._legacy import init_legacy_hooks _registry_hooks_loaded: bool = False @@ -24,7 +25,7 @@ def init_hooks() -> None: hooks = get_all_hooks() _registry_hooks_loaded = True if hooks: - hook_names = [f" {registry_info(hook).name}" for hook in hooks] + hook_names = [f" {_format_hook_for_printing(hook)}" for hook in hooks] hook_names_joined = "\n".join(hook_names) messages.append( f"[bold]hooks enabled: {len(hooks)}[/bold]\n{hook_names_joined}" @@ -38,3 +39,9 @@ def init_hooks() -> None: f"[blue][bold]inspect_ai v{version}[/bold][/blue]\n" f"[bright_black]{all_messages}[/bright_black]\n" ) + + +def _format_hook_for_printing(hook: Hooks) -> str: + info = registry_info(hook) + description = info.metadata["description"] + return f"[bold]{info.name}[/bold]: {description}" diff --git a/tests/hooks/test_hooks.py b/tests/hooks/test_hooks.py index 7d0713b187..889ac5456a 100644 --- a/tests/hooks/test_hooks.py +++ b/tests/hooks/test_hooks.py @@ -6,7 +6,7 @@ from inspect_ai import eval from inspect_ai._eval.task.task import Task from inspect_ai._util.environ import environ_var -from inspect_ai._util.registry import _registry, registry_lookup +from inspect_ai._util.registry import _registry, registry_info, registry_lookup from inspect_ai.dataset._dataset import Sample from inspect_ai.hooks._hooks import ( ApiKeyOverride, @@ -195,11 +195,18 @@ def test_init_hooks_can_be_called_multiple_times(mock_hook: MockHook) -> None: assert len(mock_hook.run_start_events) == 1 +def test_hook_name_and_description(mock_hook: MockHook) -> None: + info = registry_info(mock_hook) + + assert info.name == "test_hook" + assert info.metadata["description"] == "test_hook-description" + + T = TypeVar("T", bound=Hooks) def _create_mock_hook(name: str, hook_class: Type[T]) -> Generator[T, None, None]: - @hooks(name) + @hooks(name, description=f"{name}-description") def get_hook_class() -> type[T]: return hook_class diff --git a/tests/test_package/inspect_package/_registry.py b/tests/test_package/inspect_package/_registry.py index f632da1577..2102e0db59 100644 --- a/tests/test_package/inspect_package/_registry.py +++ b/tests/test_package/inspect_package/_registry.py @@ -28,7 +28,7 @@ def podman(): return PodmanSandboxEnvironment -@hooks(name="custom_hook") +@hooks(name="custom_hook", description="Custom hooks") def custom_hook(): from .hooks.custom import CustomHooks From 9507ccf981f309d1792560803d840585a13a04e6 Mon Sep 17 00:00:00 2001 From: "Craig.Walton" Date: Mon, 30 Jun 2025 10:46:03 +0000 Subject: [PATCH 14/24] Only emit on_sample_start when sample is first tried - not on retries. --- src/inspect_ai/_eval/task/run.py | 8 ++++-- src/inspect_ai/hooks/_hooks.py | 2 +- tests/hooks/test_hooks.py | 47 ++++++++++++++++++++++++++++++++ 3 files changed, 53 insertions(+), 4 deletions(-) diff --git a/src/inspect_ai/_eval/task/run.py b/src/inspect_ai/_eval/task/run.py index bbc24e0b74..a82a369363 100644 --- a/src/inspect_ai/_eval/task/run.py +++ b/src/inspect_ai/_eval/task/run.py @@ -706,9 +706,11 @@ def log_sample_error() -> None: ) if logger is not None: await logger.start_sample(sample_summary) - await emit_sample_start( - run_id, task_id, sample_id, sample_summary - ) + # only emit the sample start once: not on retries + if not error_retries: + await emit_sample_start( + run_id, task_id, sample_id, sample_summary + ) # set progress for plan then run it async with span("solvers"): diff --git a/src/inspect_ai/hooks/_hooks.py b/src/inspect_ai/hooks/_hooks.py index 0ccc2abbe8..6e1e25aa0f 100644 --- a/src/inspect_ai/hooks/_hooks.py +++ b/src/inspect_ai/hooks/_hooks.py @@ -164,7 +164,7 @@ async def on_sample_start(self, data: SampleStart) -> None: If a sample is run for multiple epochs, this will be called once per epoch. - If a sample is retried, this will be called again for each new attempt. + This is not called again on sample retries. Args: data: Sample start data. diff --git a/tests/hooks/test_hooks.py b/tests/hooks/test_hooks.py index 889ac5456a..6eb8213096 100644 --- a/tests/hooks/test_hooks.py +++ b/tests/hooks/test_hooks.py @@ -22,6 +22,8 @@ hooks, override_api_key, ) +from inspect_ai.solver._solver import Generate, Solver, solver +from inspect_ai.solver._task_state import TaskState class MockHook(Hooks): @@ -154,6 +156,36 @@ def test_can_subscribe_to_events_with_multiple_hooks( assert len(h.model_usage_events) == 0 +def test_hooks_on_sample_retries(mock_hook: MockHook) -> None: + eval( + Task( + dataset=[Sample("hello")], + model="mockllm/model", + solver=_fail_n_times_solver(2), + ), + retry_on_error=10, + ) + + assert len(mock_hook.sample_start_events) == 1 + assert len(mock_hook.sample_end_events) == 1 + assert len(mock_hook.sample_abort_events) == 0 + + +def test_hooks_on_sample_abort(mock_hook: MockHook) -> None: + eval( + Task( + dataset=[Sample("hello")], + model="mockllm/model", + solver=_fail_n_times_solver(10), + ), + retry_on_error=0, + ) + + assert len(mock_hook.sample_start_events) == 1 + assert len(mock_hook.sample_end_events) == 0 + assert len(mock_hook.sample_abort_events) == 1 + + def test_hook_does_not_need_to_subscribe_to_all_events( hook_minimal: MockMinimalHook, ) -> None: @@ -217,3 +249,18 @@ def get_hook_class() -> type[T]: finally: # Remove the hook from the registry to avoid conflicts in other tests. del _registry[f"hooks:{name}"] + + +@solver +def _fail_n_times_solver(target_failures: int) -> Solver: + """Fails N times, then succeeds.""" + attempts = 0 + + async def solve(state: TaskState, generate: Generate) -> TaskState: + nonlocal attempts + attempts += 1 + if attempts < target_failures: + raise RuntimeError(f"Simulated failure {attempts}") + return state + + return solve From b1be49eb1cf6770cfe68c9a7f2fe9a44d4941678 Mon Sep 17 00:00:00 2001 From: "Craig.Walton" Date: Mon, 30 Jun 2025 11:21:41 +0000 Subject: [PATCH 15/24] Reorganise hook tests. --- tests/hooks/test_hooks.py | 83 +++++++++++++++++++++++++++++---------- 1 file changed, 62 insertions(+), 21 deletions(-) diff --git a/tests/hooks/test_hooks.py b/tests/hooks/test_hooks.py index 6eb8213096..d81e10d1dc 100644 --- a/tests/hooks/test_hooks.py +++ b/tests/hooks/test_hooks.py @@ -103,19 +103,19 @@ def hook_minimal() -> Generator[MockMinimalHook, None, None]: def test_can_run_eval_with_no_hooks() -> None: - eval(Task(dataset=[Sample("hello"), Sample("bye")], model="mockllm/model")) + eval(Task(dataset=[Sample("sample_1")]), model="mockllm/model") def test_respects_enabled(mock_hook: MockHook) -> None: mock_hook.assert_no_events() mock_hook.should_enable = False - eval(Task(dataset=[Sample("hello"), Sample("bye")], model="mockllm/model")) + eval(Task(dataset=[Sample("sample_1")]), model="mockllm/model") mock_hook.assert_no_events() mock_hook.should_enable = True - eval(Task(dataset=[Sample("hello"), Sample("bye")], model="mockllm/model")) + eval(Task(dataset=[Sample("sample_1")]), model="mockllm/model") assert len(mock_hook.run_start_events) == 1 @@ -123,15 +123,15 @@ def test_respects_enabled(mock_hook: MockHook) -> None: def test_can_subscribe_to_events(mock_hook: MockHook) -> None: mock_hook.assert_no_events() - eval(Task(dataset=[Sample("hello"), Sample("bye")], model="mockllm/model")) + eval(Task(dataset=[Sample("sample_1")]), model="mockllm/model") assert len(mock_hook.run_start_events) == 1 assert mock_hook.run_start_events[0].run_id is not None assert len(mock_hook.run_end_events) == 1 assert len(mock_hook.task_start_events) == 1 assert len(mock_hook.task_end_events) == 1 - assert len(mock_hook.sample_start_events) == 2 - assert len(mock_hook.sample_end_events) == 2 + assert len(mock_hook.sample_start_events) == 1 + assert len(mock_hook.sample_end_events) == 1 assert len(mock_hook.sample_abort_events) == 0 assert len(mock_hook.model_usage_events) == 0 @@ -142,7 +142,7 @@ def test_can_subscribe_to_events_with_multiple_hooks( mock_hook.assert_no_events() hook_2.assert_no_events() - eval(Task(dataset=[Sample("hello"), Sample("bye")], model="mockllm/model")) + eval(Task(dataset=[Sample("sample_1")]), model="mockllm/model") for h in (mock_hook, hook_2): assert len(h.run_start_events) == 1 @@ -150,19 +150,63 @@ def test_can_subscribe_to_events_with_multiple_hooks( assert len(h.run_end_events) == 1 assert len(h.task_start_events) == 1 assert len(h.task_end_events) == 1 - assert len(h.sample_start_events) == 2 - assert len(h.sample_end_events) == 2 + assert len(h.sample_start_events) == 1 + assert len(h.sample_end_events) == 1 assert len(h.sample_abort_events) == 0 assert len(h.model_usage_events) == 0 +def test_hooks_on_multiple_tasks(mock_hook: MockHook) -> None: + eval( + [ + Task(dataset=[Sample("task_1_sample_1")]), + Task(dataset=[Sample("task_2_sample_1")]), + ], + model="mockllm/model", + ) + + assert len(mock_hook.run_start_events) == 1 + assert len(mock_hook.run_end_events) == 1 + assert len(mock_hook.task_start_events) == 2 + assert len(mock_hook.task_end_events) == 2 + assert len(mock_hook.sample_start_events) == 2 + assert len(mock_hook.sample_end_events) == 2 + assert len(mock_hook.sample_abort_events) == 0 + + +def test_hooks_on_multiple_samples(mock_hook: MockHook) -> None: + eval( + [ + Task(dataset=[Sample("sample_1"), Sample("sample_2")]), + ], + model="mockllm/model", + ) + + assert len(mock_hook.run_start_events) == 1 + assert len(mock_hook.run_end_events) == 1 + assert len(mock_hook.task_start_events) == 1 + assert len(mock_hook.task_end_events) == 1 + assert len(mock_hook.sample_start_events) == 2 + assert len(mock_hook.sample_end_events) == 2 + assert len(mock_hook.sample_abort_events) == 0 + + +def test_hooks_on_multiple_epochs(mock_hook: MockHook) -> None: + eval( + Task(dataset=[Sample("sample_1")]), + model="mockllm/model", + epochs=3, + ) + + assert len(mock_hook.sample_start_events) == 3 + assert len(mock_hook.sample_end_events) == 3 + assert len(mock_hook.sample_abort_events) == 0 + + def test_hooks_on_sample_retries(mock_hook: MockHook) -> None: eval( - Task( - dataset=[Sample("hello")], - model="mockllm/model", - solver=_fail_n_times_solver(2), - ), + Task(dataset=[Sample("sample_1")], solver=_fail_n_times_solver(2)), + model="mockllm/model", retry_on_error=10, ) @@ -173,11 +217,8 @@ def test_hooks_on_sample_retries(mock_hook: MockHook) -> None: def test_hooks_on_sample_abort(mock_hook: MockHook) -> None: eval( - Task( - dataset=[Sample("hello")], - model="mockllm/model", - solver=_fail_n_times_solver(10), - ), + Task(dataset=[Sample("sample_1")], solver=_fail_n_times_solver(10)), + model="mockllm/model", retry_on_error=0, ) @@ -189,7 +230,7 @@ def test_hooks_on_sample_abort(mock_hook: MockHook) -> None: def test_hook_does_not_need_to_subscribe_to_all_events( hook_minimal: MockMinimalHook, ) -> None: - eval(Task(dataset=[Sample("hello"), Sample("bye")], model="mockllm/model")) + eval(Task(dataset=[Sample("sample_1")]), model="mockllm/model") assert len(hook_minimal.run_start_events) == 1 @@ -222,7 +263,7 @@ def test_init_hooks_can_be_called_multiple_times(mock_hook: MockHook) -> None: init_hooks() init_hooks() - eval(Task(dataset=[Sample("hello"), Sample("bye")], model="mockllm/model")) + eval(Task(dataset=[Sample("sample_1")]), model="mockllm/model") assert len(mock_hook.run_start_events) == 1 From 15e0c93814e7a38e777da753846246e750054bf8 Mon Sep 17 00:00:00 2001 From: "Craig.Walton" Date: Mon, 30 Jun 2025 11:34:16 +0000 Subject: [PATCH 16/24] Add docstrings to properties of models. Remove type union with None for call duration. --- src/inspect_ai/hooks/_hooks.py | 31 +++++++++++++++++++++++++++++-- 1 file changed, 29 insertions(+), 2 deletions(-) diff --git a/src/inspect_ai/hooks/_hooks.py b/src/inspect_ai/hooks/_hooks.py index 6e1e25aa0f..4c9dd436f6 100644 --- a/src/inspect_ai/hooks/_hooks.py +++ b/src/inspect_ai/hooks/_hooks.py @@ -24,7 +24,9 @@ class RunStart: """Run start hook event data.""" run_id: str + """The unique identifier for the run.""" task_names: list[str] + """The names of the tasks which will be used in the run.""" @dataclass(frozen=True) @@ -32,7 +34,9 @@ class RunEnd: """Run end hook event data.""" run_id: str + """The unique identifier for the run.""" logs: EvalLogs + """All eval logs generated during the run.""" @dataclass(frozen=True) @@ -40,8 +44,11 @@ class TaskStart: """Task start hook event data.""" run_id: str + """The unique identifier for the run.""" eval_id: str + """The unique identifier for this task within the run.""" spec: EvalSpec + """Specification of the task.""" @dataclass(frozen=True) @@ -49,8 +56,11 @@ class TaskEnd: """Task end hook event data.""" run_id: str + """The unique identifier for the run.""" eval_id: str + """The unique identifier for this task within the run.""" log: EvalLog + """The log generated for this task.""" @dataclass(frozen=True) @@ -58,9 +68,12 @@ class SampleStart: """Sample start hook event data.""" run_id: str + """The unique identifier for the run.""" eval_id: str + """The unique identifier for the sample's task within the run.""" sample_id: int | str summary: EvalSampleSummary + """Summary of the sample to be run.""" @dataclass(frozen=True) @@ -68,9 +81,13 @@ class SampleEnd: """Sample end hook event data.""" run_id: str + """The unique identifier for the run.""" eval_id: str + """The unique identifier for the sample's task within the run.""" sample_id: int | str + # TODO: Are these different to the user-supplied Sample IDs? summary: EvalSampleSummary + """Summary of the sample that has run.""" @dataclass(frozen=True) @@ -78,9 +95,14 @@ class SampleAbort: """Sample abort hook event data.""" run_id: str + """The unique identifier for the run.""" eval_id: str + """The unique identifier for the sample's task within the run.""" sample_id: int | str + # TODO: Document sample id. error: EvalError + """The error that caused the sample to be aborted. If the sample has been retried, + this is the last error.""" @dataclass(frozen=True) @@ -88,8 +110,13 @@ class ModelUsageData: """Model usage hook event data.""" model_name: str + """The name of the model that was used.""" usage: ModelUsage - call_duration: float | None = None + """The model usage metrics.""" + call_duration: float + """The duration of the model call in seconds. If HTTP retries were made, this is the + time taken for the successful call. This excludes retry waiting (e.g. exponential + backoff) time.""" @dataclass(frozen=True) @@ -306,7 +333,7 @@ async def emit_sample_abort( async def emit_model_usage( - model_name: str, usage: ModelUsage, call_duration: float | None + model_name: str, usage: ModelUsage, call_duration: float ) -> None: data = ModelUsageData( model_name=model_name, usage=usage, call_duration=call_duration From 31adb4f7884e865653b407a70f8d5b2b3c694d09 Mon Sep 17 00:00:00 2001 From: "Craig.Walton" Date: Mon, 30 Jun 2025 14:33:55 +0000 Subject: [PATCH 17/24] Code review feedback: remove on_sample_abort, use sample UUID, add todo. --- src/inspect_ai/_eval/task/run.py | 14 ++---- src/inspect_ai/hooks/__init__.py | 2 - src/inspect_ai/hooks/_hooks.py | 83 +++++++++++--------------------- tests/hooks/test_hooks.py | 25 +++------- 4 files changed, 38 insertions(+), 86 deletions(-) diff --git a/src/inspect_ai/_eval/task/run.py b/src/inspect_ai/_eval/task/run.py index a82a369363..b224b928a6 100644 --- a/src/inspect_ai/_eval/task/run.py +++ b/src/inspect_ai/_eval/task/run.py @@ -548,11 +548,7 @@ async def task_run_sample( run_id: str, task_id: str, ) -> dict[str, SampleScore] | None: - from inspect_ai.hooks._hooks import ( - emit_sample_abort, - emit_sample_end, - emit_sample_start, - ) + from inspect_ai.hooks._hooks import emit_sample_end, emit_sample_start # if there is an existing sample then tick off its progress, log it, and return it if sample_source and sample.id is not None: @@ -709,7 +705,7 @@ def log_sample_error() -> None: # only emit the sample start once: not on retries if not error_retries: await emit_sample_start( - run_id, task_id, sample_id, sample_summary + run_id, task_id, state.uuid, sample_summary ) # set progress for plan then run it @@ -873,9 +869,7 @@ def log_sample_error() -> None: await log_sample( eval_sample=eval_sample, logger=logger, log_images=log_images ) - # TODO: Do we only want to emit sample end if there was no error? - if not error: - await emit_sample_end(run_id, task_id, sample_id, eval_sample.summary()) + await emit_sample_end(run_id, task_id, state.uuid, eval_sample.summary()) # error that should be retried (we do this outside of the above scope so that we can # retry outside of the original semaphore -- our retry will therefore go to the back @@ -926,11 +920,9 @@ def log_sample_error() -> None: # we have an error and should raise it elif raise_error is not None: - await emit_sample_abort(run_id, task_id, sample_id, error) raise raise_error # we have an error and should not raise it - # TODO: Do I need to emit a sample abort for this? else: return None diff --git a/src/inspect_ai/hooks/__init__.py b/src/inspect_ai/hooks/__init__.py index d8d74f768e..70c16c9a22 100644 --- a/src/inspect_ai/hooks/__init__.py +++ b/src/inspect_ai/hooks/__init__.py @@ -4,7 +4,6 @@ ModelUsageData, RunEnd, RunStart, - SampleAbort, SampleEnd, SampleStart, TaskEnd, @@ -18,7 +17,6 @@ "ModelUsageData", "RunEnd", "RunStart", - "SampleAbort", "SampleEnd", "SampleStart", "TaskEnd", diff --git a/src/inspect_ai/hooks/_hooks.py b/src/inspect_ai/hooks/_hooks.py index 4c9dd436f6..ba17da635f 100644 --- a/src/inspect_ai/hooks/_hooks.py +++ b/src/inspect_ai/hooks/_hooks.py @@ -5,7 +5,6 @@ from inspect_ai._eval.eval import EvalLogs from inspect_ai._eval.task.log import TaskLogger from inspect_ai._eval.task.resolved import ResolvedTask -from inspect_ai._util.error import EvalError from inspect_ai._util.registry import ( RegistryInfo, registry_add, @@ -24,7 +23,7 @@ class RunStart: """Run start hook event data.""" run_id: str - """The unique identifier for the run.""" + """The globally unique identifier for the run.""" task_names: list[str] """The names of the tasks which will be used in the run.""" @@ -34,9 +33,10 @@ class RunEnd: """Run end hook event data.""" run_id: str - """The unique identifier for the run.""" + """The globally unique identifier for the run.""" logs: EvalLogs - """All eval logs generated during the run.""" + """All eval logs generated during the run. Can be headers only if the run was an + `eval_set()`.""" @dataclass(frozen=True) @@ -44,9 +44,9 @@ class TaskStart: """Task start hook event data.""" run_id: str - """The unique identifier for the run.""" + """The globally unique identifier for the run.""" eval_id: str - """The unique identifier for this task within the run.""" + """The globally unique identifier for this task execution.""" spec: EvalSpec """Specification of the task.""" @@ -56,11 +56,12 @@ class TaskEnd: """Task end hook event data.""" run_id: str - """The unique identifier for the run.""" + """The globally unique identifier for the run.""" eval_id: str - """The unique identifier for this task within the run.""" + """The globally unique identifier for the task execution.""" log: EvalLog - """The log generated for this task.""" + """The log generated for the task. Can be header only if the run was an + `eval_set()`""" @dataclass(frozen=True) @@ -68,10 +69,11 @@ class SampleStart: """Sample start hook event data.""" run_id: str - """The unique identifier for the run.""" + """The globally unique identifier for the run.""" eval_id: str - """The unique identifier for the sample's task within the run.""" - sample_id: int | str + """The globally unique identifier for the task execution.""" + sample_id: str + """The globally unique identifier for the sample execution.""" summary: EvalSampleSummary """Summary of the sample to be run.""" @@ -81,30 +83,15 @@ class SampleEnd: """Sample end hook event data.""" run_id: str - """The unique identifier for the run.""" + """The globally unique identifier for the run.""" eval_id: str - """The unique identifier for the sample's task within the run.""" - sample_id: int | str - # TODO: Are these different to the user-supplied Sample IDs? + """The globally unique identifier for the task execution.""" + sample_id: str + """The globally unique identifier for the sample execution.""" summary: EvalSampleSummary """Summary of the sample that has run.""" -@dataclass(frozen=True) -class SampleAbort: - """Sample abort hook event data.""" - - run_id: str - """The unique identifier for the run.""" - eval_id: str - """The unique identifier for the sample's task within the run.""" - sample_id: int | str - # TODO: Document sample id. - error: EvalError - """The error that caused the sample to be aborted. If the sample has been retried, - this is the last error.""" - - @dataclass(frozen=True) class ModelUsageData: """Model usage hook event data.""" @@ -137,6 +124,8 @@ class Hooks: affect the overall execution of the eval. If a hook fails, a warning will be logged. """ + # TODO: Add name and description properties. + def enabled(self) -> bool: """Check if the hook should be enabled. @@ -189,9 +178,10 @@ async def on_task_end(self, data: TaskEnd) -> None: async def on_sample_start(self, data: SampleStart) -> None: """On sample start. - If a sample is run for multiple epochs, this will be called once per epoch. + Called when a sample is about to be start. If the sample errors and retries, + this will not be called again. - This is not called again on sample retries. + If a sample is run for multiple epochs, this will be called once per epoch. Args: data: Sample start data. @@ -201,20 +191,10 @@ async def on_sample_start(self, data: SampleStart) -> None: async def on_sample_end(self, data: SampleEnd) -> None: """On sample end. - This will be called when a sample has completed without error. If there are - multiple epochs for a sample, this will be called once per successfully - completed epoch. + Called when a sample has either completed successfully, or when a sample has + errored and has no retries remaining. - Args: - data: Sample end data. - """ - pass - - async def on_sample_abort(self, data: SampleAbort) -> None: - """A sample has been aborted due to an error, and will not be retried. - - If there are multiple epochs for a sample, this will be called once per - aborted epoch of the sample. + If a sample is run for multiple epochs, this will be called once per epoch. Args: data: Sample end data. @@ -308,7 +288,7 @@ async def emit_task_end(logger: TaskLogger, log: EvalLog) -> None: async def emit_sample_start( - run_id: str, eval_id: str, sample_id: int | str, summary: EvalSampleSummary + run_id: str, eval_id: str, sample_id: str, summary: EvalSampleSummary ) -> None: data = SampleStart( run_id=run_id, eval_id=eval_id, sample_id=sample_id, summary=summary @@ -317,7 +297,7 @@ async def emit_sample_start( async def emit_sample_end( - run_id: str, eval_id: str, sample_id: int | str, summary: EvalSampleSummary + run_id: str, eval_id: str, sample_id: str, summary: EvalSampleSummary ) -> None: data = SampleEnd( run_id=run_id, eval_id=eval_id, sample_id=sample_id, summary=summary @@ -325,13 +305,6 @@ async def emit_sample_end( await _emit_to_all(lambda hook: hook.on_sample_end(data)) -async def emit_sample_abort( - run_id: str, eval_id: str, sample_id: int | str, error: EvalError -) -> None: - data = SampleAbort(run_id=run_id, eval_id=eval_id, sample_id=sample_id, error=error) - await _emit_to_all(lambda hook: hook.on_sample_abort(data)) - - async def emit_model_usage( model_name: str, usage: ModelUsage, call_duration: float ) -> None: diff --git a/tests/hooks/test_hooks.py b/tests/hooks/test_hooks.py index d81e10d1dc..ad8115cdc2 100644 --- a/tests/hooks/test_hooks.py +++ b/tests/hooks/test_hooks.py @@ -14,7 +14,6 @@ ModelUsageData, RunEnd, RunStart, - SampleAbort, SampleEnd, SampleStart, TaskEnd, @@ -35,7 +34,6 @@ def __init__(self) -> None: self.task_end_events: list[TaskEnd] = [] self.sample_start_events: list[SampleStart] = [] self.sample_end_events: list[SampleEnd] = [] - self.sample_abort_events: list[SampleAbort] = [] self.model_usage_events: list[ModelUsageData] = [] def assert_no_events(self) -> None: @@ -45,7 +43,6 @@ def assert_no_events(self) -> None: assert not self.task_end_events assert not self.sample_start_events assert not self.sample_end_events - assert not self.sample_abort_events assert not self.model_usage_events def enabled(self) -> bool: @@ -69,9 +66,6 @@ async def on_sample_start(self, data: SampleStart) -> None: async def on_sample_end(self, data: SampleEnd) -> None: self.sample_end_events.append(data) - async def on_sample_abort(self, data: SampleAbort) -> None: - self.sample_abort_events.append(data) - async def on_model_usage(self, data: ModelUsageData) -> None: self.model_usage_events.append(data) @@ -132,7 +126,6 @@ def test_can_subscribe_to_events(mock_hook: MockHook) -> None: assert len(mock_hook.task_end_events) == 1 assert len(mock_hook.sample_start_events) == 1 assert len(mock_hook.sample_end_events) == 1 - assert len(mock_hook.sample_abort_events) == 0 assert len(mock_hook.model_usage_events) == 0 @@ -152,7 +145,6 @@ def test_can_subscribe_to_events_with_multiple_hooks( assert len(h.task_end_events) == 1 assert len(h.sample_start_events) == 1 assert len(h.sample_end_events) == 1 - assert len(h.sample_abort_events) == 0 assert len(h.model_usage_events) == 0 @@ -171,10 +163,9 @@ def test_hooks_on_multiple_tasks(mock_hook: MockHook) -> None: assert len(mock_hook.task_end_events) == 2 assert len(mock_hook.sample_start_events) == 2 assert len(mock_hook.sample_end_events) == 2 - assert len(mock_hook.sample_abort_events) == 0 -def test_hooks_on_multiple_samples(mock_hook: MockHook) -> None: +def test_hooks_with_multiple_samples(mock_hook: MockHook) -> None: eval( [ Task(dataset=[Sample("sample_1"), Sample("sample_2")]), @@ -188,10 +179,9 @@ def test_hooks_on_multiple_samples(mock_hook: MockHook) -> None: assert len(mock_hook.task_end_events) == 1 assert len(mock_hook.sample_start_events) == 2 assert len(mock_hook.sample_end_events) == 2 - assert len(mock_hook.sample_abort_events) == 0 -def test_hooks_on_multiple_epochs(mock_hook: MockHook) -> None: +def test_hooks_with_multiple_epochs(mock_hook: MockHook) -> None: eval( Task(dataset=[Sample("sample_1")]), model="mockllm/model", @@ -200,31 +190,30 @@ def test_hooks_on_multiple_epochs(mock_hook: MockHook) -> None: assert len(mock_hook.sample_start_events) == 3 assert len(mock_hook.sample_end_events) == 3 - assert len(mock_hook.sample_abort_events) == 0 -def test_hooks_on_sample_retries(mock_hook: MockHook) -> None: +def test_hooks_with_sample_retries(mock_hook: MockHook) -> None: eval( Task(dataset=[Sample("sample_1")], solver=_fail_n_times_solver(2)), model="mockllm/model", retry_on_error=10, ) + # Will succeed on 3rd attempt, but just 1 sample start and end event. assert len(mock_hook.sample_start_events) == 1 assert len(mock_hook.sample_end_events) == 1 - assert len(mock_hook.sample_abort_events) == 0 -def test_hooks_on_sample_abort(mock_hook: MockHook) -> None: +def test_hooks_with_error_and_no_retries(mock_hook: MockHook) -> None: eval( Task(dataset=[Sample("sample_1")], solver=_fail_n_times_solver(10)), model="mockllm/model", retry_on_error=0, ) + # Will fail on first attempt without any retries. assert len(mock_hook.sample_start_events) == 1 - assert len(mock_hook.sample_end_events) == 0 - assert len(mock_hook.sample_abort_events) == 1 + assert len(mock_hook.sample_end_events) == 1 def test_hook_does_not_need_to_subscribe_to_all_events( From 835d4a6a620f8810042e5eadee71e01b7022b863 Mon Sep 17 00:00:00 2001 From: "Craig.Walton" Date: Mon, 30 Jun 2025 15:21:03 +0000 Subject: [PATCH 18/24] Replace name & description parameters to @hooks() with name & description properties on class. --- docs/extensions.qmd | 19 +++++++++-- src/inspect_ai/hooks/_hooks.py | 33 ++++++++++++------- src/inspect_ai/hooks/_startup.py | 10 ++---- tests/hooks/test_hooks.py | 33 +++++++++++++++++-- .../test_package/inspect_package/_registry.py | 2 +- .../inspect_package/hooks/custom.py | 8 +++++ 6 files changed, 81 insertions(+), 24 deletions(-) diff --git a/docs/extensions.qmd b/docs/extensions.qmd index 1d807f9e0e..4dbafe87bd 100644 --- a/docs/extensions.qmd +++ b/docs/extensions.qmd @@ -439,8 +439,16 @@ import wandb from inspect_ai.hooks import Hooks, RunEnd, RunStart, SampleEnd, hooks -@hooks(name="w&b_hook", description="Weights & Biases integration") +@hooks() class WBHook(Hooks): + @property + def name(self) -> str: + return "w&b_hook" + + @property + def description(self) -> str: + return "Uploads run info to Weights & Biases" + async def on_run_start(self, data: RunStart) -> None: wandb.init(name=data.run_id) @@ -467,8 +475,15 @@ There is a hook to optionally override the value of model API key environment va ``` python from inspect_ai.hooks import hooks, Hooks, ApiKeyOverride -@hooks(name="api_key_fetcher", description="Fetches API key from secrets manager") class ApiKeyFetcher(Hooks): + @property + def name(self) -> str: + return "api_key_fetcher" + + @property + def description(self) -> str: + return "Fetches API key from secrets manager" + def override_api_key(self, data: ApiKeyOverride) -> str | None: original_env_var_value = data.value if original_env_var_value.startswith("arn:aws:secretsmanager:"): diff --git a/src/inspect_ai/hooks/_hooks.py b/src/inspect_ai/hooks/_hooks.py index ba17da635f..19e87d4cde 100644 --- a/src/inspect_ai/hooks/_hooks.py +++ b/src/inspect_ai/hooks/_hooks.py @@ -1,3 +1,4 @@ +from abc import abstractmethod from dataclasses import dataclass from logging import getLogger from typing import Awaitable, Callable, Type, TypeVar, cast @@ -124,7 +125,24 @@ class Hooks: affect the overall execution of the eval. If a hook fails, a warning will be logged. """ - # TODO: Add name and description properties. + @property + @abstractmethod + def name(self) -> str: + """Get the name of the hook. + + This is used for logging and display purposes. + """ + raise NotImplementedError + + @property + @abstractmethod + def description(self) -> str: + """Get the description of the hook. + + This is used for logging and debugging purposes. It should provide a brief + overview of what the hook does. + """ + raise NotImplementedError def enabled(self) -> bool: """Check if the hook should be enabled. @@ -231,17 +249,12 @@ def override_api_key(self, data: ApiKeyOverride) -> str | None: T = TypeVar("T", bound=Hooks) -def hooks(name: str, description: str) -> Callable[..., Type[T]]: +def hooks() -> Callable[..., Type[T]]: """Decorator for registering a hook subscriber. Either decorate a subclass of `Hooks`, or a function which returns the type of a subclass of `Hooks`. This decorator will instantiate the hook class and store it in the registry. - - Args: - name (str): Name of the subscriber (e.g. "audit logging"). - description (str): Short description of the hook (e.g. "Copies eval files to - S3 bucket for auditing."). """ def wrapper(hook_type: Type[T] | Callable[..., Type[T]]) -> Type[T]: @@ -253,12 +266,10 @@ def wrapper(hook_type: Type[T] | Callable[..., Type[T]]) -> Type[T]: # Instantiate an instance of the Hooks class. hook_instance = hook_type() - hook_name = registry_name(hook_instance, name) + namespaced_hook_name = registry_name(hook_instance, hook_instance.name) registry_add( hook_instance, - RegistryInfo( - type="hooks", name=hook_name, metadata={"description": description} - ), + RegistryInfo(type="hooks", name=namespaced_hook_name), ) return cast(Type[T], hook_instance) diff --git a/src/inspect_ai/hooks/_startup.py b/src/inspect_ai/hooks/_startup.py index 0bb9e4a4a8..5241f03636 100644 --- a/src/inspect_ai/hooks/_startup.py +++ b/src/inspect_ai/hooks/_startup.py @@ -25,7 +25,9 @@ def init_hooks() -> None: hooks = get_all_hooks() _registry_hooks_loaded = True if hooks: - hook_names = [f" {_format_hook_for_printing(hook)}" for hook in hooks] + hook_names = [ + f" {f'[bold]{hook.name}[/bold]: {hook.description}'}" for hook in hooks + ] hook_names_joined = "\n".join(hook_names) messages.append( f"[bold]hooks enabled: {len(hooks)}[/bold]\n{hook_names_joined}" @@ -39,9 +41,3 @@ def init_hooks() -> None: f"[blue][bold]inspect_ai v{version}[/bold][/blue]\n" f"[bright_black]{all_messages}[/bright_black]\n" ) - - -def _format_hook_for_printing(hook: Hooks) -> str: - info = registry_info(hook) - description = info.metadata["description"] - return f"[bold]{info.name}[/bold]: {description}" diff --git a/tests/hooks/test_hooks.py b/tests/hooks/test_hooks.py index ad8115cdc2..65b18c20dc 100644 --- a/tests/hooks/test_hooks.py +++ b/tests/hooks/test_hooks.py @@ -45,6 +45,14 @@ def assert_no_events(self) -> None: assert not self.sample_end_events assert not self.model_usage_events + @property + def name(self) -> str: + return "test_hook" + + @property + def description(self) -> str: + return "test_hook_description" + def enabled(self) -> bool: return self.should_enable @@ -73,6 +81,16 @@ def override_api_key(self, data: ApiKeyOverride) -> str | None: return f"mocked-{data.env_var_name}-{data.value}" +class MockHook2(MockHook): + @property + def name(self) -> str: + return "test_hook_2" + + @property + def description(self) -> str: + return "test_hook_2_description" + + class MockMinimalHook(Hooks): def __init__(self) -> None: self.run_start_events: list[RunStart] = [] @@ -80,6 +98,14 @@ def __init__(self) -> None: async def on_run_start(self, data: RunStart) -> None: self.run_start_events.append(data) + @property + def name(self) -> str: + return "test_hook_minimal" + + @property + def description(self) -> str: + return "test_hook_minimal_description" + @pytest.fixture def mock_hook() -> Generator[MockHook, None, None]: @@ -88,7 +114,7 @@ def mock_hook() -> Generator[MockHook, None, None]: @pytest.fixture def hook_2() -> Generator[MockHook, None, None]: - yield from _create_mock_hook("test_hook_2", MockHook) + yield from _create_mock_hook("test_hook_2", MockHook2) @pytest.fixture @@ -261,14 +287,15 @@ def test_hook_name_and_description(mock_hook: MockHook) -> None: info = registry_info(mock_hook) assert info.name == "test_hook" - assert info.metadata["description"] == "test_hook-description" + assert mock_hook.name == "test_hook" + assert mock_hook.description == "test_hook_description" T = TypeVar("T", bound=Hooks) def _create_mock_hook(name: str, hook_class: Type[T]) -> Generator[T, None, None]: - @hooks(name, description=f"{name}-description") + @hooks() def get_hook_class() -> type[T]: return hook_class diff --git a/tests/test_package/inspect_package/_registry.py b/tests/test_package/inspect_package/_registry.py index 2102e0db59..e311349d80 100644 --- a/tests/test_package/inspect_package/_registry.py +++ b/tests/test_package/inspect_package/_registry.py @@ -28,7 +28,7 @@ def podman(): return PodmanSandboxEnvironment -@hooks(name="custom_hook", description="Custom hooks") +@hooks() def custom_hook(): from .hooks.custom import CustomHooks diff --git a/tests/test_package/inspect_package/hooks/custom.py b/tests/test_package/inspect_package/hooks/custom.py index f20642cc43..639d2035a1 100644 --- a/tests/test_package/inspect_package/hooks/custom.py +++ b/tests/test_package/inspect_package/hooks/custom.py @@ -2,6 +2,14 @@ class CustomHooks(Hooks): + @property + def name(self) -> str: + return "custom_hook" + + @property + def description(self) -> str: + return "A custom hook for testing purposes" + async def on_run_start(self, event: RunStart) -> None: global run_ids run_ids.append(event.run_id) From 17045abe6e777bf44a7d5d743fc23888c19e1658 Mon Sep 17 00:00:00 2001 From: "Craig.Walton" Date: Mon, 30 Jun 2025 15:21:11 +0000 Subject: [PATCH 19/24] Revert "Replace name & description parameters to @hooks() with name & description properties on class." Did not add much value, made examples less appealing, felt at odds with how we define other custom extensions like sandbox environments. This reverts commit 835d4a6a620f8810042e5eadee71e01b7022b863. --- docs/extensions.qmd | 19 ++--------- src/inspect_ai/hooks/_hooks.py | 33 +++++++------------ src/inspect_ai/hooks/_startup.py | 10 ++++-- tests/hooks/test_hooks.py | 33 ++----------------- .../test_package/inspect_package/_registry.py | 2 +- .../inspect_package/hooks/custom.py | 8 ----- 6 files changed, 24 insertions(+), 81 deletions(-) diff --git a/docs/extensions.qmd b/docs/extensions.qmd index 4dbafe87bd..1d807f9e0e 100644 --- a/docs/extensions.qmd +++ b/docs/extensions.qmd @@ -439,16 +439,8 @@ import wandb from inspect_ai.hooks import Hooks, RunEnd, RunStart, SampleEnd, hooks -@hooks() +@hooks(name="w&b_hook", description="Weights & Biases integration") class WBHook(Hooks): - @property - def name(self) -> str: - return "w&b_hook" - - @property - def description(self) -> str: - return "Uploads run info to Weights & Biases" - async def on_run_start(self, data: RunStart) -> None: wandb.init(name=data.run_id) @@ -475,15 +467,8 @@ There is a hook to optionally override the value of model API key environment va ``` python from inspect_ai.hooks import hooks, Hooks, ApiKeyOverride +@hooks(name="api_key_fetcher", description="Fetches API key from secrets manager") class ApiKeyFetcher(Hooks): - @property - def name(self) -> str: - return "api_key_fetcher" - - @property - def description(self) -> str: - return "Fetches API key from secrets manager" - def override_api_key(self, data: ApiKeyOverride) -> str | None: original_env_var_value = data.value if original_env_var_value.startswith("arn:aws:secretsmanager:"): diff --git a/src/inspect_ai/hooks/_hooks.py b/src/inspect_ai/hooks/_hooks.py index 19e87d4cde..ba17da635f 100644 --- a/src/inspect_ai/hooks/_hooks.py +++ b/src/inspect_ai/hooks/_hooks.py @@ -1,4 +1,3 @@ -from abc import abstractmethod from dataclasses import dataclass from logging import getLogger from typing import Awaitable, Callable, Type, TypeVar, cast @@ -125,24 +124,7 @@ class Hooks: affect the overall execution of the eval. If a hook fails, a warning will be logged. """ - @property - @abstractmethod - def name(self) -> str: - """Get the name of the hook. - - This is used for logging and display purposes. - """ - raise NotImplementedError - - @property - @abstractmethod - def description(self) -> str: - """Get the description of the hook. - - This is used for logging and debugging purposes. It should provide a brief - overview of what the hook does. - """ - raise NotImplementedError + # TODO: Add name and description properties. def enabled(self) -> bool: """Check if the hook should be enabled. @@ -249,12 +231,17 @@ def override_api_key(self, data: ApiKeyOverride) -> str | None: T = TypeVar("T", bound=Hooks) -def hooks() -> Callable[..., Type[T]]: +def hooks(name: str, description: str) -> Callable[..., Type[T]]: """Decorator for registering a hook subscriber. Either decorate a subclass of `Hooks`, or a function which returns the type of a subclass of `Hooks`. This decorator will instantiate the hook class and store it in the registry. + + Args: + name (str): Name of the subscriber (e.g. "audit logging"). + description (str): Short description of the hook (e.g. "Copies eval files to + S3 bucket for auditing."). """ def wrapper(hook_type: Type[T] | Callable[..., Type[T]]) -> Type[T]: @@ -266,10 +253,12 @@ def wrapper(hook_type: Type[T] | Callable[..., Type[T]]) -> Type[T]: # Instantiate an instance of the Hooks class. hook_instance = hook_type() - namespaced_hook_name = registry_name(hook_instance, hook_instance.name) + hook_name = registry_name(hook_instance, name) registry_add( hook_instance, - RegistryInfo(type="hooks", name=namespaced_hook_name), + RegistryInfo( + type="hooks", name=hook_name, metadata={"description": description} + ), ) return cast(Type[T], hook_instance) diff --git a/src/inspect_ai/hooks/_startup.py b/src/inspect_ai/hooks/_startup.py index 5241f03636..0bb9e4a4a8 100644 --- a/src/inspect_ai/hooks/_startup.py +++ b/src/inspect_ai/hooks/_startup.py @@ -25,9 +25,7 @@ def init_hooks() -> None: hooks = get_all_hooks() _registry_hooks_loaded = True if hooks: - hook_names = [ - f" {f'[bold]{hook.name}[/bold]: {hook.description}'}" for hook in hooks - ] + hook_names = [f" {_format_hook_for_printing(hook)}" for hook in hooks] hook_names_joined = "\n".join(hook_names) messages.append( f"[bold]hooks enabled: {len(hooks)}[/bold]\n{hook_names_joined}" @@ -41,3 +39,9 @@ def init_hooks() -> None: f"[blue][bold]inspect_ai v{version}[/bold][/blue]\n" f"[bright_black]{all_messages}[/bright_black]\n" ) + + +def _format_hook_for_printing(hook: Hooks) -> str: + info = registry_info(hook) + description = info.metadata["description"] + return f"[bold]{info.name}[/bold]: {description}" diff --git a/tests/hooks/test_hooks.py b/tests/hooks/test_hooks.py index 65b18c20dc..ad8115cdc2 100644 --- a/tests/hooks/test_hooks.py +++ b/tests/hooks/test_hooks.py @@ -45,14 +45,6 @@ def assert_no_events(self) -> None: assert not self.sample_end_events assert not self.model_usage_events - @property - def name(self) -> str: - return "test_hook" - - @property - def description(self) -> str: - return "test_hook_description" - def enabled(self) -> bool: return self.should_enable @@ -81,16 +73,6 @@ def override_api_key(self, data: ApiKeyOverride) -> str | None: return f"mocked-{data.env_var_name}-{data.value}" -class MockHook2(MockHook): - @property - def name(self) -> str: - return "test_hook_2" - - @property - def description(self) -> str: - return "test_hook_2_description" - - class MockMinimalHook(Hooks): def __init__(self) -> None: self.run_start_events: list[RunStart] = [] @@ -98,14 +80,6 @@ def __init__(self) -> None: async def on_run_start(self, data: RunStart) -> None: self.run_start_events.append(data) - @property - def name(self) -> str: - return "test_hook_minimal" - - @property - def description(self) -> str: - return "test_hook_minimal_description" - @pytest.fixture def mock_hook() -> Generator[MockHook, None, None]: @@ -114,7 +88,7 @@ def mock_hook() -> Generator[MockHook, None, None]: @pytest.fixture def hook_2() -> Generator[MockHook, None, None]: - yield from _create_mock_hook("test_hook_2", MockHook2) + yield from _create_mock_hook("test_hook_2", MockHook) @pytest.fixture @@ -287,15 +261,14 @@ def test_hook_name_and_description(mock_hook: MockHook) -> None: info = registry_info(mock_hook) assert info.name == "test_hook" - assert mock_hook.name == "test_hook" - assert mock_hook.description == "test_hook_description" + assert info.metadata["description"] == "test_hook-description" T = TypeVar("T", bound=Hooks) def _create_mock_hook(name: str, hook_class: Type[T]) -> Generator[T, None, None]: - @hooks() + @hooks(name, description=f"{name}-description") def get_hook_class() -> type[T]: return hook_class diff --git a/tests/test_package/inspect_package/_registry.py b/tests/test_package/inspect_package/_registry.py index e311349d80..2102e0db59 100644 --- a/tests/test_package/inspect_package/_registry.py +++ b/tests/test_package/inspect_package/_registry.py @@ -28,7 +28,7 @@ def podman(): return PodmanSandboxEnvironment -@hooks() +@hooks(name="custom_hook", description="Custom hooks") def custom_hook(): from .hooks.custom import CustomHooks diff --git a/tests/test_package/inspect_package/hooks/custom.py b/tests/test_package/inspect_package/hooks/custom.py index 639d2035a1..f20642cc43 100644 --- a/tests/test_package/inspect_package/hooks/custom.py +++ b/tests/test_package/inspect_package/hooks/custom.py @@ -2,14 +2,6 @@ class CustomHooks(Hooks): - @property - def name(self) -> str: - return "custom_hook" - - @property - def description(self) -> str: - return "A custom hook for testing purposes" - async def on_run_start(self, event: RunStart) -> None: global run_ids run_ids.append(event.run_id) From 2abd49a774a0c24c09531abaf19dea3af75779e5 Mon Sep 17 00:00:00 2001 From: "Craig.Walton" Date: Mon, 30 Jun 2025 15:22:20 +0000 Subject: [PATCH 20/24] Remove todo (see reverted commit). --- src/inspect_ai/hooks/_hooks.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/inspect_ai/hooks/_hooks.py b/src/inspect_ai/hooks/_hooks.py index ba17da635f..8f51e1aec8 100644 --- a/src/inspect_ai/hooks/_hooks.py +++ b/src/inspect_ai/hooks/_hooks.py @@ -124,8 +124,6 @@ class Hooks: affect the overall execution of the eval. If a hook fails, a warning will be logged. """ - # TODO: Add name and description properties. - def enabled(self) -> bool: """Check if the hook should be enabled. From a15111b10d9583567b1789774a22844163dca006 Mon Sep 17 00:00:00 2001 From: "Craig.Walton" Date: Mon, 30 Jun 2025 15:54:25 +0000 Subject: [PATCH 21/24] Remove SampleAbort from docs. --- docs/reference/_sidebar.yml | 2 -- docs/reference/inspect_ai.hooks.qmd | 1 - 2 files changed, 3 deletions(-) diff --git a/docs/reference/_sidebar.yml b/docs/reference/_sidebar.yml index b2919fa409..057c754e71 100644 --- a/docs/reference/_sidebar.yml +++ b/docs/reference/_sidebar.yml @@ -632,8 +632,6 @@ website: href: reference/inspect_ai.hooks.qmd#runend - text: RunStart href: reference/inspect_ai.hooks.qmd#runstart - - text: SampleAbort - href: reference/inspect_ai.hooks.qmd#sampleabort - text: SampleEnd href: reference/inspect_ai.hooks.qmd#sampleend - text: SampleStart diff --git a/docs/reference/inspect_ai.hooks.qmd b/docs/reference/inspect_ai.hooks.qmd index ce94b93c7c..e57a648161 100644 --- a/docs/reference/inspect_ai.hooks.qmd +++ b/docs/reference/inspect_ai.hooks.qmd @@ -13,7 +13,6 @@ title: "inspect_ai.hooks" ### ModelUsageData ### RunEnd ### RunStart -### SampleAbort ### SampleEnd ### SampleStart ### TaskEnd From 142ac65158efb09c0f838b88ae27b85159080ef6 Mon Sep 17 00:00:00 2001 From: "Craig.Walton" Date: Mon, 30 Jun 2025 15:54:40 +0000 Subject: [PATCH 22/24] Improve docs for hooks. --- docs/extensions.qmd | 24 ++++++++++++++++++------ 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/docs/extensions.qmd b/docs/extensions.qmd index 1d807f9e0e..db70765a79 100644 --- a/docs/extensions.qmd +++ b/docs/extensions.qmd @@ -430,17 +430,17 @@ Once this package is installed, you'll be able to use `myfs://` with Inspect wit ## Hooks -Hooks allow you to run arbitrary code during Inspect's lifecycle, for example when runs, tasks or samples start and end. +Hooks allow you to run arbitrary code during certain events of Inspect's lifecycle, for example when runs, tasks or samples start and end. -Here is a hypothetical hook for integration with Weights & Biases. +Here is a hypothetical integration with Weights & Biases. ``` python import wandb from inspect_ai.hooks import Hooks, RunEnd, RunStart, SampleEnd, hooks -@hooks(name="w&b_hook", description="Weights & Biases integration") -class WBHook(Hooks): +@hooks(name="w&b_hooks", description="Weights & Biases integration") +class WBHooks(Hooks): async def on_run_start(self, data: RunStart) -> None: wandb.init(name=data.run_id) @@ -455,11 +455,23 @@ class WBHook(Hooks): }) ``` -See the `Hooks` class for more documentation and the full list of available hooks. +See the `Hooks` class for more documentation and the full list of available hook events. + +Each set of hooks (i.e. each `@hooks`-decorated class) can register for any events (even if they're overlapping). + +Alternatively, you may decorate a function which returns the type of a `Hooks` subclass to create a layer of indirection so that you can separate the registration of hooks from the importing of libraries they require (important for limiting dependencies). + +``` {.python filename="providers.py"} +@hooks(name="w&b_hooks", description="Weights & Biases integration") +def wandb_hooks(): + from .wb_hooks import WBHooks + + return WBHooks +``` ### API Key Override -There is a hook to optionally override the value of model API key environment variables. This could be used to: +There is a hook event to optionally override the value of model API key environment variables. This could be used to: * inject API keys at runtime (e.g. fetched from a secrets manager), to avoid having to store these in your environment or .env file * use some custom model API authentication mechanism in conjunction with a custom reverse proxy for the model API to avoid Inspect ever having access to real API keys From b30d99b340c0794ae795a11a5741df6fe990fa55 Mon Sep 17 00:00:00 2001 From: "Craig.Walton" Date: Mon, 30 Jun 2025 15:55:12 +0000 Subject: [PATCH 23/24] Make terminology consistent: a class deriving from Hooks should be named as plural too. --- tests/hooks/test_hooks.py | 134 +++++++++++++++++++------------------- 1 file changed, 67 insertions(+), 67 deletions(-) diff --git a/tests/hooks/test_hooks.py b/tests/hooks/test_hooks.py index ad8115cdc2..b7dddc6cd5 100644 --- a/tests/hooks/test_hooks.py +++ b/tests/hooks/test_hooks.py @@ -25,7 +25,7 @@ from inspect_ai.solver._task_state import TaskState -class MockHook(Hooks): +class MockHooks(Hooks): def __init__(self) -> None: self.should_enable = True self.run_start_events: list[RunStart] = [] @@ -73,7 +73,7 @@ def override_api_key(self, data: ApiKeyOverride) -> str | None: return f"mocked-{data.env_var_name}-{data.value}" -class MockMinimalHook(Hooks): +class MockMinimalHooks(Hooks): def __init__(self) -> None: self.run_start_events: list[RunStart] = [] @@ -82,62 +82,62 @@ async def on_run_start(self, data: RunStart) -> None: @pytest.fixture -def mock_hook() -> Generator[MockHook, None, None]: - yield from _create_mock_hook("test_hook", MockHook) +def mock_hooks() -> Generator[MockHooks, None, None]: + yield from _create_mock_hooks("test_hooks", MockHooks) @pytest.fixture -def hook_2() -> Generator[MockHook, None, None]: - yield from _create_mock_hook("test_hook_2", MockHook) +def hooks_2() -> Generator[MockHooks, None, None]: + yield from _create_mock_hooks("test_hooks_2", MockHooks) @pytest.fixture -def hook_minimal() -> Generator[MockMinimalHook, None, None]: - yield from _create_mock_hook("test_hook_minimal", MockMinimalHook) +def hooks_minimal() -> Generator[MockMinimalHooks, None, None]: + yield from _create_mock_hooks("test_hooks_minimal", MockMinimalHooks) def test_can_run_eval_with_no_hooks() -> None: eval(Task(dataset=[Sample("sample_1")]), model="mockllm/model") -def test_respects_enabled(mock_hook: MockHook) -> None: - mock_hook.assert_no_events() +def test_respects_enabled(mock_hooks: MockHooks) -> None: + mock_hooks.assert_no_events() - mock_hook.should_enable = False + mock_hooks.should_enable = False eval(Task(dataset=[Sample("sample_1")]), model="mockllm/model") - mock_hook.assert_no_events() + mock_hooks.assert_no_events() - mock_hook.should_enable = True + mock_hooks.should_enable = True eval(Task(dataset=[Sample("sample_1")]), model="mockllm/model") - assert len(mock_hook.run_start_events) == 1 + assert len(mock_hooks.run_start_events) == 1 -def test_can_subscribe_to_events(mock_hook: MockHook) -> None: - mock_hook.assert_no_events() +def test_can_subscribe_to_events(mock_hooks: MockHooks) -> None: + mock_hooks.assert_no_events() eval(Task(dataset=[Sample("sample_1")]), model="mockllm/model") - assert len(mock_hook.run_start_events) == 1 - assert mock_hook.run_start_events[0].run_id is not None - assert len(mock_hook.run_end_events) == 1 - assert len(mock_hook.task_start_events) == 1 - assert len(mock_hook.task_end_events) == 1 - assert len(mock_hook.sample_start_events) == 1 - assert len(mock_hook.sample_end_events) == 1 - assert len(mock_hook.model_usage_events) == 0 + assert len(mock_hooks.run_start_events) == 1 + assert mock_hooks.run_start_events[0].run_id is not None + assert len(mock_hooks.run_end_events) == 1 + assert len(mock_hooks.task_start_events) == 1 + assert len(mock_hooks.task_end_events) == 1 + assert len(mock_hooks.sample_start_events) == 1 + assert len(mock_hooks.sample_end_events) == 1 + assert len(mock_hooks.model_usage_events) == 0 def test_can_subscribe_to_events_with_multiple_hooks( - mock_hook: MockHook, hook_2: MockHook + mock_hooks: MockHooks, hooks_2: MockHooks ) -> None: - mock_hook.assert_no_events() - hook_2.assert_no_events() + mock_hooks.assert_no_events() + hooks_2.assert_no_events() eval(Task(dataset=[Sample("sample_1")]), model="mockllm/model") - for h in (mock_hook, hook_2): + for h in (mock_hooks, hooks_2): assert len(h.run_start_events) == 1 assert h.run_start_events[0].run_id is not None assert len(h.run_end_events) == 1 @@ -148,7 +148,7 @@ def test_can_subscribe_to_events_with_multiple_hooks( assert len(h.model_usage_events) == 0 -def test_hooks_on_multiple_tasks(mock_hook: MockHook) -> None: +def test_hooks_on_multiple_tasks(mock_hooks: MockHooks) -> None: eval( [ Task(dataset=[Sample("task_1_sample_1")]), @@ -157,15 +157,15 @@ def test_hooks_on_multiple_tasks(mock_hook: MockHook) -> None: model="mockllm/model", ) - assert len(mock_hook.run_start_events) == 1 - assert len(mock_hook.run_end_events) == 1 - assert len(mock_hook.task_start_events) == 2 - assert len(mock_hook.task_end_events) == 2 - assert len(mock_hook.sample_start_events) == 2 - assert len(mock_hook.sample_end_events) == 2 + assert len(mock_hooks.run_start_events) == 1 + assert len(mock_hooks.run_end_events) == 1 + assert len(mock_hooks.task_start_events) == 2 + assert len(mock_hooks.task_end_events) == 2 + assert len(mock_hooks.sample_start_events) == 2 + assert len(mock_hooks.sample_end_events) == 2 -def test_hooks_with_multiple_samples(mock_hook: MockHook) -> None: +def test_hooks_with_multiple_samples(mock_hooks: MockHooks) -> None: eval( [ Task(dataset=[Sample("sample_1"), Sample("sample_2")]), @@ -173,26 +173,26 @@ def test_hooks_with_multiple_samples(mock_hook: MockHook) -> None: model="mockllm/model", ) - assert len(mock_hook.run_start_events) == 1 - assert len(mock_hook.run_end_events) == 1 - assert len(mock_hook.task_start_events) == 1 - assert len(mock_hook.task_end_events) == 1 - assert len(mock_hook.sample_start_events) == 2 - assert len(mock_hook.sample_end_events) == 2 + assert len(mock_hooks.run_start_events) == 1 + assert len(mock_hooks.run_end_events) == 1 + assert len(mock_hooks.task_start_events) == 1 + assert len(mock_hooks.task_end_events) == 1 + assert len(mock_hooks.sample_start_events) == 2 + assert len(mock_hooks.sample_end_events) == 2 -def test_hooks_with_multiple_epochs(mock_hook: MockHook) -> None: +def test_hooks_with_multiple_epochs(mock_hooks: MockHooks) -> None: eval( Task(dataset=[Sample("sample_1")]), model="mockllm/model", epochs=3, ) - assert len(mock_hook.sample_start_events) == 3 - assert len(mock_hook.sample_end_events) == 3 + assert len(mock_hooks.sample_start_events) == 3 + assert len(mock_hooks.sample_end_events) == 3 -def test_hooks_with_sample_retries(mock_hook: MockHook) -> None: +def test_hooks_with_sample_retries(mock_hooks: MockHooks) -> None: eval( Task(dataset=[Sample("sample_1")], solver=_fail_n_times_solver(2)), model="mockllm/model", @@ -200,11 +200,11 @@ def test_hooks_with_sample_retries(mock_hook: MockHook) -> None: ) # Will succeed on 3rd attempt, but just 1 sample start and end event. - assert len(mock_hook.sample_start_events) == 1 - assert len(mock_hook.sample_end_events) == 1 + assert len(mock_hooks.sample_start_events) == 1 + assert len(mock_hooks.sample_end_events) == 1 -def test_hooks_with_error_and_no_retries(mock_hook: MockHook) -> None: +def test_hooks_with_error_and_no_retries(mock_hooks: MockHooks) -> None: eval( Task(dataset=[Sample("sample_1")], solver=_fail_n_times_solver(10)), model="mockllm/model", @@ -212,29 +212,29 @@ def test_hooks_with_error_and_no_retries(mock_hook: MockHook) -> None: ) # Will fail on first attempt without any retries. - assert len(mock_hook.sample_start_events) == 1 - assert len(mock_hook.sample_end_events) == 1 + assert len(mock_hooks.sample_start_events) == 1 + assert len(mock_hooks.sample_end_events) == 1 -def test_hook_does_not_need_to_subscribe_to_all_events( - hook_minimal: MockMinimalHook, +def test_hooks_do_not_need_to_subscribe_to_all_events( + hooks_minimal: MockMinimalHooks, ) -> None: eval(Task(dataset=[Sample("sample_1")]), model="mockllm/model") - assert len(hook_minimal.run_start_events) == 1 + assert len(hooks_minimal.run_start_events) == 1 -def test_api_key_override(mock_hook: MockHook) -> None: +def test_api_key_override(mock_hooks: MockHooks) -> None: overridden = override_api_key("TEST_VAR", "test_value") assert overridden == "mocked-TEST_VAR-test_value" -def test_api_key_override_falls_back_to_legacy(mock_hook: MockHook) -> None: +def test_api_key_override_falls_back_to_legacy(mock_hooks: MockHooks) -> None: def legacy_hook_override(var: str, value: str) -> str | None: return f"legacy-{var}-{value}" - mock_hook.should_enable = False + mock_hooks.should_enable = False with environ_var("INSPECT_API_KEY_OVERRIDE", "._legacy_hook_override"): with patch( @@ -245,7 +245,7 @@ def legacy_hook_override(var: str, value: str) -> str | None: assert overridden == "legacy-TEST_VAR-test_value" -def test_init_hooks_can_be_called_multiple_times(mock_hook: MockHook) -> None: +def test_init_hooks_can_be_called_multiple_times(mock_hooks: MockHooks) -> None: from inspect_ai.hooks._startup import init_hooks # Ensure that init_hooks can be called multiple times without issues. @@ -254,26 +254,26 @@ def test_init_hooks_can_be_called_multiple_times(mock_hook: MockHook) -> None: eval(Task(dataset=[Sample("sample_1")]), model="mockllm/model") - assert len(mock_hook.run_start_events) == 1 + assert len(mock_hooks.run_start_events) == 1 -def test_hook_name_and_description(mock_hook: MockHook) -> None: - info = registry_info(mock_hook) +def test_hooks_name_and_description(mock_hooks: MockHooks) -> None: + info = registry_info(mock_hooks) - assert info.name == "test_hook" - assert info.metadata["description"] == "test_hook-description" + assert info.name == "test_hooks" + assert info.metadata["description"] == "test_hooks-description" T = TypeVar("T", bound=Hooks) -def _create_mock_hook(name: str, hook_class: Type[T]) -> Generator[T, None, None]: +def _create_mock_hooks(name: str, hooks_class: Type[T]) -> Generator[T, None, None]: @hooks(name, description=f"{name}-description") - def get_hook_class() -> type[T]: - return hook_class + def get_hooks_class() -> type[T]: + return hooks_class hook = registry_lookup("hooks", name) - assert isinstance(hook, hook_class) + assert isinstance(hook, hooks_class) try: yield hook finally: From 7490bd9fe94467017ffa149a0369e7c5b6734ac8 Mon Sep 17 00:00:00 2001 From: "Craig.Walton" Date: Mon, 30 Jun 2025 16:13:28 +0000 Subject: [PATCH 24/24] Add release note. --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 114d00364f..2c1b41d3c1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,7 @@ - Analysis: Don't show dataframe import progress by default in notebooks (leaves empty cell output artifact). - Analysis: Include `order` field in `messages_df()` and `events_df()`. - Logging: Improvements to `--display=log` (improved task info formatting, ability to disable rich logging) +- [Hooks](https://inspect.aisi.org.uk/extensions.html#hooks): Generic lifecycle hooks for Inspect extensions. ## 0.3.111 (29 June 2025)