Coverage for human_requests/autotest.py: 80%

586 statements  

« prev     ^ index     » next       coverage.py v7.14.1, created at 2026-05-28 00:39 +0000

1from __future__ import annotations 

2 

3import heapq 

4import inspect 

5import types 

6import warnings 

7from collections.abc import Awaitable, Callable, Iterable, Mapping, Sequence 

8from dataclasses import dataclass, field 

9from typing import ( 

10 Annotated, 

11 Any, 

12 ContextManager, 

13 Literal, 

14 Protocol, 

15 TypeVar, 

16 Union, 

17 cast, 

18 get_args, 

19 get_origin, 

20) 

21 

22from .autotest_report import ( 

23 raise_autotest_hook_crash, 

24 raise_autotest_method_crash, 

25 raise_autotest_method_output_error, 

26 raise_autotest_params_crash, 

27) 

28 

29AutotestFunction = Callable[..., Any] 

30AutotestHook = Callable[[Any, Any, "AutotestContext"], Any] 

31AutotestParamProvider = Callable[["AutotestCallContext"], Any] 

32AutotestDataProvider = Callable[["AutotestDataContext"], Any] 

33AutotestTypecheckMode = Literal["off", "warn", "strict"] 

34SnapshotName = str | int | Callable[..., Any] | list[str | int | Callable[..., Any]] 

35HookKey = tuple[type[object] | None, AutotestFunction] 

36DependencyMarker = Callable[..., Any] 

37 

38_AUTOTEST_ATTR = "__autotest__" 

39_DEPENDS_ON_ATTR = "__autotest_depends_on__" 

40_HOOKS: dict[HookKey, AutotestHook] = {} 

41_PARAM_PROVIDERS: dict[HookKey, AutotestParamProvider] = {} 

42_CASE_POLICIES: dict[HookKey, "AutotestCasePolicy"] = {} 

43_DATA_CASES: list["AutotestDataCase"] = [] 

44_VALID_TYPECHECK_MODES: frozenset[str] = frozenset({"off", "warn", "strict"}) 

45 

46_PrimitiveTypes = (str, bytes, bytearray, bool, int, float, complex, range, memoryview) 

47 

48F = TypeVar("F", bound=Callable[..., Any]) 

49 

50 

51def _format_autotest_object_type(value: Any) -> str: 

52 value_type = type(value) 

53 module = value_type.__module__ 

54 qualname = value_type.__qualname__ 

55 if module == "builtins": 

56 return qualname 

57 return f"{module}.{qualname}" 

58 

59 

60def _format_autotest_wrong_object_message(response: Any) -> str: 

61 return ( 

62 "The method completed successfully, but autotest expected the returned " 

63 "human_requests.abstraction.Output object. Instead it received " 

64 f"{_format_autotest_object_type(response)}." 

65 ) 

66 

67 

68def _format_autotest_invalid_json_message() -> str: 

69 return ( 

70 "The method completed successfully, but the returned " 

71 "human_requests.abstraction.Output object contains invalid JSON." 

72 ) 

73 

74 

75@dataclass(frozen=True) 

76class AutotestContext: 

77 api: object 

78 owner: object 

79 parent: object | None 

80 method: Callable[..., Awaitable[Any]] 

81 func: AutotestFunction 

82 schemashot: Any 

83 state: dict[str, Any] 

84 

85 

86@dataclass(frozen=True) 

87class AutotestMethodCase: 

88 owner: object 

89 parent: object | None 

90 method: Callable[..., Awaitable[Any]] 

91 func: AutotestFunction 

92 required_parameters: tuple[str, ...] 

93 depends_on: tuple[AutotestFunction, ...] 

94 

95 

96@dataclass(frozen=True) 

97class AutotestCallContext: 

98 api: object 

99 owner: object 

100 parent: object | None 

101 method: Callable[..., Awaitable[Any]] 

102 func: AutotestFunction 

103 schemashot: Any 

104 state: dict[str, Any] 

105 

106 

107@dataclass(frozen=True) 

108class AutotestDataContext: 

109 api: object 

110 schemashot: Any 

111 state: dict[str, Any] 

112 

113 

114@dataclass(frozen=True) 

115class AutotestInvocation: 

116 args: tuple[Any, ...] = () 

117 kwargs: dict[str, Any] = field(default_factory=dict) 

118 

119 

120@dataclass(frozen=True) 

121class AutotestDataCase: 

122 name: SnapshotName 

123 provider: AutotestDataProvider 

124 

125 

126@dataclass(frozen=True) 

127class AutotestCasePolicy: 

128 depends_on: tuple[AutotestFunction, ...] = () 

129 

130 

131class AutotestSubtests(Protocol): 

132 def test(self, msg: str | None = None, **kwargs: Any) -> ContextManager[None]: ... 

133 

134 

135def autotest(func: F) -> F: 

136 setattr(func, _AUTOTEST_ATTR, True) 

137 return func 

138 

139 

140def autotest_depends_on( 

141 target: Callable[..., Any], 

142) -> Callable[[DependencyMarker], DependencyMarker]: 

143 dependency = _as_function(target) 

144 

145 def decorator(callback: DependencyMarker) -> DependencyMarker: 

146 existing = _get_callback_dependencies(callback) 

147 if dependency in existing: 

148 return callback 

149 setattr(callback, _DEPENDS_ON_ATTR, (*existing, dependency)) 

