Coverage for human_requests / autotest.py: 80%

535 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-03-07 17:38 +0000

1from __future__ import annotations 

2 

3import heapq 

4import inspect 

5import types 

6import warnings 

7from collections.abc import Awaitable, Callable, Iterable, Mapping, Sequence 

8from dataclasses import dataclass, field 

9from typing import ( 

10 Annotated, 

11 Any, 

12 ContextManager, 

13 Literal, 

14 Protocol, 

15 TypeVar, 

16 Union, 

17 cast, 

18 get_args, 

19 get_origin, 

20) 

21 

22AutotestFunction = Callable[..., Any] 

23AutotestHook = Callable[[Any, Any, "AutotestContext"], Any] 

24AutotestParamProvider = Callable[["AutotestCallContext"], Any] 

25AutotestDataProvider = Callable[["AutotestDataContext"], Any] 

26AutotestTypecheckMode = Literal["off", "warn", "strict"] 

27SnapshotName = str | int | Callable[..., Any] | list[str | int | Callable[..., Any]] 

28HookKey = tuple[type[object] | None, AutotestFunction] 

29DependencyMarker = Callable[..., Any] 

30 

31_AUTOTEST_ATTR = "__autotest__" 

32_DEPENDS_ON_ATTR = "__autotest_depends_on__" 

33_HOOKS: dict[HookKey, AutotestHook] = {} 

34_PARAM_PROVIDERS: dict[HookKey, AutotestParamProvider] = {} 

35_CASE_POLICIES: dict[HookKey, "AutotestCasePolicy"] = {} 

36_DATA_CASES: list["AutotestDataCase"] = [] 

37_VALID_TYPECHECK_MODES: frozenset[str] = frozenset({"off", "warn", "strict"}) 

38 

39_PrimitiveTypes = (str, bytes, bytearray, bool, int, float, complex, range, memoryview) 

40 

41F = TypeVar("F", bound=Callable[..., Any]) 

42 

43 

44@dataclass(frozen=True) 

45class AutotestContext: 

46 api: object 

47 owner: object 

48 parent: object | None 

49 method: Callable[..., Awaitable[Any]] 

50 func: AutotestFunction 

51 schemashot: Any 

52 state: dict[str, Any] 

53 

54 

55@dataclass(frozen=True) 

56class AutotestMethodCase: 

57 owner: object 

58 parent: object | None 

59 method: Callable[..., Awaitable[Any]] 

60 func: AutotestFunction 

61 required_parameters: tuple[str, ...] 

62 depends_on: tuple[AutotestFunction, ...] 

63 

64 

65@dataclass(frozen=True) 

66class AutotestCallContext: 

67 api: object 

68 owner: object 

69 parent: object | None 

70 method: Callable[..., Awaitable[Any]] 

71 func: AutotestFunction 

72 schemashot: Any 

73 state: dict[str, Any] 

74 

75 

76@dataclass(frozen=True) 

77class AutotestDataContext: 

78 api: object 

79 schemashot: Any 

80 state: dict[str, Any] 

81 

82 

83@dataclass(frozen=True) 

84class AutotestInvocation: 

85 args: tuple[Any, ...] = () 

86 kwargs: dict[str, Any] = field(default_factory=dict) 

87 

88 

89@dataclass(frozen=True) 

90class AutotestDataCase: 

91 name: SnapshotName 

92 provider: AutotestDataProvider 

93 

94 

95@dataclass(frozen=True) 

96class AutotestCasePolicy: 

97 depends_on: tuple[AutotestFunction, ...] = () 

98 

99 

100class AutotestSubtests(Protocol): 

101 def test(self, msg: str | None = None, **kwargs: Any) -> ContextManager[None]: ... 

102 

103 

104def autotest(func: F) -> F: 

105 setattr(func, _AUTOTEST_ATTR, True) 

106 return func 

107 

108 

109def autotest_depends_on( 

110 target: Callable[..., Any], 

111) -> Callable[[DependencyMarker], DependencyMarker]: 

112 dependency = _as_function(target) 

113 

114 def decorator(callback: DependencyMarker) -> DependencyMarker: 

115 existing = _get_callback_dependencies(callback) 

116 if dependency in existing: 

117 return callback 

118 setattr(callback, _DEPENDS_ON_ATTR, (*existing, dependency)) 

119 return callback 

