Coverage for human_requests / autotest.py: 80%
535 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-07 17:38 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-07 17:38 +0000
1from __future__ import annotations
3import heapq
4import inspect
5import types
6import warnings
7from collections.abc import Awaitable, Callable, Iterable, Mapping, Sequence
8from dataclasses import dataclass, field
9from typing import (
10 Annotated,
11 Any,
12 ContextManager,
13 Literal,
14 Protocol,
15 TypeVar,
16 Union,
17 cast,
18 get_args,
19 get_origin,
20)
22AutotestFunction = Callable[..., Any]
23AutotestHook = Callable[[Any, Any, "AutotestContext"], Any]
24AutotestParamProvider = Callable[["AutotestCallContext"], Any]
25AutotestDataProvider = Callable[["AutotestDataContext"], Any]
26AutotestTypecheckMode = Literal["off", "warn", "strict"]
27SnapshotName = str | int | Callable[..., Any] | list[str | int | Callable[..., Any]]
28HookKey = tuple[type[object] | None, AutotestFunction]
29DependencyMarker = Callable[..., Any]
31_AUTOTEST_ATTR = "__autotest__"
32_DEPENDS_ON_ATTR = "__autotest_depends_on__"
33_HOOKS: dict[HookKey, AutotestHook] = {}
34_PARAM_PROVIDERS: dict[HookKey, AutotestParamProvider] = {}
35_CASE_POLICIES: dict[HookKey, "AutotestCasePolicy"] = {}
36_DATA_CASES: list["AutotestDataCase"] = []
37_VALID_TYPECHECK_MODES: frozenset[str] = frozenset({"off", "warn", "strict"})
39_PrimitiveTypes = (str, bytes, bytearray, bool, int, float, complex, range, memoryview)
41F = TypeVar("F", bound=Callable[..., Any])
44@dataclass(frozen=True)
45class AutotestContext:
46 api: object
47 owner: object
48 parent: object | None
49 method: Callable[..., Awaitable[Any]]
50 func: AutotestFunction
51 schemashot: Any
52 state: dict[str, Any]
55@dataclass(frozen=True)
56class AutotestMethodCase:
57 owner: object
58 parent: object | None
59 method: Callable[..., Awaitable[Any]]
60 func: AutotestFunction
61 required_parameters: tuple[str, ...]
62 depends_on: tuple[AutotestFunction, ...]
65@dataclass(frozen=True)
66class AutotestCallContext:
67 api: object
68 owner: object
69 parent: object | None
70 method: Callable[..., Awaitable[Any]]
71 func: AutotestFunction
72 schemashot: Any
73 state: dict[str, Any]
76@dataclass(frozen=True)
77class AutotestDataContext:
78 api: object
79 schemashot: Any
80 state: dict[str, Any]
83@dataclass(frozen=True)
84class AutotestInvocation:
85 args: tuple[Any, ...] = ()
86 kwargs: dict[str, Any] = field(default_factory=dict)
89@dataclass(frozen=True)
90class AutotestDataCase:
91 name: SnapshotName
92 provider: AutotestDataProvider
95@dataclass(frozen=True)
96class AutotestCasePolicy:
97 depends_on: tuple[AutotestFunction, ...] = ()
100class AutotestSubtests(Protocol):
101 def test(self, msg: str | None = None, **kwargs: Any) -> ContextManager[None]: ...
104def autotest(func: F) -> F:
105 setattr(func, _AUTOTEST_ATTR, True)
106 return func
109def autotest_depends_on(
110 target: Callable[..., Any],
111) -> Callable[[DependencyMarker], DependencyMarker]:
112 dependency = _as_function(target)
114 def decorator(callback: DependencyMarker) -> DependencyMarker:
115 existing = _get_callback_dependencies(callback)
116 if dependency in existing:
117 return callback
118 setattr(callback, _DEPENDS_ON_ATTR, (*existing, dependency))
119 return callback
121 return decorator
124def autotest_hook(
125 *,
126 target: Callable[..., Any],
127 parent: type[object] | None = None,
128) -> Callable[[AutotestHook], AutotestHook]:
129 if parent is not None and not inspect.isclass(parent):
130 raise TypeError("autotest_hook parent must be a class or None.")
132 target_func = _as_function(target)
134 def decorator(hook: AutotestHook) -> AutotestHook:
135 _HOOKS[(parent, target_func)] = hook
136 return hook
138 return decorator
141def autotest_params(
142 *,
143 target: Callable[..., Any],
144 parent: type[object] | None = None,
145 depends_on: Sequence[Callable[..., Any]] | None = None,
146) -> Callable[[AutotestParamProvider], AutotestParamProvider]:
147 if parent is not None and not inspect.isclass(parent):
148 raise TypeError("autotest_params parent must be a class or None.")
150 target_func = _as_function(target)
151 depends_on_funcs = _normalize_depends_on(depends_on)
153 def decorator(provider: AutotestParamProvider) -> AutotestParamProvider:
154 _PARAM_PROVIDERS[(parent, target_func)] = provider
155 if depends_on_funcs:
156 _register_case_policy(
157 parent=parent,
158 target_func=target_func,
159 depends_on=depends_on_funcs,
160 )
161 return provider
163 return decorator
166def autotest_policy(
167 *,
168 target: Callable[..., Any],
169 parent: type[object] | None = None,
170 depends_on: Sequence[Callable[..., Any]] | None = None,
171) -> Callable[[F], F]:
172 if parent is not None and not inspect.isclass(parent):
173 raise TypeError("autotest_policy parent must be a class or None.")
175 target_func = _as_function(target)
176 depends_on_funcs = _normalize_depends_on(depends_on)
177 _register_case_policy(
178 parent=parent,
179 target_func=target_func,
180 depends_on=depends_on_funcs,
181 )
183 def decorator(marker: F) -> F:
184 return marker
186 return decorator
189def autotest_data(
190 *,
191 name: SnapshotName,
192) -> Callable[[AutotestDataProvider], AutotestDataProvider]:
193 def decorator(provider: AutotestDataProvider) -> AutotestDataProvider:
194 _DATA_CASES.append(AutotestDataCase(name=name, provider=provider))
195 return provider
197 return decorator
200def clear_autotest_hooks() -> None:
201 _HOOKS.clear()
202 _PARAM_PROVIDERS.clear()
203 _CASE_POLICIES.clear()
204 _DATA_CASES.clear()
207def find_autotest_hook(
208 func: Callable[..., Any],
209 parent_object: object | None,
210) -> AutotestHook | None:
211 target_func = _as_function(func)
212 parent_class = parent_object.__class__ if parent_object is not None else None
214 if parent_class is not None:
215 parent_hook = _HOOKS.get((parent_class, target_func))
216 if parent_hook is not None:
217 return parent_hook
219 return _HOOKS.get((None, target_func))
222def find_autotest_params_provider(
223 func: Callable[..., Any],
224 parent_object: object | None,
225) -> AutotestParamProvider | None:
226 target_func = _as_function(func)
227 parent_class = parent_object.__class__ if parent_object is not None else None
229 if parent_class is not None:
230 parent_provider = _PARAM_PROVIDERS.get((parent_class, target_func))
231 if parent_provider is not None:
232 return parent_provider
234 return _PARAM_PROVIDERS.get((None, target_func))
237def find_autotest_policy(
238 func: Callable[..., Any],
239 parent_object: object | None,
240) -> AutotestCasePolicy:
241 target_func = _as_function(func)
242 parent_class = parent_object.__class__ if parent_object is not None else None
244 if parent_class is not None:
245 parent_policy = _CASE_POLICIES.get((parent_class, target_func))
246 if parent_policy is not None:
247 return parent_policy
249 return _CASE_POLICIES.get((None, target_func), AutotestCasePolicy())
252def find_autotest_hook_dependencies(
253 func: Callable[..., Any],
254 parent_object: object | None,
255) -> tuple[AutotestFunction, ...]:
256 hook = find_autotest_hook(func, parent_object)
257 if hook is None:
258 return ()
259 return _get_callback_dependencies(hook)
262def find_autotest_params_dependencies(
263 func: Callable[..., Any],
264 parent_object: object | None,
265) -> tuple[AutotestFunction, ...]:
266 provider = find_autotest_params_provider(func, parent_object)
267 if provider is None:
268 return ()
269 return _get_callback_dependencies(provider)
272def discover_autotest_methods(api: object) -> list[AutotestMethodCase]:
273 cases: list[AutotestMethodCase] = []
274 visited: set[int] = set()
276 def walk(owner: object, parent: object | None) -> None:
277 owner_id = id(owner)
278 if owner_id in visited:
279 return
280 visited.add(owner_id)
282 for attr_name in sorted(dir(owner)):
283 if attr_name.startswith("_"):
284 continue
286 try:
287 value = getattr(owner, attr_name)
288 except Exception:
289 continue
291 bound_method = _as_bound_method(value)
292 if bound_method is not None:
293 func = _as_function(bound_method)
294 if _is_autotest(func):
295 required_parameters = _required_parameters(bound_method)
296 provider = find_autotest_params_provider(func, parent)
297 policy = find_autotest_policy(func, parent)
298 dependencies = _merge_dependencies(
299 policy.depends_on,
300 find_autotest_params_dependencies(func, parent),
301 find_autotest_hook_dependencies(func, parent),
302 )
303 if required_parameters and provider is None:
304 joined = ", ".join(required_parameters)
305 raise TypeError(
306 f"Autotest method {func.__qualname__} requires arguments ({joined}). "
307 "Register @autotest_params(target=...) for this method."
308 )
309 cases.append(
310 AutotestMethodCase(
311 owner=owner,
312 parent=parent,
313 method=cast(Callable[..., Awaitable[Any]], bound_method),
314 func=func,
315 required_parameters=required_parameters,
316 depends_on=dependencies,
317 )
318 )
319 continue
321 if _is_resource_object(value):
322 walk(value, owner)
324 walk(api, None)
325 return _order_cases(cases)
328async def execute_autotests(
329 api: object,
330 schemashot: Any,
331 *,
332 typecheck_mode: AutotestTypecheckMode | str = "off",
333) -> int:
334 _validate_schemashot(schemashot)
335 resolved_typecheck_mode = _normalize_typecheck_mode(typecheck_mode)
337 executed_count = 0
338 state, completed_funcs, skipped_funcs = _initialize_runtime_state()
340 cases = discover_autotest_methods(api)
341 for case in cases:
342 missing_dependencies = tuple(dep for dep in case.depends_on if dep not in completed_funcs)
343 if missing_dependencies:
344 skipped_funcs.add(case.func)
345 continue
347 try:
348 await execute_autotest_case(
349 case=case,
350 api=api,
351 schemashot=schemashot,
352 state=state,
353 typecheck_mode=resolved_typecheck_mode,
354 )
355 except BaseException as error: # pragma: no cover - runtime-only branch for skip semantics
356 if _is_pytest_skip_exception(error):
357 skipped_funcs.add(case.func)
358 continue
359 raise
361 completed_funcs.add(case.func)
362 executed_count += 1
364 executed_count += await execute_autotest_data_cases(api=api, schemashot=schemashot, state=state)
365 return executed_count
368async def execute_autotests_with_subtests(
369 api: object,
370 schemashot: Any,
371 *,
372 subtests: AutotestSubtests,
373 typecheck_mode: AutotestTypecheckMode | str = "off",
374) -> int:
375 _validate_schemashot(schemashot)
376 resolved_typecheck_mode = _normalize_typecheck_mode(typecheck_mode)
378 processed_count = 0
379 state, completed_funcs, skipped_funcs = _initialize_runtime_state()
381 cases = discover_autotest_methods(api)
382 for case in cases:
383 processed_count += 1
384 case_succeeded = False
385 with subtests.test(**_subtest_label_for_case(case)):
386 missing_dependencies = tuple(
387 dep for dep in case.depends_on if dep not in completed_funcs
388 )
389 if missing_dependencies:
390 skipped_funcs.add(case.func)
391 _skip_current_case(_format_dependency_skip_reason(missing_dependencies))
393 try:
394 await execute_autotest_case(
395 case=case,
396 api=api,
397 schemashot=schemashot,
398 state=state,
399 typecheck_mode=resolved_typecheck_mode,
400 )
401 except BaseException as error: # pragma: no cover - runtime-only skip branch
402 if _is_pytest_skip_exception(error):
403 skipped_funcs.add(case.func)
404 raise
406 case_succeeded = True
408 if case_succeeded:
409 completed_funcs.add(case.func)
411 processed_count += await _execute_autotest_data_cases_with_subtests(
412 api=api,
413 schemashot=schemashot,
414 state=state,
415 subtests=subtests,
416 )
417 return processed_count
420async def execute_autotest_case(
421 *,
422 case: AutotestMethodCase,
423 api: object,
424 schemashot: Any,
425 state: dict[str, Any] | None = None,
426 typecheck_mode: AutotestTypecheckMode | str = "off",
427) -> None:
428 _validate_schemashot(schemashot)
429 resolved_typecheck_mode = _normalize_typecheck_mode(typecheck_mode)
430 runtime_state = state if state is not None else {}
431 invocation = await _resolve_invocation(
432 case=case,
433 api=api,
434 schemashot=schemashot,
435 state=runtime_state,
436 typecheck_mode=resolved_typecheck_mode,
437 )
439 response = await _invoke_method(case.method, case.func, invocation)
441 if not hasattr(response, "json") or not callable(response.json):
442 raise TypeError(
443 f"Autotest method {case.func.__qualname__} must return an object with json()."
444 )
446 data = response.json()
447 ctx = AutotestContext(
448 api=api,
449 owner=case.owner,
450 parent=case.parent,
451 method=case.method,
452 func=case.func,
453 schemashot=schemashot,
454 state=runtime_state,
455 )
457 hook = find_autotest_hook(case.func, case.parent)
458 if hook is not None:
459 hook_result = hook(response, data, ctx)
460 if inspect.isawaitable(hook_result):
461 hook_result = await hook_result
462 if hook_result is not None:
463 data = hook_result
465 schemashot.assert_json_match(data, case.func)
468async def execute_autotest_data_cases(
469 *,
470 api: object,
471 schemashot: Any,
472 state: dict[str, Any] | None = None,
473) -> int:
474 _validate_schemashot(schemashot)
475 runtime_state = state if state is not None else {}
476 ctx = AutotestDataContext(api=api, schemashot=schemashot, state=runtime_state)
478 for case in list(_DATA_CASES):
479 payload = case.provider(ctx)
480 if inspect.isawaitable(payload):
481 payload = await payload
482 schemashot.assert_json_match(payload, case.name)
484 return len(_DATA_CASES)
487async def _execute_autotest_data_cases_with_subtests(
488 *,
489 api: object,
490 schemashot: Any,
491 state: dict[str, Any],
492 subtests: AutotestSubtests,
493) -> int:
494 _validate_schemashot(schemashot)
495 ctx = AutotestDataContext(api=api, schemashot=schemashot, state=state)
497 processed_count = 0
498 for case in list(_DATA_CASES):
499 processed_count += 1
500 with subtests.test(**_subtest_label_for_data(case)):
501 payload = case.provider(ctx)
502 if inspect.isawaitable(payload):
503 payload = await payload
504 schemashot.assert_json_match(payload, case.name)
506 return processed_count
509def _is_autotest(func: AutotestFunction) -> bool:
510 return bool(getattr(func, _AUTOTEST_ATTR, False))
513def _as_function(target: Callable[..., Any]) -> AutotestFunction:
514 unbound: Any = target.__func__ if inspect.ismethod(target) else target
515 if not callable(unbound):
516 raise TypeError("Target must be a function or method.")
517 return cast(AutotestFunction, inspect.unwrap(unbound))
520def _as_bound_method(value: Any) -> Callable[..., Any] | None:
521 if inspect.ismethod(value) and value.__self__ is not None:
522 return cast(Callable[..., Any], value)
523 return None
526def _is_resource_object(value: Any) -> bool:
527 if value is None:
528 return False
529 if isinstance(value, _PrimitiveTypes):
530 return False
531 if isinstance(value, (dict, list, tuple, set, frozenset)):
532 return False
533 if inspect.ismodule(value) or inspect.isclass(value) or inspect.isfunction(value):
534 return False
535 if inspect.ismethod(value) or inspect.isbuiltin(value):
536 return False
537 return hasattr(value, "__dict__") or hasattr(value, "__slots__")
540def _order_cases(cases: list[AutotestMethodCase]) -> list[AutotestMethodCase]:
541 if len(cases) < 2:
542 return cases
544 index_by_func: dict[AutotestFunction, list[int]] = {}
545 for index, case in enumerate(cases):
546 index_by_func.setdefault(case.func, []).append(index)
548 edges: dict[int, set[int]] = {index: set() for index in range(len(cases))}
549 indegree = [0] * len(cases)
551 for target_index, case in enumerate(cases):
552 for dependency in case.depends_on:
553 for source_index in index_by_func.get(dependency, []):
554 if source_index == target_index:
555 continue
556 if target_index in edges[source_index]:
557 continue
558 edges[source_index].add(target_index)
559 indegree[target_index] += 1
561 queue: list[tuple[str, int]] = []
562 for index, case in enumerate(cases):
563 if indegree[index] == 0:
564 heapq.heappush(queue, (case.func.__qualname__, index))
566 ordered: list[AutotestMethodCase] = []
567 while queue:
568 _, current_index = heapq.heappop(queue)
569 ordered.append(cases[current_index])
570 for dependent_index in edges[current_index]:
571 indegree[dependent_index] -= 1
572 if indegree[dependent_index] == 0:
573 heapq.heappush(
574 queue,
575 (cases[dependent_index].func.__qualname__, dependent_index),
576 )
578 if len(ordered) == len(cases):
579 return ordered
581 for index, case in enumerate(cases):
582 if indegree[index] > 0:
583 ordered.append(case)
584 return ordered
587def _required_parameters(method: Callable[..., Any]) -> tuple[str, ...]:
588 required_arguments: list[str] = []
589 signature = inspect.signature(method)
590 for parameter in signature.parameters.values():
591 if (
592 parameter.kind
593 in (
594 inspect.Parameter.POSITIONAL_ONLY,
595 inspect.Parameter.POSITIONAL_OR_KEYWORD,
596 inspect.Parameter.KEYWORD_ONLY,
597 )
598 and parameter.default is inspect.Signature.empty
599 ):
600 required_arguments.append(parameter.name)
602 return tuple(required_arguments)
605def _normalize_depends_on(
606 depends_on: Sequence[Callable[..., Any]] | None,
607) -> tuple[AutotestFunction, ...]:
608 if depends_on is None:
609 return ()
611 if not isinstance(depends_on, Sequence):
612 raise TypeError("depends_on must be a sequence of functions or methods.")
614 return tuple(_as_function(target) for target in depends_on)
617def _register_case_policy(
618 *,
619 parent: type[object] | None,
620 target_func: AutotestFunction,
621 depends_on: Iterable[AutotestFunction] = (),
622) -> None:
623 current = _CASE_POLICIES.get((parent, target_func), AutotestCasePolicy())
624 resolved_depends = tuple(depends_on) if depends_on else current.depends_on
625 _CASE_POLICIES[(parent, target_func)] = AutotestCasePolicy(
626 depends_on=resolved_depends,
627 )
630def _get_callback_dependencies(callback: Callable[..., Any]) -> tuple[AutotestFunction, ...]:
631 raw = getattr(callback, _DEPENDS_ON_ATTR, ())
632 if not isinstance(raw, tuple):
633 return ()
634 return tuple(_as_function(dep) for dep in raw)
637def _merge_dependencies(
638 *sources: Iterable[AutotestFunction],
639) -> tuple[AutotestFunction, ...]:
640 merged: list[AutotestFunction] = []
641 for source in sources:
642 for dependency in source:
643 if dependency not in merged:
644 merged.append(dependency)
645 return tuple(merged)
648def _initialize_runtime_state() -> tuple[
649 dict[str, Any],
650 set[AutotestFunction],
651 set[AutotestFunction],
652]:
653 state: dict[str, Any] = {}
654 completed_funcs: set[AutotestFunction] = set()
655 skipped_funcs: set[AutotestFunction] = set()
656 state["autotest_completed_funcs"] = completed_funcs
657 state["autotest_skipped_funcs"] = skipped_funcs
658 return state, completed_funcs, skipped_funcs
661def _subtest_label_for_case(case: AutotestMethodCase) -> dict[str, str]:
662 parent_name = "None" if case.parent is None else type(case.parent).__name__
663 return {
664 "method": case.func.__qualname__,
665 "parent": parent_name,
666 }
669def _subtest_label_for_data(case: AutotestDataCase) -> dict[str, str]:
670 return {"data": _snapshot_name_repr(case.name)}
673def _snapshot_name_repr(name: object) -> str:
674 if isinstance(name, (str, int)):
675 return str(name)
676 if callable(name):
677 return getattr(name, "__qualname__", repr(name))
678 if isinstance(name, list):
679 return "[" + ", ".join(_snapshot_name_repr(item) for item in name) + "]"
680 return repr(name)
683def _format_dependency_skip_reason(missing: tuple[AutotestFunction, ...]) -> str:
684 names = ", ".join(dep.__qualname__ for dep in missing)
685 return f"Dependency was not executed: {names}"
688def _skip_current_case(reason: str) -> None:
689 try:
690 import pytest
691 except Exception as error: # pragma: no cover - pytest runtime always has pytest installed
692 raise RuntimeError("pytest is required to skip autotest subtest cases.") from error
694 pytest.skip(reason)
697def _is_pytest_skip_exception(error: BaseException) -> bool:
698 try:
699 import pytest
700 except Exception:
701 return False
703 skip_exception = getattr(pytest.skip, "Exception", None)
704 return bool(skip_exception and isinstance(error, skip_exception))
707def _validate_schemashot(schemashot: Any) -> None:
708 if not hasattr(schemashot, "assert_json_match") or not callable(schemashot.assert_json_match):
709 raise TypeError(
710 "schemashot fixture must provide a callable assert_json_match(data, name) method."
711 )
714async def _invoke_method(
715 method: Callable[..., Any],
716 func: AutotestFunction,
717 invocation: AutotestInvocation,
718) -> Any:
719 result = method(*invocation.args, **invocation.kwargs)
720 if not inspect.isawaitable(result):
721 raise TypeError(f"Autotest method {func.__qualname__} must be async.")
722 return await result
725async def _resolve_invocation(
726 *,
727 case: AutotestMethodCase,
728 api: object,
729 schemashot: Any,
730 state: dict[str, Any],
731 typecheck_mode: AutotestTypecheckMode,
732) -> AutotestInvocation:
733 provider = find_autotest_params_provider(case.func, case.parent)
734 if provider is None:
735 invocation = AutotestInvocation()
736 else:
737 ctx = AutotestCallContext(
738 api=api,
739 owner=case.owner,
740 parent=case.parent,
741 method=case.method,
742 func=case.func,
743 schemashot=schemashot,
744 state=state,
745 )
746 raw = provider(ctx)
747 if inspect.isawaitable(raw):
748 raw = await raw
749 invocation = _normalize_invocation(raw, case.func)
751 _validate_invocation(
752 case.method,
753 case.func,
754 invocation,
755 typecheck_mode=typecheck_mode,
756 )
757 return invocation
760def _normalize_invocation(raw: Any, func: AutotestFunction) -> AutotestInvocation:
761 if raw is None:
762 return AutotestInvocation()
763 if isinstance(raw, AutotestInvocation):
764 return AutotestInvocation(args=tuple(raw.args), kwargs=dict(raw.kwargs))
765 if isinstance(raw, dict):
766 return AutotestInvocation(kwargs=dict(raw))
767 if isinstance(raw, (tuple, list)):
768 return AutotestInvocation(args=tuple(raw))
770 raise TypeError(
771 f"autotest_params provider for {func.__qualname__} must return one of: "
772 "None, dict (kwargs), tuple/list (args), AutotestInvocation."
773 )
776def _validate_invocation(
777 method: Callable[..., Any],
778 func: AutotestFunction,
779 invocation: AutotestInvocation,
780 *,
781 typecheck_mode: AutotestTypecheckMode = "off",
782) -> None:
783 signature = inspect.signature(method)
784 try:
785 bound_arguments = signature.bind(*invocation.args, **invocation.kwargs)
786 except TypeError as error:
787 raise TypeError(f"Invalid invocation for {func.__qualname__}: {error}") from error
788 _validate_invocation_types(
789 signature=signature,
790 bound_arguments=bound_arguments.arguments,
791 method=method,
792 func=func,
793 typecheck_mode=typecheck_mode,
794 )
797def _normalize_typecheck_mode(mode: AutotestTypecheckMode | str) -> AutotestTypecheckMode:
798 if not isinstance(mode, str):
799 raise TypeError("autotest typecheck mode must be a string.")
800 normalized = mode.strip().lower()
801 if normalized not in _VALID_TYPECHECK_MODES:
802 expected = ", ".join(sorted(_VALID_TYPECHECK_MODES))
803 raise ValueError(f"autotest typecheck mode must be one of: {expected}.")
804 return cast(AutotestTypecheckMode, normalized)
807def _validate_invocation_types(
808 *,
809 signature: inspect.Signature,
810 bound_arguments: Mapping[str, Any],
811 method: Callable[..., Any],
812 func: AutotestFunction,
813 typecheck_mode: AutotestTypecheckMode,
814) -> None:
815 if typecheck_mode == "off":
816 return
818 mismatches: list[str] = []
819 for name, value in bound_arguments.items():
820 parameter = signature.parameters.get(name)
821 if parameter is None:
822 continue
824 annotation = _resolve_annotation(parameter.annotation, method)
825 if annotation is inspect.Signature.empty:
826 continue
828 if _matches_annotation(value, annotation):
829 continue
831 expected = _format_annotation(annotation)
832 mismatches.append(f"parameter {name!r} expects {expected}, got {type(value).__name__}")
834 if not mismatches:
835 return
837 details = "; ".join(mismatches)
838 message = f"Invalid invocation types for {func.__qualname__}: {details}."
839 if typecheck_mode == "strict":
840 raise TypeError(message)
841 warnings.warn(message, RuntimeWarning, stacklevel=4)
844def _resolve_annotation(annotation: Any, method: Callable[..., Any]) -> Any:
845 if annotation is inspect.Signature.empty:
846 return annotation
848 if not isinstance(annotation, str):
849 return annotation
851 globals_dict = getattr(method, "__globals__", {})
852 if not isinstance(globals_dict, dict):
853 return inspect.Signature.empty
855 try:
856 return eval(annotation, globals_dict, {})
857 except Exception:
858 return inspect.Signature.empty
861def _matches_annotation(value: Any, annotation: Any) -> bool:
862 if annotation in (Any, object):
863 return True
864 if annotation in (None, type(None)):
865 return value is None
867 supertype = getattr(annotation, "__supertype__", None)
868 if supertype is not None:
869 return _matches_annotation(value, supertype)
871 origin = get_origin(annotation)
872 if origin is None:
873 if isinstance(annotation, type):
874 return isinstance(value, annotation)
875 return True
877 if origin in (types.UnionType, Union):
878 return any(_matches_annotation(value, arg) for arg in get_args(annotation))
880 if origin is Annotated:
881 args = get_args(annotation)
882 if not args:
883 return True
884 return _matches_annotation(value, args[0])
886 if origin is Literal:
887 return any(value == option for option in get_args(annotation))
889 if origin is list:
890 return _matches_iterable(value, annotation, list)
891 if origin is set:
892 return _matches_iterable(value, annotation, set)
893 if origin is frozenset:
894 return _matches_iterable(value, annotation, frozenset)
895 if origin is tuple:
896 return _matches_tuple(value, annotation)
897 if origin is dict:
898 return _matches_mapping(value, annotation)
900 if isinstance(origin, type):
901 return isinstance(value, origin)
903 return True
906def _matches_iterable(value: Any, annotation: Any, container_type: type[object]) -> bool:
907 if not isinstance(value, container_type):
908 return False
909 args = get_args(annotation)
910 if not args:
911 return True
912 item_type = args[0]
913 iterable_value = cast(Iterable[Any], value)
914 return all(_matches_annotation(item, item_type) for item in iterable_value)
917def _matches_mapping(value: Any, annotation: Any) -> bool:
918 if not isinstance(value, dict):
919 return False
920 key_type, value_type = (Any, Any)
921 args = get_args(annotation)
922 if len(args) == 2:
923 key_type, value_type = args
925 for key, item in value.items():
926 if not _matches_annotation(key, key_type):
927 return False
928 if not _matches_annotation(item, value_type):
929 return False
930 return True
933def _matches_tuple(value: Any, annotation: Any) -> bool:
934 if not isinstance(value, tuple):
935 return False
936 args = get_args(annotation)
937 if not args:
938 return True
939 if len(args) == 2 and args[1] is Ellipsis:
940 return all(_matches_annotation(item, args[0]) for item in value)
941 if len(args) != len(value):
942 return False
943 return all(_matches_annotation(item, item_type) for item, item_type in zip(value, args))
946def _format_annotation(annotation: Any) -> str:
947 try:
948 return inspect.formatannotation(annotation)
949 except Exception:
950 return str(annotation)
953__all__ = [
954 "AutotestCallContext",
955 "AutotestContext",
956 "AutotestSubtests",
957 "AutotestDataContext",
958 "AutotestDataCase",
959 "AutotestInvocation",
960 "AutotestMethodCase",
961 "AutotestCasePolicy",
962 "AutotestTypecheckMode",
963 "autotest",
964 "autotest_depends_on",
965 "autotest_data",
966 "autotest_hook",
967 "autotest_policy",
968 "autotest_params",
969 "clear_autotest_hooks",
970 "discover_autotest_methods",
971 "execute_autotest_case",
972 "execute_autotest_data_cases",
973 "execute_autotests",
974 "execute_autotests_with_subtests",
975 "find_autotest_hook",
976 "find_autotest_hook_dependencies",
977 "find_autotest_policy",
978 "find_autotest_params_dependencies",
979 "find_autotest_params_provider",
980]