150 return callback 

151 

152 return decorator 

153 

154 

155def autotest_hook( 

156 *, 

157 target: Callable[..., Any], 

158 parent: type[object] | None = None, 

159) -> Callable[[AutotestHook], AutotestHook]: 

160 if parent is not None and not inspect.isclass(parent): 

161 raise TypeError("autotest_hook parent must be a class or None.") 

162 

163 target_func = _as_function(target) 

164 

165 def decorator(hook: AutotestHook) -> AutotestHook: 

166 _HOOKS[(parent, target_func)] = hook 

167 return hook 

168 

169 return decorator 

170 

171 

172def autotest_params( 

173 *, 

174 target: Callable[..., Any], 

175 parent: type[object] | None = None, 

176 depends_on: Sequence[Callable[..., Any]] | None = None, 

177) -> Callable[[AutotestParamProvider], AutotestParamProvider]: 

178 if parent is not None and not inspect.isclass(parent): 

179 raise TypeError("autotest_params parent must be a class or None.") 

180 

181 target_func = _as_function(target) 

182 depends_on_funcs = _normalize_depends_on(depends_on) 

183 

184 def decorator(provider: AutotestParamProvider) -> AutotestParamProvider: 

185 _PARAM_PROVIDERS[(parent, target_func)] = provider 

186 if depends_on_funcs: 

187 _register_case_policy( 

188 parent=parent, 

189 target_func=target_func, 

190 depends_on=depends_on_funcs, 

191 ) 

192 return provider 

193 

194 return decorator 

195 

196 

197def autotest_policy( 

198 *, 

199 target: Callable[..., Any], 

200 parent: type[object] | None = None, 

201 depends_on: Sequence[Callable[..., Any]] | None = None, 

202) -> Callable[[F], F]: 

203 if parent is not None and not inspect.isclass(parent): 

204 raise TypeError("autotest_policy parent must be a class or None.") 

205 

206 target_func = _as_function(target) 

207 depends_on_funcs = _normalize_depends_on(depends_on) 

208 _register_case_policy( 

209 parent=parent, 

210 target_func=target_func, 

211 depends_on=depends_on_funcs, 

212 ) 

213 

214 def decorator(marker: F) -> F: 

215 return marker 

216 

217 return decorator 

218 

219 

220def autotest_data( 

221 *, 

222 name: SnapshotName, 

223) -> Callable[[AutotestDataProvider], AutotestDataProvider]: 

224 def decorator(provider: AutotestDataProvider) -> AutotestDataProvider: 

225 _DATA_CASES.append(AutotestDataCase(name=name, provider=provider)) 

226 return provider 

227 

228 return decorator 

229 

230 

231def clear_autotest_hooks() -> None: 

232 _HOOKS.clear() 

233 _PARAM_PROVIDERS.clear() 

234 _CASE_POLICIES.clear() 

235 _DATA_CASES.clear() 

236 

237 

238def find_autotest_hook( 

239 func: Callable[..., Any], 

240 parent_object: object | None, 

241) -> AutotestHook | None: 

242 target_func = _as_function(func) 

243 parent_class = parent_object.__class__ if parent_object is not None else None 

244 

245 if parent_class is not None: 

246 parent_hook = _HOOKS.get((parent_class, target_func)) 

247 if parent_hook is not None: 

248 return parent_hook 

249 

250 return _HOOKS.get((None, target_func)) 

251 

252 

253def find_autotest_params_provider( 

254 func: Callable[..., Any], 

255 parent_object: object | None, 

256) -> AutotestParamProvider | None: 

257 target_func = _as_function(func) 

258 parent_class = parent_object.__class__ if parent_object is not None else None 

259 

260 if parent_class is not None: 

261 parent_provider = _PARAM_PROVIDERS.get((parent_class, target_func)) 

262 if parent_provider is not None: 

263 return parent_provider 

264 

265 return _PARAM_PROVIDERS.get((None, target_func)) 

266 

267 

268def find_autotest_policy( 

269 func: Callable[..., Any], 

270 parent_object: object | None, 

271) -> AutotestCasePolicy: 

272 target_func = _as_function(func) 

273 parent_class = parent_object.__class__ if parent_object is not None else None 

274 

275 if parent_class is not None: 

276 parent_policy = _CASE_POLICIES.get((parent_class, target_func)) 

277 if parent_policy is not None: 

278 return parent_policy 

279 

280 return _CASE_POLICIES.get((None, target_func), AutotestCasePolicy()) 

281 

282 

283def find_autotest_hook_dependencies( 

284 func: Callable[..., Any], 

285 parent_object: object | None, 

286) -> tuple[AutotestFunction, ...]: 

287 hook = find_autotest_hook(func, parent_object) 

288 if hook is None: 

289 return () 

290 return _get_callback_dependencies(hook) 

291 

292 

293def find_autotest_params_dependencies( 

294 func: Callable[..., Any], 

295 parent_object: object | None, 

296) -> tuple[AutotestFunction, ...]: 

297 provider = find_autotest_params_provider(func, parent_object) 

298 if provider is None: 

299 return () 

300 return _get_callback_dependencies(provider) 

301 

302 

303def discover_autotest_methods(api: object) -> list[AutotestMethodCase]: 

304 cases: list[AutotestMethodCase] = [] 