120 

121 return decorator 

122 

123 

124def autotest_hook( 

125 *, 

126 target: Callable[..., Any], 

127 parent: type[object] | None = None, 

128) -> Callable[[AutotestHook], AutotestHook]: 

129 if parent is not None and not inspect.isclass(parent): 

130 raise TypeError("autotest_hook parent must be a class or None.") 

131 

132 target_func = _as_function(target) 

133 

134 def decorator(hook: AutotestHook) -> AutotestHook: 

135 _HOOKS[(parent, target_func)] = hook 

136 return hook 

137 

138 return decorator 

139 

140 

141def autotest_params( 

142 *, 

143 target: Callable[..., Any], 

144 parent: type[object] | None = None, 

145 depends_on: Sequence[Callable[..., Any]] | None = None, 

146) -> Callable[[AutotestParamProvider], AutotestParamProvider]: 

147 if parent is not None and not inspect.isclass(parent): 

148 raise TypeError("autotest_params parent must be a class or None.") 

149 

150 target_func = _as_function(target) 

151 depends_on_funcs = _normalize_depends_on(depends_on) 

152 

153 def decorator(provider: AutotestParamProvider) -> AutotestParamProvider: 

154 _PARAM_PROVIDERS[(parent, target_func)] = provider 

155 if depends_on_funcs: 

156 _register_case_policy( 

157 parent=parent, 

158 target_func=target_func, 

159 depends_on=depends_on_funcs, 

160 ) 

161 return provider 

162 

163 return decorator 

164 

165 

166def autotest_policy( 

167 *, 

168 target: Callable[..., Any], 

169 parent: type[object] | None = None, 

170 depends_on: Sequence[Callable[..., Any]] | None = None, 

171) -> Callable[[F], F]: 

172 if parent is not None and not inspect.isclass(parent): 

173 raise TypeError("autotest_policy parent must be a class or None.") 

174 

175 target_func = _as_function(target) 

176 depends_on_funcs = _normalize_depends_on(depends_on) 

177 _register_case_policy( 

178 parent=parent, 

179 target_func=target_func, 

180 depends_on=depends_on_funcs, 

181 ) 

182 

183 def decorator(marker: F) -> F: 

184 return marker 

185 

186 return decorator 

187 

188 

189def autotest_data( 

190 *, 

191 name: SnapshotName, 

192) -> Callable[[AutotestDataProvider], AutotestDataProvider]: 

193 def decorator(provider: AutotestDataProvider) -> AutotestDataProvider: 

194 _DATA_CASES.append(AutotestDataCase(name=name, provider=provider)) 

195 return provider 

196 

197 return decorator 

198 

199 

200def clear_autotest_hooks() -> None: 

201 _HOOKS.clear() 

202 _PARAM_PROVIDERS.clear() 

203 _CASE_POLICIES.clear() 

204 _DATA_CASES.clear() 

205 

206 

207def find_autotest_hook( 

208 func: Callable[..., Any], 

209 parent_object: object | None, 

210) -> AutotestHook | None: 

211 target_func = _as_function(func) 

212 parent_class = parent_object.__class__ if parent_object is not None else None 

213 

214 if parent_class is not None: 

215 parent_hook = _HOOKS.get((parent_class, target_func)) 

216 if parent_hook is not None: 

217 return parent_hook 

218 

219 return _HOOKS.get((None, target_func)) 

220 

221 

222def find_autotest_params_provider( 

223 func: Callable[..., Any], 

224 parent_object: object | None, 

225) -> AutotestParamProvider | None: 

226 target_func = _as_function(func) 

227 parent_class = parent_object.__class__ if parent_object is not None else None 

228 

229 if parent_class is not None: 

230 parent_provider = _PARAM_PROVIDERS.get((parent_class, target_func)) 

231 if parent_provider is not None: 

232 return parent_provider 

233 

234 return _PARAM_PROVIDERS.get((None, target_func)) 

235 

236 

237def find_autotest_policy( 

238 func: Callable[..., Any], 

239 parent_object: object | None, 

240) -> AutotestCasePolicy: 

241 target_func = _as_function(func) 

242 parent_class = parent_object.__class__ if parent_object is not None else None 

243 

244 if parent_class is not None: 

245 parent_policy = _CASE_POLICIES.get((parent_class, target_func)) 

246 if parent_policy is not None: 

