Coverage for human_requests/autotest.py: 80%
586 statements
« prev ^ index » next coverage.py v7.14.1, created at 2026-05-28 00:39 +0000
« prev ^ index » next coverage.py v7.14.1, created at 2026-05-28 00:39 +0000
1from __future__ import annotations
3import heapq
4import inspect
5import types
6import warnings
7from collections.abc import Awaitable, Callable, Iterable, Mapping, Sequence
8from dataclasses import dataclass, field
9from typing import (
10 Annotated,
11 Any,
12 ContextManager,
13 Literal,
14 Protocol,
15 TypeVar,
16 Union,
17 cast,
18 get_args,
19 get_origin,
20)
22from .autotest_report import (
23 raise_autotest_hook_crash,
24 raise_autotest_method_crash,
25 raise_autotest_method_output_error,
26 raise_autotest_params_crash,
27)
29AutotestFunction = Callable[..., Any]
30AutotestHook = Callable[[Any, Any, "AutotestContext"], Any]
31AutotestParamProvider = Callable[["AutotestCallContext"], Any]
32AutotestDataProvider = Callable[["AutotestDataContext"], Any]
33AutotestTypecheckMode = Literal["off", "warn", "strict"]
34SnapshotName = str | int | Callable[..., Any] | list[str | int | Callable[..., Any]]
35HookKey = tuple[type[object] | None, AutotestFunction]
36DependencyMarker = Callable[..., Any]
38_AUTOTEST_ATTR = "__autotest__"
39_DEPENDS_ON_ATTR = "__autotest_depends_on__"
40_HOOKS: dict[HookKey, AutotestHook] = {}
41_PARAM_PROVIDERS: dict[HookKey, AutotestParamProvider] = {}
42_CASE_POLICIES: dict[HookKey, "AutotestCasePolicy"] = {}
43_DATA_CASES: list["AutotestDataCase"] = []
44_VALID_TYPECHECK_MODES: frozenset[str] = frozenset({"off", "warn", "strict"})
46_PrimitiveTypes = (str, bytes, bytearray, bool, int, float, complex, range, memoryview)
48F = TypeVar("F", bound=Callable[..., Any])
51def _format_autotest_object_type(value: Any) -> str:
52 value_type = type(value)
53 module = value_type.__module__
54 qualname = value_type.__qualname__
55 if module == "builtins":
56 return qualname
57 return f"{module}.{qualname}"
60def _format_autotest_wrong_object_message(response: Any) -> str:
61 return (
62 "The method completed successfully, but autotest expected the returned "
63 "human_requests.abstraction.Output object. Instead it received "
64 f"{_format_autotest_object_type(response)}."
65 )
68def _format_autotest_invalid_json_message() -> str:
69 return (
70 "The method completed successfully, but the returned "
71 "human_requests.abstraction.Output object contains invalid JSON."
72 )
75@dataclass(frozen=True)
76class AutotestContext:
77 api: object
78 owner: object
79 parent: object | None
80 method: Callable[..., Awaitable[Any]]
81 func: AutotestFunction
82 schemashot: Any
83 state: dict[str, Any]
86@dataclass(frozen=True)
87class AutotestMethodCase:
88 owner: object
89 parent: object | None
90 method: Callable[..., Awaitable[Any]]
91 func: AutotestFunction
92 required_parameters: tuple[str, ...]
93 depends_on: tuple[AutotestFunction, ...]
96@dataclass(frozen=True)
97class AutotestCallContext:
98 api: object
99 owner: object
100 parent: object | None
101 method: Callable[..., Awaitable[Any]]
102 func: AutotestFunction
103 schemashot: Any
104 state: dict[str, Any]
107@dataclass(frozen=True)
108class AutotestDataContext:
109 api: object
110 schemashot: Any
111 state: dict[str, Any]
114@dataclass(frozen=True)
115class AutotestInvocation:
116 args: tuple[Any, ...] = ()
117 kwargs: dict[str, Any] = field(default_factory=dict)
120@dataclass(frozen=True)
121class AutotestDataCase:
122 name: SnapshotName
123 provider: AutotestDataProvider
126@dataclass(frozen=True)
127class AutotestCasePolicy:
128 depends_on: tuple[AutotestFunction, ...] = ()
131class AutotestSubtests(Protocol):
132 def test(self, msg: str | None = None, **kwargs: Any) -> ContextManager[None]: ...
135def autotest(func: F) -> F:
136 setattr(func, _AUTOTEST_ATTR, True)
137 return func
140def autotest_depends_on(
141 target: Callable[..., Any],
142) -> Callable[[DependencyMarker], DependencyMarker]:
143 dependency = _as_function(target)
145 def decorator(callback: DependencyMarker) -> DependencyMarker:
146 existing = _get_callback_dependencies(callback)
147 if dependency in existing:
148 return callback
149 setattr(callback, _DEPENDS_ON_ATTR, (*existing, dependency))
150 return callback
152 return decorator
155def autotest_hook(
156 *,
157 target: Callable[..., Any],
158 parent: type[object] | None = None,
159) -> Callable[[AutotestHook], AutotestHook]:
160 if parent is not None and not inspect.isclass(parent):
161 raise TypeError("autotest_hook parent must be a class or None.")
163 target_func = _as_function(target)
165 def decorator(hook: AutotestHook) -> AutotestHook:
166 _HOOKS[(parent, target_func)] = hook
167 return hook
169 return decorator
172def autotest_params(
173 *,
174 target: Callable[..., Any],
175 parent: type[object] | None = None,
176 depends_on: Sequence[Callable[..., Any]] | None = None,
177) -> Callable[[AutotestParamProvider], AutotestParamProvider]:
178 if parent is not None and not inspect.isclass(parent):
179 raise TypeError("autotest_params parent must be a class or None.")
181 target_func = _as_function(target)
182 depends_on_funcs = _normalize_depends_on(depends_on)
184 def decorator(provider: AutotestParamProvider) -> AutotestParamProvider:
185 _PARAM_PROVIDERS[(parent, target_func)] = provider
186 if depends_on_funcs:
187 _register_case_policy(
188 parent=parent,
189 target_func=target_func,
190 depends_on=depends_on_funcs,
191 )
192 return provider
194 return decorator
197def autotest_policy(
198 *,
199 target: Callable[..., Any],
200 parent: type[object] | None = None,
201 depends_on: Sequence[Callable[..., Any]] | None = None,
202) -> Callable[[F], F]:
203 if parent is not None and not inspect.isclass(parent):
204 raise TypeError("autotest_policy parent must be a class or None.")
206 target_func = _as_function(target)
207 depends_on_funcs = _normalize_depends_on(depends_on)
208 _register_case_policy(
209 parent=parent,
210 target_func=target_func,
211 depends_on=depends_on_funcs,
212 )
214 def decorator(marker: F) -> F:
215 return marker
217 return decorator
220def autotest_data(
221 *,
222 name: SnapshotName,
223) -> Callable[[AutotestDataProvider], AutotestDataProvider]:
224 def decorator(provider: AutotestDataProvider) -> AutotestDataProvider:
225 _DATA_CASES.append(AutotestDataCase(name=name, provider=provider))
226 return provider
228 return decorator
231def clear_autotest_hooks() -> None:
232 _HOOKS.clear()
233 _PARAM_PROVIDERS.clear()
234 _CASE_POLICIES.clear()
235 _DATA_CASES.clear()
238def find_autotest_hook(
239 func: Callable[..., Any],
240 parent_object: object | None,
241) -> AutotestHook | None:
242 target_func = _as_function(func)
243 parent_class = parent_object.__class__ if parent_object is not None else None
245 if parent_class is not None:
246 parent_hook = _HOOKS.get((parent_class, target_func))
247 if parent_hook is not None:
248 return parent_hook
250 return _HOOKS.get((None, target_func))
253def find_autotest_params_provider(
254 func: Callable[..., Any],
255 parent_object: object | None,
256) -> AutotestParamProvider | None:
257 target_func = _as_function(func)
258 parent_class = parent_object.__class__ if parent_object is not None else None
260 if parent_class is not None:
261 parent_provider = _PARAM_PROVIDERS.get((parent_class, target_func))
262 if parent_provider is not None:
263 return parent_provider
265 return _PARAM_PROVIDERS.get((None, target_func))
268def find_autotest_policy(
269 func: Callable[..., Any],
270 parent_object: object | None,
271) -> AutotestCasePolicy:
272 target_func = _as_function(func)
273 parent_class = parent_object.__class__ if parent_object is not None else None
275 if parent_class is not None:
276 parent_policy = _CASE_POLICIES.get((parent_class, target_func))
277 if parent_policy is not None:
278 return parent_policy
280 return _CASE_POLICIES.get((None, target_func), AutotestCasePolicy())
283def find_autotest_hook_dependencies(
284 func: Callable[..., Any],
285 parent_object: object | None,
286) -> tuple[AutotestFunction, ...]:
287 hook = find_autotest_hook(func, parent_object)
288 if hook is None:
289 return ()
290 return _get_callback_dependencies(hook)
293def find_autotest_params_dependencies(
294 func: Callable[..., Any],
295 parent_object: object | None,
296) -> tuple[AutotestFunction, ...]:
297 provider = find_autotest_params_provider(func, parent_object)
298 if provider is None:
299 return ()
300 return _get_callback_dependencies(provider)
303def discover_autotest_methods(api: object) -> list[AutotestMethodCase]:
304 cases: list[AutotestMethodCase] = []
305 visited: set[int] = set()
307 def walk(owner: object, parent: object | None) -> None:
308 owner_id = id(owner)
309 if owner_id in visited:
310 return
311 visited.add(owner_id)
313 for attr_name in sorted(dir(owner)):
314 if attr_name.startswith("_"):
315 continue
317 try:
318 value = getattr(owner, attr_name)
319 except Exception:
320 continue
322 bound_method = _as_bound_method(value)
323 if bound_method is not None:
324 func = _as_function(bound_method)
325 if _is_autotest(func):
326 required_parameters = _required_parameters(bound_method)
327 provider = find_autotest_params_provider(func, parent)
328 policy = find_autotest_policy(func, parent)
329 dependencies = _merge_dependencies(
330 policy.depends_on,
331 find_autotest_params_dependencies(func, parent),
332 find_autotest_hook_dependencies(func, parent),
333 )
334 if required_parameters and provider is None:
335 joined = ", ".join(required_parameters)
336 raise TypeError(
337 f"Autotest method {func.__qualname__} requires arguments ({joined}). "
338 "Register @autotest_params(target=...) for this method."
339 )
340 cases.append(
341 AutotestMethodCase(
342 owner=owner,
343 parent=parent,
344 method=cast(Callable[..., Awaitable[Any]], bound_method),
345 func=func,
346 required_parameters=required_parameters,
347 depends_on=dependencies,
348 )
349 )
350 continue
352 if _is_resource_object(value):
353 walk(value, owner)
355 walk(api, None)
356 return _order_cases(cases)
359async def execute_autotests(
360 api: object,
361 schemashot: Any,
362 *,
363 typecheck_mode: AutotestTypecheckMode | str = "off",
364 trace_limit: int = 3,
365 truncation_context_lines: int = 3,
366 case_status_recorder: Callable[[str, str], None] | None = None,
367 success_recorder: Callable[[str], None] | None = None,
368) -> int:
369 _validate_schemashot(schemashot)
370 resolved_typecheck_mode = _normalize_typecheck_mode(typecheck_mode)
372 executed_count = 0
373 state, completed_funcs, skipped_funcs = _initialize_runtime_state()
375 cases = discover_autotest_methods(api)
376 for case in cases:
377 label = case.func.__qualname__
378 missing_dependencies = tuple(dep for dep in case.depends_on if dep not in completed_funcs)
379 if missing_dependencies:
380 skipped_funcs.add(case.func)
381 _record_case_status(case_status_recorder, label, "skipped")
382 continue
384 try:
385 await execute_autotest_case(
386 case=case,
387 api=api,
388 schemashot=schemashot,
389 state=state,
390 typecheck_mode=resolved_typecheck_mode,
391 trace_limit=trace_limit,
392 truncation_context_lines=truncation_context_lines,
393 success_recorder=success_recorder,
394 )
395 except BaseException as error: # pragma: no cover - runtime-only branch for skip semantics
396 if _is_pytest_skip_exception(error):
397 skipped_funcs.add(case.func)
398 _record_case_status(case_status_recorder, label, "skipped")
399 continue
400 _record_case_status(case_status_recorder, label, "failed")
401 raise
403 completed_funcs.add(case.func)
404 executed_count += 1
405 _record_case_status(case_status_recorder, label, "passed")
407 executed_count += await execute_autotest_data_cases(
408 api=api,
409 schemashot=schemashot,
410 state=state,
411 case_status_recorder=case_status_recorder,
412 )
413 return executed_count
416async def execute_autotests_with_subtests(
417 api: object,
418 schemashot: Any,
419 *,
420 subtests: AutotestSubtests,
421 typecheck_mode: AutotestTypecheckMode | str = "off",
422 trace_limit: int = 3,
423 truncation_context_lines: int = 3,
424 case_status_recorder: Callable[[str, str], None] | None = None,
425 success_recorder: Callable[[str], None] | None = None,
426) -> int:
427 _validate_schemashot(schemashot)
428 resolved_typecheck_mode = _normalize_typecheck_mode(typecheck_mode)
430 processed_count = 0
431 state, completed_funcs, skipped_funcs = _initialize_runtime_state()
433 cases = discover_autotest_methods(api)
434 for case in cases:
435 processed_count += 1
436 label = case.func.__qualname__
437 case_succeeded = False
438 with subtests.test(**_subtest_label_for_case(case)):
439 missing_dependencies = tuple(
440 dep for dep in case.depends_on if dep not in completed_funcs
441 )
442 if missing_dependencies:
443 skipped_funcs.add(case.func)
444 _record_case_status(case_status_recorder, label, "skipped")
445 _skip_current_case(_format_dependency_skip_reason(missing_dependencies))
447 try:
448 await execute_autotest_case(
449 case=case,
450 api=api,
451 schemashot=schemashot,
452 state=state,
453 typecheck_mode=resolved_typecheck_mode,
454 trace_limit=trace_limit,
455 truncation_context_lines=truncation_context_lines,
456 success_recorder=success_recorder,
457 )
458 except BaseException as error: # pragma: no cover - runtime-only skip branch
459 if _is_pytest_skip_exception(error):
460 skipped_funcs.add(case.func)
461 _record_case_status(case_status_recorder, label, "skipped")
462 else:
463 _record_case_status(case_status_recorder, label, "failed")
464 raise
466 case_succeeded = True
467 _record_case_status(case_status_recorder, label, "passed")
469 if case_succeeded:
470 completed_funcs.add(case.func)
472 processed_count += await _execute_autotest_data_cases_with_subtests(
473 api=api,
474 schemashot=schemashot,
475 state=state,
476 subtests=subtests,
477 case_status_recorder=case_status_recorder,
478 )
479 return processed_count
482async def execute_autotest_case(
483 *,
484 case: AutotestMethodCase,
485 api: object,
486 schemashot: Any,
487 state: dict[str, Any] | None = None,
488 typecheck_mode: AutotestTypecheckMode | str = "off",
489 trace_limit: int = 3,
490 truncation_context_lines: int = 3,
491 success_recorder: Callable[[str], None] | None = None,
492) -> None:
493 _validate_schemashot(schemashot)
494 resolved_typecheck_mode = _normalize_typecheck_mode(typecheck_mode)
495 runtime_state = state if state is not None else {}
496 crash_error: BaseException | None = None
497 invocation = await _resolve_invocation(
498 case=case,
499 api=api,
500 schemashot=schemashot,
501 state=runtime_state,
502 typecheck_mode=resolved_typecheck_mode,
503 trace_limit=trace_limit,
504 truncation_context_lines=truncation_context_lines,
505 )
507 try:
508 response = await _invoke_method(case.method, case.func, invocation)
509 except BaseException as error: # pragma: no cover - runtime-only branch for crash reporting
510 if isinstance(error, (KeyboardInterrupt, SystemExit)):
511 raise
512 if _is_pytest_skip_exception(error):
513 raise
514 crash_error = error
516 if crash_error is not None:
517 raise_autotest_method_crash(
518 api=api,
519 func=case.func,
520 error=crash_error,
521 source_func=case.func,
522 trace_limit=trace_limit,
523 truncation_context_lines=truncation_context_lines,
524 )
526 ctx = AutotestContext(
527 api=api,
528 owner=case.owner,
529 parent=case.parent,
530 method=case.method,
531 func=case.func,
532 schemashot=schemashot,
533 state=runtime_state,
534 )
536 response_json_error: BaseException | None = None
537 response_json_detail_message: str | None = None
538 hook = find_autotest_hook(case.func, case.parent)
539 if hook is not None:
540 try:
541 if not hasattr(response, "json") or not callable(response.json):
542 response_json_error = TypeError(
543 f"Autotest method {case.func.__qualname__} must return an object with json()."
544 )
545 response_json_detail_message = _format_autotest_wrong_object_message(response)
546 else:
547 data = response.json()
548 except BaseException as error: # pragma: no cover - runtime-only branch for crash reporting
549 if isinstance(error, (KeyboardInterrupt, SystemExit)):
550 raise
551 if _is_pytest_skip_exception(error):
552 raise
553 response_json_error = error
554 response_json_detail_message = _format_autotest_invalid_json_message()
556 if response_json_error is not None:
557 raise_autotest_method_output_error(
558 api=api,
559 func=case.func,
560 error=response_json_error,
561 detail_message=response_json_detail_message,
562 trace_limit=trace_limit,
563 truncation_context_lines=truncation_context_lines,
564 )
566 hook_error: BaseException | None = None
567 hook_result: Any = None
568 try:
569 hook_result = hook(response, data, ctx)
570 if inspect.isawaitable(hook_result):
571 hook_result = await hook_result
572 except BaseException as error: # pragma: no cover - runtime-only branch for crash reporting
573 if isinstance(error, (KeyboardInterrupt, SystemExit)):
574 raise
575 if _is_pytest_skip_exception(error):
576 raise
577 hook_error = error
578 if hook_error is not None:
579 raise_autotest_hook_crash(
580 api=api,
581 hook=hook,
582 error=hook_error,
583 source_func=hook,
584 trace_limit=trace_limit,
585 truncation_context_lines=truncation_context_lines,
586 )
587 if hook_result is not None:
588 data = hook_result
589 else:
590 try:
591 if not hasattr(response, "json") or not callable(response.json):
592 response_json_error = TypeError(
593 f"Autotest method {case.func.__qualname__} must return an object with json()."
594 )
595 response_json_detail_message = _format_autotest_wrong_object_message(response)
596 else:
597 data = response.json()
598 except BaseException as error: # pragma: no cover - runtime-only branch for crash reporting
599 if isinstance(error, (KeyboardInterrupt, SystemExit)):
600 raise
601 if _is_pytest_skip_exception(error):
602 raise
603 response_json_error = error
604 response_json_detail_message = _format_autotest_invalid_json_message()
606 if response_json_error is not None:
607 raise_autotest_method_output_error(
608 api=api,
609 func=case.func,
610 error=response_json_error,
611 source_func=case.func,
612 detail_message=response_json_detail_message,
613 trace_limit=trace_limit,
614 truncation_context_lines=truncation_context_lines,
615 )
617 schemashot.assert_json_match(data, case.func)
618 if success_recorder is not None:
619 success_recorder(case.func.__qualname__)
622async def execute_autotest_data_cases(
623 *,
624 api: object,
625 schemashot: Any,
626 state: dict[str, Any] | None = None,
627 case_status_recorder: Callable[[str, str], None] | None = None,
628) -> int:
629 _validate_schemashot(schemashot)
630 runtime_state = state if state is not None else {}
631 ctx = AutotestDataContext(api=api, schemashot=schemashot, state=runtime_state)
633 for case in list(_DATA_CASES):
634 label = _snapshot_name_repr(case.name)
635 try:
636 payload = case.provider(ctx)
637 if inspect.isawaitable(payload):
638 payload = await payload
639 schemashot.assert_json_match(payload, case.name)
640 except BaseException as error: # pragma: no cover - runtime-only branch for skip semantics
641 if isinstance(error, (KeyboardInterrupt, SystemExit)):
642 raise
643 if _is_pytest_skip_exception(error):
644 _record_case_status(case_status_recorder, label, "skipped")
645 raise
646 _record_case_status(case_status_recorder, label, "failed")
647 raise
649 _record_case_status(case_status_recorder, label, "passed")
651 return len(_DATA_CASES)
654async def _execute_autotest_data_cases_with_subtests(
655 *,
656 api: object,
657 schemashot: Any,
658 state: dict[str, Any],
659 subtests: AutotestSubtests,
660 case_status_recorder: Callable[[str, str], None] | None = None,
661) -> int:
662 _validate_schemashot(schemashot)
663 ctx = AutotestDataContext(api=api, schemashot=schemashot, state=state)
665 processed_count = 0
666 for case in list(_DATA_CASES):
667 processed_count += 1
668 label = _snapshot_name_repr(case.name)
669 with subtests.test(**_subtest_label_for_data(case)):
670 try:
671 payload = case.provider(ctx)
672 if inspect.isawaitable(payload):
673 payload = await payload
674 schemashot.assert_json_match(payload, case.name)
675 except BaseException as error: # pragma: no cover - runtime-only skip branch
676 if _is_pytest_skip_exception(error):
677 _record_case_status(case_status_recorder, label, "skipped")
678 else:
679 _record_case_status(case_status_recorder, label, "failed")
680 raise
682 _record_case_status(case_status_recorder, label, "passed")
684 return processed_count
687def _is_autotest(func: AutotestFunction) -> bool:
688 return bool(getattr(func, _AUTOTEST_ATTR, False))
691def _as_function(target: Callable[..., Any]) -> AutotestFunction:
692 unbound: Any = target.__func__ if inspect.ismethod(target) else target
693 if not callable(unbound):
694 raise TypeError("Target must be a function or method.")
695 return cast(AutotestFunction, inspect.unwrap(unbound))
698def _as_bound_method(value: Any) -> Callable[..., Any] | None:
699 if inspect.ismethod(value) and value.__self__ is not None:
700 return cast(Callable[..., Any], value)
701 return None
704def _is_resource_object(value: Any) -> bool:
705 if value is None:
706 return False
707 if isinstance(value, _PrimitiveTypes):
708 return False
709 if isinstance(value, (dict, list, tuple, set, frozenset)):
710 return False
711 if inspect.ismodule(value) or inspect.isclass(value) or inspect.isfunction(value):
712 return False
713 if inspect.ismethod(value) or inspect.isbuiltin(value):
714 return False
715 return hasattr(value, "__dict__") or hasattr(value, "__slots__")
718def _order_cases(cases: list[AutotestMethodCase]) -> list[AutotestMethodCase]:
719 if len(cases) < 2:
720 return cases
722 index_by_func: dict[AutotestFunction, list[int]] = {}
723 for index, case in enumerate(cases):
724 index_by_func.setdefault(case.func, []).append(index)
726 edges: dict[int, set[int]] = {index: set() for index in range(len(cases))}
727 indegree = [0] * len(cases)
729 for target_index, case in enumerate(cases):
730 for dependency in case.depends_on:
731 for source_index in index_by_func.get(dependency, []):
732 if source_index == target_index:
733 continue
734 if target_index in edges[source_index]:
735 continue
736 edges[source_index].add(target_index)
737 indegree[target_index] += 1
739 queue: list[tuple[str, int]] = []
740 for index, case in enumerate(cases):
741 if indegree[index] == 0:
742 heapq.heappush(queue, (case.func.__qualname__, index))
744 ordered: list[AutotestMethodCase] = []
745 while queue:
746 _, current_index = heapq.heappop(queue)
747 ordered.append(cases[current_index])
748 for dependent_index in edges[current_index]:
749 indegree[dependent_index] -= 1
750 if indegree[dependent_index] == 0:
751 heapq.heappush(
752 queue,
753 (cases[dependent_index].func.__qualname__, dependent_index),
754 )
756 if len(ordered) == len(cases):
757 return ordered
759 for index, case in enumerate(cases):
760 if indegree[index] > 0:
761 ordered.append(case)
762 return ordered
765def _required_parameters(method: Callable[..., Any]) -> tuple[str, ...]:
766 required_arguments: list[str] = []
767 signature = inspect.signature(method)
768 for parameter in signature.parameters.values():
769 if (
770 parameter.kind
771 in (
772 inspect.Parameter.POSITIONAL_ONLY,
773 inspect.Parameter.POSITIONAL_OR_KEYWORD,
774 inspect.Parameter.KEYWORD_ONLY,
775 )
776 and parameter.default is inspect.Signature.empty
777 ):
778 required_arguments.append(parameter.name)
780 return tuple(required_arguments)
783def _normalize_depends_on(
784 depends_on: Sequence[Callable[..., Any]] | None,
785) -> tuple[AutotestFunction, ...]:
786 if depends_on is None:
787 return ()
789 if not isinstance(depends_on, Sequence):
790 raise TypeError("depends_on must be a sequence of functions or methods.")
792 return tuple(_as_function(target) for target in depends_on)
795def _register_case_policy(
796 *,
797 parent: type[object] | None,
798 target_func: AutotestFunction,
799 depends_on: Iterable[AutotestFunction] = (),
800) -> None:
801 current = _CASE_POLICIES.get((parent, target_func), AutotestCasePolicy())
802 resolved_depends = tuple(depends_on) if depends_on else current.depends_on
803 _CASE_POLICIES[(parent, target_func)] = AutotestCasePolicy(
804 depends_on=resolved_depends,
805 )
808def _get_callback_dependencies(callback: Callable[..., Any]) -> tuple[AutotestFunction, ...]:
809 raw = getattr(callback, _DEPENDS_ON_ATTR, ())
810 if not isinstance(raw, tuple):
811 return ()
812 return tuple(_as_function(dep) for dep in raw)
815def _merge_dependencies(
816 *sources: Iterable[AutotestFunction],
817) -> tuple[AutotestFunction, ...]:
818 merged: list[AutotestFunction] = []
819 for source in sources:
820 for dependency in source:
821 if dependency not in merged:
822 merged.append(dependency)
823 return tuple(merged)
826def _initialize_runtime_state() -> tuple[
827 dict[str, Any],
828 set[AutotestFunction],
829 set[AutotestFunction],
830]:
831 state: dict[str, Any] = {}
832 completed_funcs: set[AutotestFunction] = set()
833 skipped_funcs: set[AutotestFunction] = set()
834 state["autotest_completed_funcs"] = completed_funcs
835 state["autotest_skipped_funcs"] = skipped_funcs
836 return state, completed_funcs, skipped_funcs
839def _subtest_label_for_case(case: AutotestMethodCase) -> dict[str, str]:
840 return {"msg": case.func.__qualname__}
843def _subtest_label_for_data(case: AutotestDataCase) -> dict[str, str]:
844 return {"msg": _snapshot_name_repr(case.name)}
847def _snapshot_name_repr(name: object) -> str:
848 if isinstance(name, (str, int)):
849 return str(name)
850 if callable(name):
851 return getattr(name, "__qualname__", repr(name))
852 if isinstance(name, list):
853 return "[" + ", ".join(_snapshot_name_repr(item) for item in name) + "]"
854 return repr(name)
857def _record_case_status(
858 recorder: Callable[[str, str], None] | None,
859 label: str,
860 status: str,
861) -> None:
862 if recorder is not None:
863 recorder(label, status)
866def _format_dependency_skip_reason(missing: tuple[AutotestFunction, ...]) -> str:
867 names = ", ".join(dep.__qualname__ for dep in missing)
868 return f"Dependency was not executed: {names}"
871def _skip_current_case(reason: str) -> None:
872 try:
873 import pytest
874 except Exception as error: # pragma: no cover - pytest runtime always has pytest installed
875 raise RuntimeError("pytest is required to skip autotest subtest cases.") from error
877 pytest.skip(reason)
880def _is_pytest_skip_exception(error: BaseException) -> bool:
881 try:
882 import pytest
883 except Exception:
884 return False
886 skip_exception = getattr(pytest.skip, "Exception", None)
887 return bool(skip_exception and isinstance(error, skip_exception))
890def _validate_schemashot(schemashot: Any) -> None:
891 if not hasattr(schemashot, "assert_json_match") or not callable(schemashot.assert_json_match):
892 raise TypeError(
893 "schemashot fixture must provide a callable assert_json_match(data, name) method."
894 )
897async def _invoke_method(
898 method: Callable[..., Any],
899 func: AutotestFunction,
900 invocation: AutotestInvocation,
901) -> Any:
902 result = method(*invocation.args, **invocation.kwargs)
903 if not inspect.isawaitable(result):
904 raise TypeError(f"Autotest method {func.__qualname__} must be async.")
905 return await result
908async def _resolve_invocation(
909 *,
910 case: AutotestMethodCase,
911 api: object,
912 schemashot: Any,
913 state: dict[str, Any],
914 typecheck_mode: AutotestTypecheckMode,
915 trace_limit: int = 3,
916 truncation_context_lines: int = 3,
917) -> AutotestInvocation:
918 provider = find_autotest_params_provider(case.func, case.parent)
919 if provider is None:
920 invocation = AutotestInvocation()
921 else:
922 ctx = AutotestCallContext(
923 api=api,
924 owner=case.owner,
925 parent=case.parent,
926 method=case.method,
927 func=case.func,
928 schemashot=schemashot,
929 state=state,
930 )
931 try:
932 raw = provider(ctx)
933 if inspect.isawaitable(raw):
934 raw = await raw
935 invocation = _normalize_invocation(raw, case.func)
936 except BaseException as error: # pragma: no cover - runtime-only branch for crash reporting
937 if isinstance(error, (KeyboardInterrupt, SystemExit)):
938 raise
939 if _is_pytest_skip_exception(error):
940 raise
941 raise_autotest_params_crash(
942 api=api,
943 params_provider=provider,
944 error=error,
945 trace_limit=trace_limit,
946 truncation_context_lines=truncation_context_lines,
947 )
949 _validate_invocation(
950 case.method,
951 case.func,
952 invocation,
953 typecheck_mode=typecheck_mode,
954 )
955 return invocation
958def _normalize_invocation(raw: Any, func: AutotestFunction) -> AutotestInvocation:
959 if raw is None:
960 return AutotestInvocation()
961 if isinstance(raw, AutotestInvocation):
962 return AutotestInvocation(args=tuple(raw.args), kwargs=dict(raw.kwargs))
963 if isinstance(raw, dict):
964 return AutotestInvocation(kwargs=dict(raw))
965 if isinstance(raw, (tuple, list)):
966 return AutotestInvocation(args=tuple(raw))
968 raise TypeError(
969 f"autotest_params provider for {func.__qualname__} must return one of: "
970 "None, dict (kwargs), tuple/list (args), AutotestInvocation."
971 )
974def _validate_invocation(
975 method: Callable[..., Any],
976 func: AutotestFunction,
977 invocation: AutotestInvocation,
978 *,
979 typecheck_mode: AutotestTypecheckMode = "off",
980) -> None:
981 signature = inspect.signature(method)
982 try:
983 bound_arguments = signature.bind(*invocation.args, **invocation.kwargs)
984 except TypeError as error:
985 raise TypeError(f"Invalid invocation for {func.__qualname__}: {error}") from error
986 _validate_invocation_types(
987 signature=signature,
988 bound_arguments=bound_arguments.arguments,
989 method=method,
990 func=func,
991 typecheck_mode=typecheck_mode,
992 )
995def _normalize_typecheck_mode(mode: AutotestTypecheckMode | str) -> AutotestTypecheckMode:
996 if not isinstance(mode, str):
997 raise TypeError("autotest typecheck mode must be a string.")
998 normalized = mode.strip().lower()
999 if normalized not in _VALID_TYPECHECK_MODES:
1000 expected = ", ".join(sorted(_VALID_TYPECHECK_MODES))
1001 raise ValueError(f"autotest typecheck mode must be one of: {expected}.")
1002 return cast(AutotestTypecheckMode, normalized)
1005def _validate_invocation_types(
1006 *,
1007 signature: inspect.Signature,
1008 bound_arguments: Mapping[str, Any],
1009 method: Callable[..., Any],
1010 func: AutotestFunction,
1011 typecheck_mode: AutotestTypecheckMode,
1012) -> None:
1013 if typecheck_mode == "off":
1014 return
1016 mismatches: list[str] = []
1017 for name, value in bound_arguments.items():
1018 parameter = signature.parameters.get(name)
1019 if parameter is None:
1020 continue
1022 annotation = _resolve_annotation(parameter.annotation, method)
1023 if annotation is inspect.Signature.empty:
1024 continue
1026 if _matches_annotation(value, annotation):
1027 continue
1029 expected = _format_annotation(annotation)
1030 mismatches.append(f"parameter {name!r} expects {expected}, got {type(value).__name__}")
1032 if not mismatches:
1033 return
1035 details = "; ".join(mismatches)
1036 message = f"Invalid invocation types for {func.__qualname__}: {details}."
1037 if typecheck_mode == "strict":
1038 raise TypeError(message)
1039 warnings.warn(message, RuntimeWarning, stacklevel=4)
1042def _resolve_annotation(annotation: Any, method: Callable[..., Any]) -> Any:
1043 if annotation is inspect.Signature.empty:
1044 return annotation
1046 if not isinstance(annotation, str):
1047 return annotation
1049 globals_dict = getattr(method, "__globals__", {})
1050 if not isinstance(globals_dict, dict):
1051 return inspect.Signature.empty
1053 try:
1054 return eval(annotation, globals_dict, {})
1055 except Exception:
1056 return inspect.Signature.empty
1059def _matches_annotation(value: Any, annotation: Any) -> bool:
1060 if annotation in (Any, object):
1061 return True
1062 if annotation in (None, type(None)):
1063 return value is None
1065 supertype = getattr(annotation, "__supertype__", None)
1066 if supertype is not None:
1067 return _matches_annotation(value, supertype)
1069 origin = get_origin(annotation)
1070 if origin is None:
1071 if isinstance(annotation, type):
1072 return isinstance(value, annotation)
1073 return True
1075 if origin in (types.UnionType, Union):
1076 return any(_matches_annotation(value, arg) for arg in get_args(annotation))
1078 if origin is Annotated:
1079 args = get_args(annotation)
1080 if not args:
1081 return True
1082 return _matches_annotation(value, args[0])
1084 if origin is Literal:
1085 return any(value == option for option in get_args(annotation))
1087 if origin is list:
1088 return _matches_iterable(value, annotation, list)
1089 if origin is set:
1090 return _matches_iterable(value, annotation, set)
1091 if origin is frozenset:
1092 return _matches_iterable(value, annotation, frozenset)
1093 if origin is tuple:
1094 return _matches_tuple(value, annotation)
1095 if origin is dict:
1096 return _matches_mapping(value, annotation)
1098 if isinstance(origin, type):
1099 return isinstance(value, origin)
1101 return True
1104def _matches_iterable(value: Any, annotation: Any, container_type: type[object]) -> bool:
1105 if not isinstance(value, container_type):
1106 return False
1107 args = get_args(annotation)
1108 if not args:
1109 return True
1110 item_type = args[0]
1111 iterable_value = cast(Iterable[Any], value)
1112 return all(_matches_annotation(item, item_type) for item in iterable_value)
1115def _matches_mapping(value: Any, annotation: Any) -> bool:
1116 if not isinstance(value, dict):
1117 return False
1118 key_type, value_type = (Any, Any)
1119 args = get_args(annotation)
1120 if len(args) == 2:
1121 key_type, value_type = args
1123 for key, item in value.items():
1124 if not _matches_annotation(key, key_type):
1125 return False
1126 if not _matches_annotation(item, value_type):
1127 return False
1128 return True
1131def _matches_tuple(value: Any, annotation: Any) -> bool:
1132 if not isinstance(value, tuple):
1133 return False
1134 args = get_args(annotation)
1135 if not args:
1136 return True
1137 if len(args) == 2 and args[1] is Ellipsis:
1138 return all(_matches_annotation(item, args[0]) for item in value)
1139 if len(args) != len(value):
1140 return False
1141 return all(_matches_annotation(item, item_type) for item, item_type in zip(value, args))
1144def _format_annotation(annotation: Any) -> str:
1145 try:
1146 return inspect.formatannotation(annotation)
1147 except Exception:
1148 return str(annotation)
1151__all__ = [
1152 "AutotestCallContext",
1153 "AutotestContext",
1154 "AutotestSubtests",
1155 "AutotestDataContext",
1156 "AutotestDataCase",
1157 "AutotestInvocation",
1158 "AutotestMethodCase",
1159 "AutotestCasePolicy",
1160 "AutotestTypecheckMode",
1161 "autotest",
1162 "autotest_depends_on",
1163 "autotest_data",
1164 "autotest_hook",
1165 "autotest_policy",
1166 "autotest_params",
1167 "clear_autotest_hooks",
1168 "discover_autotest_methods",
1169 "execute_autotest_case",
1170 "execute_autotest_data_cases",
1171 "execute_autotests",
1172 "execute_autotests_with_subtests",
1173 "find_autotest_hook",
1174 "find_autotest_hook_dependencies",
1175 "find_autotest_policy",
1176 "find_autotest_params_dependencies",
1177 "find_autotest_params_provider",
1178]