305 visited: set[int] = set() 

306 

307 def walk(owner: object, parent: object | None) -> None: 

308 owner_id = id(owner) 

309 if owner_id in visited: 

310 return 

311 visited.add(owner_id) 

312 

313 for attr_name in sorted(dir(owner)): 

314 if attr_name.startswith("_"): 

315 continue 

316 

317 try: 

318 value = getattr(owner, attr_name) 

319 except Exception: 

320 continue 

321 

322 bound_method = _as_bound_method(value) 

323 if bound_method is not None: 

324 func = _as_function(bound_method) 

325 if _is_autotest(func): 

326 required_parameters = _required_parameters(bound_method) 

327 provider = find_autotest_params_provider(func, parent) 

328 policy = find_autotest_policy(func, parent) 

329 dependencies = _merge_dependencies( 

330 policy.depends_on, 

331 find_autotest_params_dependencies(func, parent), 

332 find_autotest_hook_dependencies(func, parent), 

333 ) 

334 if required_parameters and provider is None: 

335 joined = ", ".join(required_parameters) 

336 raise TypeError( 

337 f"Autotest method {func.__qualname__} requires arguments ({joined}). " 

338 "Register @autotest_params(target=...) for this method." 

339 ) 

340 cases.append( 

341 AutotestMethodCase( 

342 owner=owner, 

343 parent=parent, 

344 method=cast(Callable[..., Awaitable[Any]], bound_method), 

345 func=func, 

346 required_parameters=required_parameters, 

347 depends_on=dependencies, 

348 ) 

349 ) 

350 continue 

351 

352 if _is_resource_object(value): 

353 walk(value, owner) 

354 

355 walk(api, None) 

356 return _order_cases(cases) 

357 

358 

359async def execute_autotests( 

360 api: object, 

361 schemashot: Any, 

362 *, 

363 typecheck_mode: AutotestTypecheckMode | str = "off", 

364 trace_limit: int = 3, 

365 truncation_context_lines: int = 3, 

366 case_status_recorder: Callable[[str, str], None] | None = None, 

367 success_recorder: Callable[[str], None] | None = None, 

368) -> int: 

369 _validate_schemashot(schemashot) 

370 resolved_typecheck_mode = _normalize_typecheck_mode(typecheck_mode) 

371 

372 executed_count = 0 

373 state, completed_funcs, skipped_funcs = _initialize_runtime_state() 

374 

375 cases = discover_autotest_methods(api) 

376 for case in cases: 

377 label = case.func.__qualname__ 

378 missing_dependencies = tuple(dep for dep in case.depends_on if dep not in completed_funcs) 

379 if missing_dependencies: 

380 skipped_funcs.add(case.func) 

381 _record_case_status(case_status_recorder, label, "skipped") 

382 continue 

383 

384 try: 

385 await execute_autotest_case( 

386 case=case, 

387 api=api, 

388 schemashot=schemashot, 

389 state=state, 

390 typecheck_mode=resolved_typecheck_mode, 

391 trace_limit=trace_limit, 

392 truncation_context_lines=truncation_context_lines, 

393 success_recorder=success_recorder, 

394 ) 

395 except BaseException as error: # pragma: no cover - runtime-only branch for skip semantics 

396 if _is_pytest_skip_exception(error): 

397 skipped_funcs.add(case.func) 

398 _record_case_status(case_status_recorder, label, "skipped") 

399 continue 

400 _record_case_status(case_status_recorder, label, "failed") 

401 raise 

402 

403 completed_funcs.add(case.func) 

404 executed_count += 1 

405 _record_case_status(case_status_recorder, label, "passed") 

406 

407 executed_count += await execute_autotest_data_cases( 

408 api=api, 

409 schemashot=schemashot, 

410 state=state, 

411 case_status_recorder=case_status_recorder, 

412 ) 

413 return executed_count 

414 

415 

416async def execute_autotests_with_subtests( 

417 api: object, 

418 schemashot: Any, 

419 *, 

420 subtests: AutotestSubtests, 

421 typecheck_mode: AutotestTypecheckMode | str = "off", 

422 trace_limit: int = 3, 

423 truncation_context_lines: int = 3, 

424 case_status_recorder: Callable[[str, str], None] | None = None, 

425 success_recorder: Callable[[str], None] | None = None, 

426) -> int: 

427 _validate_schemashot(schemashot) 

428 resolved_typecheck_mode = _normalize_typecheck_mode(typecheck_mode) 

429 

430 processed_count = 0 

431 state, completed_funcs, skipped_funcs = _initialize_runtime_state() 

432 

433 cases = discover_autotest_methods(api) 

434 for case in cases: 

435 processed_count += 1 

436 label = case.func.__qualname__ 

437 case_succeeded = False 

438 with subtests.test(**_subtest_label_for_case(case)): 

439 missing_dependencies = tuple( 

440 dep for dep in case.depends_on if dep not in completed_funcs 

441 ) 

442 if missing_dependencies: 

443 skipped_funcs.add(case.func) 

444 _record_case_status(case_status_recorder, label, "skipped") 

445 _skip_current_case(_format_dependency_skip_reason(missing_dependencies)) 

446 

447 try: 

448 await execute_autotest_case( 

449 case=case, 

450 api=api, 

451 schemashot=schemashot, 

452 state=state, 

453 typecheck_mode=resolved_typecheck_mode, 

454 trace_limit=trace_limit, 

455 truncation_context_lines=truncation_context_lines, 

456 success_recorder=success_recorder, 

457 ) 