247 return parent_policy 

248 

249 return _CASE_POLICIES.get((None, target_func), AutotestCasePolicy()) 

250 

251 

252def find_autotest_hook_dependencies( 

253 func: Callable[..., Any], 

254 parent_object: object | None, 

255) -> tuple[AutotestFunction, ...]: 

256 hook = find_autotest_hook(func, parent_object) 

257 if hook is None: 

258 return () 

259 return _get_callback_dependencies(hook) 

260 

261 

262def find_autotest_params_dependencies( 

263 func: Callable[..., Any], 

264 parent_object: object | None, 

265) -> tuple[AutotestFunction, ...]: 

266 provider = find_autotest_params_provider(func, parent_object) 

267 if provider is None: 

268 return () 

269 return _get_callback_dependencies(provider) 

270 

271 

272def discover_autotest_methods(api: object) -> list[AutotestMethodCase]: 

273 cases: list[AutotestMethodCase] = [] 

274 visited: set[int] = set() 

275 

276 def walk(owner: object, parent: object | None) -> None: 

277 owner_id = id(owner) 

278 if owner_id in visited: 

279 return 

280 visited.add(owner_id) 

281 

282 for attr_name in sorted(dir(owner)): 

283 if attr_name.startswith("_"): 

284 continue 

285 

286 try: 

287 value = getattr(owner, attr_name) 

288 except Exception: 

289 continue 

290 

291 bound_method = _as_bound_method(value) 

292 if bound_method is not None: 

293 func = _as_function(bound_method) 

294 if _is_autotest(func): 

295 required_parameters = _required_parameters(bound_method) 

296 provider = find_autotest_params_provider(func, parent) 

297 policy = find_autotest_policy(func, parent) 

298 dependencies = _merge_dependencies( 

299 policy.depends_on, 

300 find_autotest_params_dependencies(func, parent), 

301 find_autotest_hook_dependencies(func, parent), 

302 ) 

303 if required_parameters and provider is None: 

304 joined = ", ".join(required_parameters) 

305 raise TypeError( 

306 f"Autotest method {func.__qualname__} requires arguments ({joined}). " 

307 "Register @autotest_params(target=...) for this method." 

308 ) 

309 cases.append( 

310 AutotestMethodCase( 

311 owner=owner, 

312 parent=parent, 

313 method=cast(Callable[..., Awaitable[Any]], bound_method), 

314 func=func, 

315 required_parameters=required_parameters, 

316 depends_on=dependencies, 

317 ) 

318 ) 

319 continue 

320 

321 if _is_resource_object(value): 

322 walk(value, owner) 

323 

324 walk(api, None) 

325 return _order_cases(cases) 

326 

327 

328async def execute_autotests( 

329 api: object, 

330 schemashot: Any, 

331 *, 

332 typecheck_mode: AutotestTypecheckMode | str = "off", 

333) -> int: 

334 _validate_schemashot(schemashot) 

335 resolved_typecheck_mode = _normalize_typecheck_mode(typecheck_mode) 

336 

337 executed_count = 0 

338 state, completed_funcs, skipped_funcs = _initialize_runtime_state() 

339 

340 cases = discover_autotest_methods(api) 

341 for case in cases: 

342 missing_dependencies = tuple(dep for dep in case.depends_on if dep not in completed_funcs) 

343 if missing_dependencies: 

344 skipped_funcs.add(case.func) 

345 continue 

346 

347 try: 

348 await execute_autotest_case( 

349 case=case, 

350 api=api, 

351 schemashot=schemashot, 

352 state=state, 

353 typecheck_mode=resolved_typecheck_mode, 

354 ) 

355 except BaseException as error: # pragma: no cover - runtime-only branch for skip semantics 

356 if _is_pytest_skip_exception(error): 

357 skipped_funcs.add(case.func) 

358 continue 

359 raise 

360 

361 completed_funcs.add(case.func) 

362 executed_count += 1 

363 

364 executed_count += await execute_autotest_data_cases(api=api, schemashot=schemashot, state=state) 

365 return executed_count 

366 

367 

368async def execute_autotests_with_subtests( 

369 api: object, 

370 schemashot: Any, 

371 *, 

372 subtests: AutotestSubtests, 

373 typecheck_mode: AutotestTypecheckMode | str = "off", 

374) -> int: 

375 _validate_schemashot(schemashot) 

376 resolved_typecheck_mode = _normalize_typecheck_mode(typecheck_mode) 

377 

378 processed_count = 0 

379 state, completed_funcs, skipped_funcs = _initialize_runtime_state() 

380 

381 cases = discover_autotest_methods(api) 

382 for case in cases: 

383 processed_count += 1 

384 case_succeeded = False 

385 with subtests.test(**_subtest_label_for_case(case)): 

386 missing_dependencies = tuple( 

387 dep for dep in case.depends_on if dep not in completed_funcs 

388 ) 

389 if missing_dependencies: 

390 skipped_funcs.add(case.func) 

391 _skip_current_case(_format_dependency_skip_reason(missing_dependencies)) 

392 

393 try: 

394 await execute_autotest_case( 

395 case=case, 

396 api=api, 

397 schemashot=schemashot, 

398 state=state, 

399 typecheck_mode=resolved_typecheck_mode, 

400 ) 

401 except BaseException as error: # pragma: no cover - runtime-only skip branch 

402 if _is_pytest_skip_exception(error): 

403 skipped_funcs.add(case.func) 

404 raise 

405 

406 case_succeeded = True 

407 

408 if case_succeeded: 

409 completed_funcs.add(case.func) 

410 

411 processed_count += await _execute_autotest_data_cases_with_subtests( 

412 api=api, 

413 schemashot=schemashot, 

414 state=state, 

415 subtests=subtests, 

416 ) 

417 return processed_count 

418 

419 

420async def execute_autotest_case( 

421 *, 

422 case: AutotestMethodCase, 

423 api: object, 

424 schemashot: Any, 

425 state: dict[str, Any] | None = None, 

426 typecheck_mode: AutotestTypecheckMode | str = "off", 

427) -> None: 

428 _validate_schemashot(schemashot) 

429 resolved_typecheck_mode = _normalize_typecheck_mode(typecheck_mode) 

430 runtime_state = state if state is not None else {} 

431 invocation = await _resolve_invocation( 

432 case=case, 

433 api=api, 

434 schemashot=schemashot, 

435 state=runtime_state, 

436 typecheck_mode=resolved_typecheck_mode, 

437 ) 

438 

439 response = await _invoke_method(case.method, case.func, invocation) 

440 

441 if not hasattr(response, "json") or not callable(response.json): 

442 raise TypeError( 

443 f"Autotest method {case.func.__qualname__} must return an object with json()." 

444 ) 

445 

446 data = response.json() 

447 ctx = AutotestContext( 

448 api=api, 

449 owner=case.owner, 

450 parent=case.parent, 

451 method=case.method, 

452 func=case.func, 

453 schemashot=schemashot, 

454 state=runtime_state, 

455 ) 

456 

457 hook = find_autotest_hook(case.func, case.parent) 

458 if hook is not None: 

459 hook_result = hook(response, data, ctx) 

460 if inspect.isawaitable(hook_result): 

461 hook_result = await hook_result 

462 if hook_result is not None: 

463 data = hook_result 

464 

465 schemashot.assert_json_match(data, case.func) 

466 

467 

468async def execute_autotest_data_cases( 

469 *, 

470 api: object, 

471 schemashot: Any, 

472 state: dict[str, Any] | None = None, 

473) -> int: 

474 _validate_schemashot(schemashot) 

475 runtime_state = state if state is not None else {} 

476 ctx = AutotestDataContext(api=api, schemashot=schemashot, state=runtime_state) 

477 

478 for case in list(_DATA_CASES): 

479 payload = case.provider(ctx) 

480 if inspect.isawaitable(payload): 

481 payload = await payload 

482 schemashot.assert_json_match(payload, case.name) 

483 

484 return len(_DATA_CASES) 

485 

486 

487async def _execute_autotest_data_cases_with_subtests( 

488 *, 

489 api: object, 

490 schemashot: Any, 

491 state: dict[str, Any], 

492 subtests: AutotestSubtests, 

493) -> int: 

494 _validate_schemashot(schemashot) 

495 ctx = AutotestDataContext(api=api, schemashot=schemashot, state=state) 

496 

497 processed_count = 0 

498 for case in list(_DATA_CASES): 

499 processed_count += 1 

500 with subtests.test(**_subtest_label_for_data(case)): 

501 payload = case.provider(ctx) 