458 except BaseException as error: # pragma: no cover - runtime-only skip branch 

459 if _is_pytest_skip_exception(error): 

460 skipped_funcs.add(case.func) 

461 _record_case_status(case_status_recorder, label, "skipped") 

462 else: 

463 _record_case_status(case_status_recorder, label, "failed") 

464 raise 

465 

466 case_succeeded = True 

467 _record_case_status(case_status_recorder, label, "passed") 

468 

469 if case_succeeded: 

470 completed_funcs.add(case.func) 

471 

472 processed_count += await _execute_autotest_data_cases_with_subtests( 

473 api=api, 

474 schemashot=schemashot, 

475 state=state, 

476 subtests=subtests, 

477 case_status_recorder=case_status_recorder, 

478 ) 

479 return processed_count 

480 

481 

482async def execute_autotest_case( 

483 *, 

484 case: AutotestMethodCase, 

485 api: object, 

486 schemashot: Any, 

487 state: dict[str, Any] | None = None, 

488 typecheck_mode: AutotestTypecheckMode | str = "off", 

489 trace_limit: int = 3, 

490 truncation_context_lines: int = 3, 

491 success_recorder: Callable[[str], None] | None = None, 

492) -> None: 

493 _validate_schemashot(schemashot) 

494 resolved_typecheck_mode = _normalize_typecheck_mode(typecheck_mode) 

495 runtime_state = state if state is not None else {} 

496 crash_error: BaseException | None = None 

497 invocation = await _resolve_invocation( 

498 case=case, 

499 api=api, 

500 schemashot=schemashot, 

501 state=runtime_state, 

502 typecheck_mode=resolved_typecheck_mode, 

503 trace_limit=trace_limit, 

504 truncation_context_lines=truncation_context_lines, 

505 ) 

506 

507 try: 

508 response = await _invoke_method(case.method, case.func, invocation) 

509 except BaseException as error: # pragma: no cover - runtime-only branch for crash reporting 

510 if isinstance(error, (KeyboardInterrupt, SystemExit)): 

511 raise 

512 if _is_pytest_skip_exception(error): 

513 raise 

514 crash_error = error 

515 

516 if crash_error is not None: 

517 raise_autotest_method_crash( 

518 api=api, 

519 func=case.func, 

520 error=crash_error, 

521 source_func=case.func, 

522 trace_limit=trace_limit, 

523 truncation_context_lines=truncation_context_lines, 

524 ) 

525 

526 ctx = AutotestContext( 

527 api=api, 

528 owner=case.owner, 

529 parent=case.parent, 

530 method=case.method, 

531 func=case.func, 

532 schemashot=schemashot, 

533 state=runtime_state, 

534 ) 

535 

536 response_json_error: BaseException | None = None 

537 response_json_detail_message: str | None = None 

538 hook = find_autotest_hook(case.func, case.parent) 

539 if hook is not None: 

540 try: 

541 if not hasattr(response, "json") or not callable(response.json): 

542 response_json_error = TypeError( 

543 f"Autotest method {case.func.__qualname__} must return an object with json()." 

544 ) 

545 response_json_detail_message = _format_autotest_wrong_object_message(response) 

546 else: 

547 data = response.json() 

548 except BaseException as error: # pragma: no cover - runtime-only branch for crash reporting 

549 if isinstance(error, (KeyboardInterrupt, SystemExit)): 

550 raise 

551 if _is_pytest_skip_exception(error): 

552 raise 

553 response_json_error = error 

554 response_json_detail_message = _format_autotest_invalid_json_message() 

555 

556 if response_json_error is not None: 

557 raise_autotest_method_output_error( 

558 api=api, 

559 func=case.func, 

560 error=response_json_error, 

561 detail_message=response_json_detail_message, 

562 trace_limit=trace_limit, 

563 truncation_context_lines=truncation_context_lines, 

564 ) 

565 

566 hook_error: BaseException | None = None 

567 hook_result: Any = None 

568 try: 

569 hook_result = hook(response, data, ctx) 

570 if inspect.isawaitable(hook_result): 

571 hook_result = await hook_result 

572 except BaseException as error: # pragma: no cover - runtime-only branch for crash reporting 

573 if isinstance(error, (KeyboardInterrupt, SystemExit)): 

574 raise 

575 if _is_pytest_skip_exception(error): 

576 raise 

577 hook_error = error 

578 if hook_error is not None: 

579 raise_autotest_hook_crash( 

580 api=api, 

581 hook=hook, 

582 error=hook_error, 

583 source_func=hook, 

584 trace_limit=trace_limit, 

585 truncation_context_lines=truncation_context_lines, 

586 ) 

587 if hook_result is not None: 

588 data = hook_result 

589 else: 

590 try: 

591 if not hasattr(response, "json") or not callable(response.json): 

592 response_json_error = TypeError( 

593 f"Autotest method {case.func.__qualname__} must return an object with json()." 

594 ) 

595 response_json_detail_message = _format_autotest_wrong_object_message(response) 

596 else: 

597 data = response.json() 

598 except BaseException as error: # pragma: no cover - runtime-only branch for crash reporting 

599 if isinstance(error, (KeyboardInterrupt, SystemExit)): 