502 if inspect.isawaitable(payload): 

503 payload = await payload 

504 schemashot.assert_json_match(payload, case.name) 

505 

506 return processed_count 

507 

508 

509def _is_autotest(func: AutotestFunction) -> bool: 

510 return bool(getattr(func, _AUTOTEST_ATTR, False)) 

511 

512 

513def _as_function(target: Callable[..., Any]) -> AutotestFunction: 

514 unbound: Any = target.__func__ if inspect.ismethod(target) else target 

515 if not callable(unbound): 

516 raise TypeError("Target must be a function or method.") 

517 return cast(AutotestFunction, inspect.unwrap(unbound)) 

518 

519 

520def _as_bound_method(value: Any) -> Callable[..., Any] | None: 

521 if inspect.ismethod(value) and value.__self__ is not None: 

522 return cast(Callable[..., Any], value) 

523 return None 

524 

525 

526def _is_resource_object(value: Any) -> bool: 

527 if value is None: 

528 return False 

529 if isinstance(value, _PrimitiveTypes): 

530 return False 

531 if isinstance(value, (dict, list, tuple, set, frozenset)): 

532 return False 

533 if inspect.ismodule(value) or inspect.isclass(value) or inspect.isfunction(value): 

534 return False 

535 if inspect.ismethod(value) or inspect.isbuiltin(value): 

536 return False 

537 return hasattr(value, "__dict__") or hasattr(value, "__slots__") 

538 

539 

540def _order_cases(cases: list[AutotestMethodCase]) -> list[AutotestMethodCase]: 

541 if len(cases) < 2: 

542 return cases 

543 

544 index_by_func: dict[AutotestFunction, list[int]] = {} 

545 for index, case in enumerate(cases): 

546 index_by_func.setdefault(case.func, []).append(index) 

547 

548 edges: dict[int, set[int]] = {index: set() for index in range(len(cases))} 

549 indegree = [0] * len(cases) 

550 

551 for target_index, case in enumerate(cases): 

552 for dependency in case.depends_on: 

553 for source_index in index_by_func.get(dependency, []): 

554 if source_index == target_index: 

555 continue 

556 if target_index in edges[source_index]: 

557 continue 

558 edges[source_index].add(target_index) 

559 indegree[target_index] += 1 

560 

561 queue: list[tuple[str, int]] = [] 

562 for index, case in enumerate(cases): 

563 if indegree[index] == 0: 

564 heapq.heappush(queue, (case.func.__qualname__, index)) 

565 

566 ordered: list[AutotestMethodCase] = [] 

567 while queue: 

568 _, current_index = heapq.heappop(queue) 

569 ordered.append(cases[current_index]) 

570 for dependent_index in edges[current_index]: 

571 indegree[dependent_index] -= 1 

572 if indegree[dependent_index] == 0: 

573 heapq.heappush( 

574 queue, 

575 (cases[dependent_index].func.__qualname__, dependent_index), 

576 ) 

577 

578 if len(ordered) == len(cases): 

579 return ordered 

580 

581 for index, case in enumerate(cases): 

582 if indegree[index] > 0: 

583 ordered.append(case) 

584 return ordered 

585 

586 

587def _required_parameters(method: Callable[..., Any]) -> tuple[str, ...]: 

588 required_arguments: list[str] = [] 

589 signature = inspect.signature(method) 

590 for parameter in signature.parameters.values(): 

591 if ( 

592 parameter.kind 

593 in ( 

594 inspect.Parameter.POSITIONAL_ONLY, 

595 inspect.Parameter.POSITIONAL_OR_KEYWORD, 

596 inspect.Parameter.KEYWORD_ONLY, 

597 ) 

598 and parameter.default is inspect.Signature.empty 

599 ): 

600 required_arguments.append(parameter.name) 

601 

602 return tuple(required_arguments) 

603 

604 

605def _normalize_depends_on( 

606 depends_on: Sequence[Callable[..., Any]] | None, 

607) -> tuple[AutotestFunction, ...]: 

608 if depends_on is None: 

609 return () 

610 

611 if not isinstance(depends_on, Sequence): 

612 raise TypeError("depends_on must be a sequence of functions or methods.") 

613 

614 return tuple(_as_function(target) for target in depends_on) 

615 

616 

617def _register_case_policy( 

618 *, 

619 parent: type[object] | None, 

620 target_func: AutotestFunction, 

621 depends_on: Iterable[AutotestFunction] = (), 

622) -> None: 

623 current = _CASE_POLICIES.get((parent, target_func), AutotestCasePolicy()) 

624 resolved_depends = tuple(depends_on) if depends_on else current.depends_on 

625 _CASE_POLICIES[(parent, target_func)] = AutotestCasePolicy( 

626 depends_on=resolved_depends, 

627 ) 

628 

629 

630def _get_callback_dependencies(callback: Callable[..., Any]) -> tuple[AutotestFunction, ...]: 

631 raw = getattr(callback, _DEPENDS_ON_ATTR, ()) 

632 if not isinstance(raw, tuple): 

633 return () 

634 return tuple(_as_function(dep) for dep in raw) 

635 

636 

637def _merge_dependencies( 

638 *sources: Iterable[AutotestFunction], 

639) -> tuple[AutotestFunction, ...]: 

640 merged: list[AutotestFunction] = [] 

641 for source in sources: 

642 for dependency in source: 

643 if dependency not in merged: 

644 merged.append(dependency) 

645 return tuple(merged) 

646 

647 

648def _initialize_runtime_state() -> tuple[ 

649 dict[str, Any], 

650 set[AutotestFunction], 

651 set[AutotestFunction], 

652]: 

653 state: dict[str, Any] = {} 

654 completed_funcs: set[AutotestFunction] = set() 

655 skipped_funcs: set[AutotestFunction] = set() 

656 state["autotest_completed_funcs"] = completed_funcs 

657 state["autotest_skipped_funcs"] = skipped_funcs 

658 return state, completed_funcs, skipped_funcs 

659 

660 

661def _subtest_label_for_case(case: AutotestMethodCase) -> dict[str, str]: 

662 parent_name = "None" if case.parent is None else type(case.parent).__name__ 

663 return { 

664 "method": case.func.__qualname__, 

665 "parent": parent_name, 

666 } 

667 

668 

669def _subtest_label_for_data(case: AutotestDataCase) -> dict[str, str]: 

670 return {"data": _snapshot_name_repr(case.name)} 

671 

672 

673def _snapshot_name_repr(name: object) -> str: 

674 if isinstance(name, (str, int)): 

675 return str(name) 

676 if callable(name): 

677 return getattr(name, "__qualname__", repr(name)) 

678 if isinstance(name, list): 

679 return "[" + ", ".join(_snapshot_name_repr(item) for item in name) + "]" 

680 return repr(name) 

681 

682 

683def _format_dependency_skip_reason(missing: tuple[AutotestFunction, ...]) -> str: 

684 names = ", ".join(dep.__qualname__ for dep in missing) 

685 return f"Dependency was not executed: {names}" 

686 

687 

688def _skip_current_case(reason: str) -> None: 

689 try: 

690 import pytest 

691 except Exception as error: # pragma: no cover - pytest runtime always has pytest installed 

692 raise RuntimeError("pytest is required to skip autotest subtest cases.") from error 

693 

694 pytest.skip(reason) 

695 

696 

697def _is_pytest_skip_exception(error: BaseException) -> bool: 

698 try: 

699 import pytest 

700 except Exception: 

701 return False 

702 

703 skip_exception = getattr(pytest.skip, "Exception", None) 

704 return bool(skip_exception and isinstance(error, skip_exception)) 

705 

706 

707def _validate_schemashot(schemashot: Any) -> None: 

708 if not hasattr(schemashot, "assert_json_match") or not callable(schemashot.assert_json_match): 

709 raise TypeError( 

710 "schemashot fixture must provide a callable assert_json_match(data, name) method." 

711 ) 

712 

713 

714async def _invoke_method( 

715 method: Callable[..., Any], 

716 func: AutotestFunction, 

717 invocation: AutotestInvocation, 

718) -> Any: 

719 result = method(*invocation.args, **invocation.kwargs) 

720 if not inspect.isawaitable(result): 

721 raise TypeError(f"Autotest method {func.__qualname__} must be async.") 

722 return await result 

723 

724 

725async def _resolve_invocation( 

726 *, 

727 case: AutotestMethodCase, 

728 api: object, 

729 schemashot: Any, 

730 state: dict[str, Any], 

731 typecheck_mode: AutotestTypecheckMode, 

732) -> AutotestInvocation: 