600 raise 

601 if _is_pytest_skip_exception(error): 

602 raise 

603 response_json_error = error 

604 response_json_detail_message = _format_autotest_invalid_json_message() 

605 

606 if response_json_error is not None: 

607 raise_autotest_method_output_error( 

608 api=api, 

609 func=case.func, 

610 error=response_json_error, 

611 source_func=case.func, 

612 detail_message=response_json_detail_message, 

613 trace_limit=trace_limit, 

614 truncation_context_lines=truncation_context_lines, 

615 ) 

616 

617 schemashot.assert_json_match(data, case.func) 

618 if success_recorder is not None: 

619 success_recorder(case.func.__qualname__) 

620 

621 

622async def execute_autotest_data_cases( 

623 *, 

624 api: object, 

625 schemashot: Any, 

626 state: dict[str, Any] | None = None, 

627 case_status_recorder: Callable[[str, str], None] | None = None, 

628) -> int: 

629 _validate_schemashot(schemashot) 

630 runtime_state = state if state is not None else {} 

631 ctx = AutotestDataContext(api=api, schemashot=schemashot, state=runtime_state) 

632 

633 for case in list(_DATA_CASES): 

634 label = _snapshot_name_repr(case.name) 

635 try: 

636 payload = case.provider(ctx) 

637 if inspect.isawaitable(payload): 

638 payload = await payload 

639 schemashot.assert_json_match(payload, case.name) 

640 except BaseException as error: # pragma: no cover - runtime-only branch for skip semantics 

641 if isinstance(error, (KeyboardInterrupt, SystemExit)): 

642 raise 

643 if _is_pytest_skip_exception(error): 

644 _record_case_status(case_status_recorder, label, "skipped") 

645 raise 

646 _record_case_status(case_status_recorder, label, "failed") 

647 raise 

648 

649 _record_case_status(case_status_recorder, label, "passed") 

650 

651 return len(_DATA_CASES) 

652 

653 

654async def _execute_autotest_data_cases_with_subtests( 

655 *, 

656 api: object, 

657 schemashot: Any, 

658 state: dict[str, Any], 

659 subtests: AutotestSubtests, 

660 case_status_recorder: Callable[[str, str], None] | None = None, 

661) -> int: 

662 _validate_schemashot(schemashot) 

663 ctx = AutotestDataContext(api=api, schemashot=schemashot, state=state) 

664 

665 processed_count = 0 

666 for case in list(_DATA_CASES): 

667 processed_count += 1 

668 label = _snapshot_name_repr(case.name) 

669 with subtests.test(**_subtest_label_for_data(case)): 

670 try: 

671 payload = case.provider(ctx) 

672 if inspect.isawaitable(payload): 

673 payload = await payload 

674 schemashot.assert_json_match(payload, case.name) 

675 except BaseException as error: # pragma: no cover - runtime-only skip branch 

676 if _is_pytest_skip_exception(error): 

677 _record_case_status(case_status_recorder, label, "skipped") 

678 else: 

679 _record_case_status(case_status_recorder, label, "failed") 

680 raise 

681 

682 _record_case_status(case_status_recorder, label, "passed") 

683 

684 return processed_count 

685 

686 

687def _is_autotest(func: AutotestFunction) -> bool: 

688 return bool(getattr(func, _AUTOTEST_ATTR, False)) 

689 

690 

691def _as_function(target: Callable[..., Any]) -> AutotestFunction: 

692 unbound: Any = target.__func__ if inspect.ismethod(target) else target 

693 if not callable(unbound): 

694 raise TypeError("Target must be a function or method.") 

695 return cast(AutotestFunction, inspect.unwrap(unbound)) 

696 

697 

698def _as_bound_method(value: Any) -> Callable[..., Any] | None: 

699 if inspect.ismethod(value) and value.__self__ is not None: 

700 return cast(Callable[..., Any], value) 

701 return None 

702 

703 

704def _is_resource_object(value: Any) -> bool: 

705 if value is None: 

706 return False 

707 if isinstance(value, _PrimitiveTypes): 

708 return False 

709 if isinstance(value, (dict, list, tuple, set, frozenset)): 

710 return False 

711 if inspect.ismodule(value) or inspect.isclass(value) or inspect.isfunction(value): 

712 return False 

713 if inspect.ismethod(value) or inspect.isbuiltin(value): 

714 return False 

715 return hasattr(value, "__dict__") or hasattr(value, "__slots__") 

716 

717 

718def _order_cases(cases: list[AutotestMethodCase]) -> list[AutotestMethodCase]: 

719 if len(cases) < 2: 

720 return cases 

721 

722 index_by_func: dict[AutotestFunction, list[int]] = {} 

723 for index, case in enumerate(cases): 

724 index_by_func.setdefault(case.func, []).append(index) 

725 

726 edges: dict[int, set[int]] = {index: set() for index in range(len(cases))} 

727 indegree = [0] * len(cases) 

728 

729 for target_index, case in enumerate(cases): 

730 for dependency in case.depends_on: 

731 for source_index in index_by_func.get(dependency, []): 

732 if source_index == target_index: 

733 continue 

734 if target_index in edges[source_index]: 

735 continue 

736 edges[source_index].add(target_index) 

737 indegree[target_index] += 1 

738 

739 queue: list[tuple[str, int]] = [] 