733 provider = find_autotest_params_provider(case.func, case.parent) 

734 if provider is None: 

735 invocation = AutotestInvocation() 

736 else: 

737 ctx = AutotestCallContext( 

738 api=api, 

739 owner=case.owner, 

740 parent=case.parent, 

741 method=case.method, 

742 func=case.func, 

743 schemashot=schemashot, 

744 state=state, 

745 ) 

746 raw = provider(ctx) 

747 if inspect.isawaitable(raw): 

748 raw = await raw 

749 invocation = _normalize_invocation(raw, case.func) 

750 

751 _validate_invocation( 

752 case.method, 

753 case.func, 

754 invocation, 

755 typecheck_mode=typecheck_mode, 

756 ) 

757 return invocation 

758 

759 

760def _normalize_invocation(raw: Any, func: AutotestFunction) -> AutotestInvocation: 

761 if raw is None: 

762 return AutotestInvocation() 

763 if isinstance(raw, AutotestInvocation): 

764 return AutotestInvocation(args=tuple(raw.args), kwargs=dict(raw.kwargs)) 

765 if isinstance(raw, dict): 

766 return AutotestInvocation(kwargs=dict(raw)) 

767 if isinstance(raw, (tuple, list)): 

768 return AutotestInvocation(args=tuple(raw)) 

769 

770 raise TypeError( 

771 f"autotest_params provider for {func.__qualname__} must return one of: " 

772 "None, dict (kwargs), tuple/list (args), AutotestInvocation." 

773 ) 

774 

775 

776def _validate_invocation( 

777 method: Callable[..., Any], 

778 func: AutotestFunction, 

779 invocation: AutotestInvocation, 

780 *, 

781 typecheck_mode: AutotestTypecheckMode = "off", 

782) -> None: 

783 signature = inspect.signature(method) 

784 try: 

785 bound_arguments = signature.bind(*invocation.args, **invocation.kwargs) 

786 except TypeError as error: 

787 raise TypeError(f"Invalid invocation for {func.__qualname__}: {error}") from error 

788 _validate_invocation_types( 

789 signature=signature, 

790 bound_arguments=bound_arguments.arguments, 

791 method=method, 

792 func=func, 

793 typecheck_mode=typecheck_mode, 

794 ) 

795 

796 

797def _normalize_typecheck_mode(mode: AutotestTypecheckMode | str) -> AutotestTypecheckMode: 

798 if not isinstance(mode, str): 

799 raise TypeError("autotest typecheck mode must be a string.") 

800 normalized = mode.strip().lower() 

801 if normalized not in _VALID_TYPECHECK_MODES: 

802 expected = ", ".join(sorted(_VALID_TYPECHECK_MODES)) 

803 raise ValueError(f"autotest typecheck mode must be one of: {expected}.") 

804 return cast(AutotestTypecheckMode, normalized) 

805 

806 

807def _validate_invocation_types( 

808 *, 

809 signature: inspect.Signature, 

810 bound_arguments: Mapping[str, Any], 

811 method: Callable[..., Any], 

812 func: AutotestFunction, 

813 typecheck_mode: AutotestTypecheckMode, 

814) -> None: 

815 if typecheck_mode == "off": 

816 return 

817 

818 mismatches: list[str] = [] 

819 for name, value in bound_arguments.items(): 

820 parameter = signature.parameters.get(name) 

821 if parameter is None: 

822 continue 

823 

824 annotation = _resolve_annotation(parameter.annotation, method) 

825 if annotation is inspect.Signature.empty: 

826 continue 

827 

828 if _matches_annotation(value, annotation): 

829 continue 

830 

831 expected = _format_annotation(annotation) 

832 mismatches.append(f"parameter {name!r} expects {expected}, got {type(value).__name__}") 

833 

834 if not mismatches: 

835 return 

836 

837 details = "; ".join(mismatches) 

838 message = f"Invalid invocation types for {func.__qualname__}: {details}." 

839 if typecheck_mode == "strict": 

840 raise TypeError(message) 

841 warnings.warn(message, RuntimeWarning, stacklevel=4) 

842 

843 

844def _resolve_annotation(annotation: Any, method: Callable[..., Any]) -> Any: 

845 if annotation is inspect.Signature.empty: 

846 return annotation 

847 

848 if not isinstance(annotation, str): 

849 return annotation 

850 

851 globals_dict = getattr(method, "__globals__", {}) 

852 if not isinstance(globals_dict, dict): 

853 return inspect.Signature.empty 

854 

855 try: 

856 return eval(annotation, globals_dict, {}) 

857 except Exception: 

858 return inspect.Signature.empty 

859 

860 

861def _matches_annotation(value: Any, annotation: Any) -> bool: 

862 if annotation in (Any, object): 

863 return True 

864 if annotation in (None, type(None)): 

865 return value is None 

866 

867 supertype = getattr(annotation, "__supertype__", None) 

868 if supertype is not None: 

869 return _matches_annotation(value, supertype) 

870 

871 origin = get_origin(annotation) 

872 if origin is None: 

873 if isinstance(annotation, type): 

874 return isinstance(value, annotation) 

875 return True 

876 

877 if origin in (types.UnionType, Union): 

878 return any(_matches_annotation(value, arg) for arg in get_args(annotation)) 

879 

880 if origin is Annotated: 

881 args = get_args(annotation) 

882 if not args: 

883 return True 

884 return _matches_annotation(value, args[0]) 

885 

886 if origin is Literal: 

887 return any(value == option for option in get_args(annotation)) 

888 

889 if origin is list: 

890 return _matches_iterable(value, annotation, list) 

891 if origin is set: 

892 return _matches_iterable(value, annotation, set) 

893 if origin is frozenset: 

894 return _matches_iterable(value, annotation, frozenset) 

895 if origin is tuple: 

896 return _matches_tuple(value, annotation) 

897 if origin is dict: 

898 return _matches_mapping(value, annotation) 

899 

900 if isinstance(origin, type): 

901 return isinstance(value, origin) 

902 

903 return True 

904 

905 

906def _matches_iterable(value: Any, annotation: Any, container_type: type[object]) -> bool: 

907 if not isinstance(value, container_type): 

908 return False 

909 args = get_args(annotation) 

910 if not args: 

911 return True 

912 item_type = args[0] 

913 iterable_value = cast(Iterable[Any], value) 

914 return all(_matches_annotation(item, item_type) for item in iterable_value) 

915 

916 

917def _matches_mapping(value: Any, annotation: Any) -> bool: 

918 if not isinstance(value, dict): 

919 return False 

920 key_type, value_type = (Any, Any) 

921 args = get_args(annotation) 

922 if len(args) == 2: 

923 key_type, value_type = args 

924 

925 for key, item in value.items(): 

926 if not _matches_annotation(key, key_type): 

927 return False 

928 if not _matches_annotation(item, value_type): 

929 return False 

930 return True 

931 

932 

933def _matches_tuple(value: Any, annotation: Any) -> bool: 

934 if not isinstance(value, tuple): 

935 return False 

936 args = get_args(annotation) 

937 if not args: 

938 return True 

939 if len(args) == 2 and args[1] is Ellipsis: 

940 return all(_matches_annotation(item, args[0]) for item in value) 

941 if len(args) != len(value): 

942 return False 

943 return all(_matches_annotation(item, item_type) for item, item_type in zip(value, args)) 

944 

945 

946def _format_annotation(annotation: Any) -> str: 

947 try: 

948 return inspect.formatannotation(annotation) 

949 except Exception: 

950 return str(annotation) 

951 

952 

953__all__ = [ 

954 "AutotestCallContext", 

955 "AutotestContext", 

956 "AutotestSubtests", 

957 "AutotestDataContext", 

958 "AutotestDataCase", 

959 "AutotestInvocation", 

960 "AutotestMethodCase", 

961 "AutotestCasePolicy", 

962 "AutotestTypecheckMode", 

963 "autotest", 

964 "autotest_depends_on", 

965 "autotest_data", 

966 "autotest_hook", 

967 "autotest_policy", 

968 "autotest_params", 

969 "clear_autotest_hooks", 

970 "discover_autotest_methods", 

971 "execute_autotest_case", 

972 "execute_autotest_data_cases", 

973 "execute_autotests", 

974 "execute_autotests_with_subtests", 

975 "find_autotest_hook", 

976 "find_autotest_hook_dependencies", 

977 "find_autotest_policy", 

978 "find_autotest_params_dependencies", 

979 "find_autotest_params_provider", 

980]