740 for index, case in enumerate(cases): 

741 if indegree[index] == 0: 

742 heapq.heappush(queue, (case.func.__qualname__, index)) 

743 

744 ordered: list[AutotestMethodCase] = [] 

745 while queue: 

746 _, current_index = heapq.heappop(queue) 

747 ordered.append(cases[current_index]) 

748 for dependent_index in edges[current_index]: 

749 indegree[dependent_index] -= 1 

750 if indegree[dependent_index] == 0: 

751 heapq.heappush( 

752 queue, 

753 (cases[dependent_index].func.__qualname__, dependent_index), 

754 ) 

755 

756 if len(ordered) == len(cases): 

757 return ordered 

758 

759 for index, case in enumerate(cases): 

760 if indegree[index] > 0: 

761 ordered.append(case) 

762 return ordered 

763 

764 

765def _required_parameters(method: Callable[..., Any]) -> tuple[str, ...]: 

766 required_arguments: list[str] = [] 

767 signature = inspect.signature(method) 

768 for parameter in signature.parameters.values(): 

769 if ( 

770 parameter.kind 

771 in ( 

772 inspect.Parameter.POSITIONAL_ONLY, 

773 inspect.Parameter.POSITIONAL_OR_KEYWORD, 

774 inspect.Parameter.KEYWORD_ONLY, 

775 ) 

776 and parameter.default is inspect.Signature.empty 

777 ): 

778 required_arguments.append(parameter.name) 

779 

780 return tuple(required_arguments) 

781 

782 

783def _normalize_depends_on( 

784 depends_on: Sequence[Callable[..., Any]] | None, 

785) -> tuple[AutotestFunction, ...]: 

786 if depends_on is None: 

787 return () 

788 

789 if not isinstance(depends_on, Sequence): 

790 raise TypeError("depends_on must be a sequence of functions or methods.") 

791 

792 return tuple(_as_function(target) for target in depends_on) 

793 

794 

795def _register_case_policy( 

796 *, 

797 parent: type[object] | None, 

798 target_func: AutotestFunction, 

799 depends_on: Iterable[AutotestFunction] = (), 

800) -> None: 

801 current = _CASE_POLICIES.get((parent, target_func), AutotestCasePolicy()) 

802 resolved_depends = tuple(depends_on) if depends_on else current.depends_on 

803 _CASE_POLICIES[(parent, target_func)] = AutotestCasePolicy( 

804 depends_on=resolved_depends, 

805 ) 

806 

807 

808def _get_callback_dependencies(callback: Callable[..., Any]) -> tuple[AutotestFunction, ...]: 

809 raw = getattr(callback, _DEPENDS_ON_ATTR, ()) 

810 if not isinstance(raw, tuple): 

811 return () 

812 return tuple(_as_function(dep) for dep in raw) 

813 

814 

815def _merge_dependencies( 

816 *sources: Iterable[AutotestFunction], 

817) -> tuple[AutotestFunction, ...]: 

818 merged: list[AutotestFunction] = [] 

819 for source in sources: 

820 for dependency in source: 

821 if dependency not in merged: 

822 merged.append(dependency) 

823 return tuple(merged) 

824 

825 

826def _initialize_runtime_state() -> tuple[ 

827 dict[str, Any], 

828 set[AutotestFunction], 

829 set[AutotestFunction], 

830]: 

831 state: dict[str, Any] = {} 

832 completed_funcs: set[AutotestFunction] = set() 

833 skipped_funcs: set[AutotestFunction] = set() 

834 state["autotest_completed_funcs"] = completed_funcs 

835 state["autotest_skipped_funcs"] = skipped_funcs 

836 return state, completed_funcs, skipped_funcs 

837 

838 

839def _subtest_label_for_case(case: AutotestMethodCase) -> dict[str, str]: 

840 return {"msg": case.func.__qualname__} 

841 

842 

843def _subtest_label_for_data(case: AutotestDataCase) -> dict[str, str]: 

844 return {"msg": _snapshot_name_repr(case.name)} 

845 

846 

847def _snapshot_name_repr(name: object) -> str: 

848 if isinstance(name, (str, int)): 

849 return str(name) 

850 if callable(name): 

851 return getattr(name, "__qualname__", repr(name)) 

852 if isinstance(name, list): 

853 return "[" + ", ".join(_snapshot_name_repr(item) for item in name) + "]" 

854 return repr(name) 

855 

856 

857def _record_case_status( 

858 recorder: Callable[[str, str], None] | None, 

859 label: str, 

860 status: str, 

861) -> None: 

862 if recorder is not None: 

863 recorder(label, status) 

864 

865 

866def _format_dependency_skip_reason(missing: tuple[AutotestFunction, ...]) -> str: 

867 names = ", ".join(dep.__qualname__ for dep in missing) 

868 return f"Dependency was not executed: {names}" 

869 

870 

871def _skip_current_case(reason: str) -> None: 

872 try: 

873 import pytest 

874 except Exception as error: # pragma: no cover - pytest runtime always has pytest installed 

875 raise RuntimeError("pytest is required to skip autotest subtest cases.") from error 

876 

877 pytest.skip(reason) 

878 

879 

880def _is_pytest_skip_exception(error: BaseException) -> bool: 

881 try: 

882 import pytest 

883 except Exception: 

884 return False 

885 

886 skip_exception = getattr(pytest.skip, "Exception", None) 

887 return bool(skip_exception and isinstance(error, skip_exception)) 

888 

889 

890def _validate_schemashot(schemashot: Any) -> None: 

891 if not hasattr(schemashot, "assert_json_match") or not callable(schemashot.assert_json_match): 

892 raise TypeError( 

893 "schemashot fixture must provide a callable assert_json_match(data, name) method." 

894 ) 

895 

896 

897async def _invoke_method( 

898 method: Callable[..., Any], 

899 func: AutotestFunction, 

900 invocation: AutotestInvocation, 

901) -> Any: 

902 result = method(*invocation.args, **invocation.kwargs) 

903 if not inspect.isawaitable(result): 

904 raise TypeError(f"Autotest method {func.__qualname__} must be async.") 

905 return await result 

906 

907 

908async def _resolve_invocation( 

909 *, 

910 case: AutotestMethodCase, 

911 api: object, 

912 schemashot: Any, 

913 state: dict[str, Any], 

914 typecheck_mode: AutotestTypecheckMode, 

915 trace_limit: int = 3, 

916 truncation_context_lines: int = 3, 

917) -> AutotestInvocation: 

918 provider = find_autotest_params_provider(case.func, case.parent) 

919 if provider is None: 

920 invocation = AutotestInvocation() 

921 else: 

922 ctx = AutotestCallContext( 

923 api=api, 

924 owner=case.owner, 

925 parent=case.parent, 

926 method=case.method, 

927 func=case.func, 

928 schemashot=schemashot, 

929 state=state, 

930 ) 

931 try: 

932 raw = provider(ctx) 

933 if inspect.isawaitable(raw): 

934 raw = await raw 

935 invocation = _normalize_invocation(raw, case.func) 

936 except BaseException as error: # pragma: no cover - runtime-only branch for crash reporting 

937 if isinstance(error, (KeyboardInterrupt, SystemExit)): 

938 raise 

939 if _is_pytest_skip_exception(error): 

940 raise 

941 raise_autotest_params_crash( 

942 api=api, 

943 params_provider=provider, 

944 error=error, 

945 trace_limit=trace_limit, 

946 truncation_context_lines=truncation_context_lines, 

947 ) 

948 

949 _validate_invocation( 

950 case.method, 

951 case.func, 

952 invocation, 

953 typecheck_mode=typecheck_mode, 

954 ) 

955 return invocation 

956 

957 

958def _normalize_invocation(raw: Any, func: AutotestFunction) -> AutotestInvocation: 

959 if raw is None: 

960 return AutotestInvocation() 

961 if isinstance(raw, AutotestInvocation): 

962 return AutotestInvocation(args=tuple(raw.args), kwargs=dict(raw.kwargs)) 

963 if isinstance(raw, dict): 

964 return AutotestInvocation(kwargs=dict(raw)) 

965 if isinstance(raw, (tuple, list)): 

966 return AutotestInvocation(args=tuple(raw)) 

967 

968 raise TypeError( 

969 f"autotest_params provider for {func.__qualname__} must return one of: " 

970 "None, dict (kwargs), tuple/list (args), AutotestInvocation." 

971 ) 

972 

973 

974def _validate_invocation( 

975 method: Callable[..., Any], 

976 func: AutotestFunction, 

977 invocation: AutotestInvocation, 

978 *, 

979 typecheck_mode: AutotestTypecheckMode = "off", 

980) -> None: 

981 signature = inspect.signature(method) 

982 try: 

983 bound_arguments = signature.bind(*invocation.args, **invocation.kwargs) 

984 except TypeError as error: 

985 raise TypeError(f"Invalid invocation for {func.__qualname__}: {error}") from error 

986 _validate_invocation_types( 

987 signature=signature, 

988 bound_arguments=bound_arguments.arguments, 

989 method=method, 

990 func=func, 

991 typecheck_mode=typecheck_mode, 

992 ) 

993 

994 

995def _normalize_typecheck_mode(mode: AutotestTypecheckMode | str) -> AutotestTypecheckMode: 

996 if not isinstance(mode, str): 

997 raise TypeError("autotest typecheck mode must be a string.") 

998 normalized = mode.strip().lower() 

999 if normalized not in _VALID_TYPECHECK_MODES: 

1000 expected = ", ".join(sorted(_VALID_TYPECHECK_MODES)) 

1001 raise ValueError(f"autotest typecheck mode must be one of: {expected}.") 

1002 return cast(AutotestTypecheckMode, normalized) 

1003 

1004 

1005def _validate_invocation_types( 

1006 *, 

1007 signature: inspect.Signature, 

1008 bound_arguments: Mapping[str, Any], 

1009 method: Callable[..., Any], 

1010 func: AutotestFunction, 

1011 typecheck_mode: AutotestTypecheckMode, 

1012) -> None: 

1013 if typecheck_mode == "off": 

1014 return 

1015 

1016 mismatches: list[str] = [] 

1017 for name, value in bound_arguments.items(): 

1018 parameter = signature.parameters.get(name) 

1019 if parameter is None: 

1020 continue 

1021 

1022 annotation = _resolve_annotation(parameter.annotation, method) 

1023 if annotation is inspect.Signature.empty: 

1024 continue 

1025 

1026 if _matches_annotation(value, annotation): 

1027 continue 

1028 

1029 expected = _format_annotation(annotation) 

1030 mismatches.append(f"parameter {name!r} expects {expected}, got {type(value).__name__}") 

1031 

1032 if not mismatches: 

1033 return 

1034 

1035 details = "; ".join(mismatches) 

1036 message = f"Invalid invocation types for {func.__qualname__}: {details}." 

1037 if typecheck_mode == "strict": 

1038 raise TypeError(message) 

1039 warnings.warn(message, RuntimeWarning, stacklevel=4) 

1040 

1041 

1042def _resolve_annotation(annotation: Any, method: Callable[..., Any]) -> Any: 

1043 if annotation is inspect.Signature.empty: 

1044 return annotation 

1045 

1046 if not isinstance(annotation, str): 

1047 return annotation 

1048 

1049 globals_dict = getattr(method, "__globals__", {}) 

1050 if not isinstance(globals_dict, dict): 

1051 return inspect.Signature.empty 

1052 

1053 try: 

1054 return eval(annotation, globals_dict, {}) 

1055 except Exception: 

1056 return inspect.Signature.empty 

1057 

1058 

1059def _matches_annotation(value: Any, annotation: Any) -> bool: 

1060 if annotation in (Any, object): 

1061 return True 

1062 if annotation in (None, type(None)): 

1063 return value is None 

1064 

1065 supertype = getattr(annotation, "__supertype__", None) 

1066 if supertype is not None: 

1067 return _matches_annotation(value, supertype) 

1068 

1069 origin = get_origin(annotation) 

1070 if origin is None: 

1071 if isinstance(annotation, type): 

1072 return isinstance(value, annotation) 

1073 return True 

1074 

1075 if origin in (types.UnionType, Union): 

1076 return any(_matches_annotation(value, arg) for arg in get_args(annotation)) 

1077 

1078 if origin is Annotated: 

1079 args = get_args(annotation) 

1080 if not args: 

1081 return True 

1082 return _matches_annotation(value, args[0]) 

1083 

1084 if origin is Literal: 

1085 return any(value == option for option in get_args(annotation)) 

1086 

1087 if origin is list: 

1088 return _matches_iterable(value, annotation, list) 

1089 if origin is set: 

1090 return _matches_iterable(value, annotation, set) 

1091 if origin is frozenset: 

1092 return _matches_iterable(value, annotation, frozenset) 

1093 if origin is tuple: 

1094 return _matches_tuple(value, annotation) 

1095 if origin is dict: 

1096 return _matches_mapping(value, annotation) 

1097 

1098 if isinstance(origin, type): 

1099 return isinstance(value, origin) 

1100 

1101 return True 

1102 

1103 

1104def _matches_iterable(value: Any, annotation: Any, container_type: type[object]) -> bool: 

1105 if not isinstance(value, container_type): 

1106 return False 

1107 args = get_args(annotation) 

1108 if not args: 

1109 return True 

1110 item_type = args[0] 

1111 iterable_value = cast(Iterable[Any], value) 

1112 return all(_matches_annotation(item, item_type) for item in iterable_value) 

1113 

1114 

1115def _matches_mapping(value: Any, annotation: Any) -> bool: 

1116 if not isinstance(value, dict): 

1117 return False 

1118 key_type, value_type = (Any, Any) 

1119 args = get_args(annotation) 

1120 if len(args) == 2: 

1121 key_type, value_type = args 

1122 

1123 for key, item in value.items(): 

1124 if not _matches_annotation(key, key_type): 

1125 return False 

1126 if not _matches_annotation(item, value_type): 

1127 return False 

1128 return True 

1129 

1130 

1131def _matches_tuple(value: Any, annotation: Any) -> bool: 

1132 if not isinstance(value, tuple): 

1133 return False 

1134 args = get_args(annotation) 

1135 if not args: 

1136 return True 

1137 if len(args) == 2 and args[1] is Ellipsis: 

1138 return all(_matches_annotation(item, args[0]) for item in value) 

1139 if len(args) != len(value): 

1140 return False 

1141 return all(_matches_annotation(item, item_type) for item, item_type in zip(value, args)) 

1142 

1143 

1144def _format_annotation(annotation: Any) -> str: 

1145 try: 

1146 return inspect.formatannotation(annotation) 

1147 except Exception: 

1148 return str(annotation) 

1149 

1150 

1151__all__ = [ 

1152 "AutotestCallContext", 

1153 "AutotestContext", 

1154 "AutotestSubtests", 

1155 "AutotestDataContext", 

1156 "AutotestDataCase", 

1157 "AutotestInvocation", 

1158 "AutotestMethodCase", 

1159 "AutotestCasePolicy", 

1160 "AutotestTypecheckMode", 

1161 "autotest", 

1162 "autotest_depends_on", 

1163 "autotest_data", 

1164 "autotest_hook", 

1165 "autotest_policy", 

1166 "autotest_params", 

1167 "clear_autotest_hooks", 

1168 "discover_autotest_methods", 

1169 "execute_autotest_case", 

1170 "execute_autotest_data_cases", 

1171 "execute_autotests", 

1172 "execute_autotests_with_subtests", 

1173 "find_autotest_hook", 

1174 "find_autotest_hook_dependencies", 

1175 "find_autotest_policy", 

1176 "find_autotest_params_dependencies", 

1177 "find_autotest_params_provider", 

1178]