Coverage for genschema / postprocessing / schema_references.py: 79%

342 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-03-25 09:44 +0000

1from __future__ import annotations 

2 

3import copy 

4import re 

5from collections import Counter 

6from dataclasses import dataclass, field 

7from typing import Callable, Iterable, Literal, TypeAlias 

8 

9from ..comparators import ( 

10 DeleteElement, 

11 EmptyComparator, 

12 EnumComparator, 

13 FormatComparator, 

14 PreserveCommonKeywordsComparator, 

15 RequiredComparator, 

16) 

17from ..comparators.template import Comparator 

18from ..comparators.type import infer_schema_type, infer_schema_types 

19from ..pipeline import Converter 

20from ..pseudo_arrays import PseudoArrayHandlerBase 

21 

22PathSegment: TypeAlias = str | int 

23SchemaPath: TypeAlias = tuple[PathSegment, ...] 

24ComparatorFactory: TypeAlias = Callable[[], Comparator] 

25SimilarityMetric: TypeAlias = Callable[[frozenset[str], frozenset[str]], float] 

26MergeStrategy: TypeAlias = Callable[[list[dict], "SchemaReferenceExtractionConfig"], dict] 

27NameFactory: TypeAlias = Callable[[int, "CandidateGroup", "SchemaReferenceExtractionConfig"], str] 

28 

29DEFAULT_COMPARATOR_FACTORIES: tuple[ComparatorFactory, ...] = ( 

30 FormatComparator, 

31 EnumComparator, 

32 RequiredComparator, 

33 EmptyComparator, 

34 DeleteElement, 

35 lambda: DeleteElement("isPseudoArray"), 

36) 

37 

38DEFINITION_SECTION_KEYS = {"$defs", "definitions"} 

39STRUCTURAL_CONTAINER_KEYS = ( 

40 "items", 

41 "additionalProperties", 

42 "contains", 

43 "if", 

44 "then", 

45 "else", 

46 "not", 

47 "propertyNames", 

48 "unevaluatedProperties", 

49 "unevaluatedItems", 

50) 

51STRUCTURAL_VARIANT_KEYS = ("anyOf", "oneOf", "allOf", "prefixItems") 

52MEANINGFUL_PATH_PARTS_BLACKLIST = { 

53 "properties", 

54 "patternProperties", 

55 "items", 

56 "anyOf", 

57 "oneOf", 

58 "allOf", 

59 "prefixItems", 

60 "additionalProperties", 

61 "contains", 

62 "if", 

63 "then", 

64 "else", 

65 "not", 

66 "propertyNames", 

67 "unevaluatedProperties", 

68 "unevaluatedItems", 

69 "$defs", 

70 "definitions", 

71} 

72 

73 

74def _default_similarity(left: frozenset[str], right: frozenset[str]) -> float: 

75 if not left and not right: 

76 return 1.0 

77 if not left or not right: 

78 return 0.0 

79 intersection_size = len(left & right) 

80 return (2 * intersection_size) / (len(left) + len(right)) 

81 

82 

83@dataclass(slots=True, frozen=True) 

84class SchemaReferenceExtractionConfig: 

85 similarity_threshold: float = 0.85 

86 min_total_keys: int = 3 

87 min_occurrences: int = 2 

88 defs_key: str = "$defs" 

89 ref_prefix: str | None = None 

90 merge_base_of: Literal["anyOf", "oneOf", "allOf"] = "anyOf" 

91 merge_pseudo_handler: PseudoArrayHandlerBase | None = None 

92 merge_comparator_factories: tuple[ComparatorFactory, ...] = field( 

93 default_factory=lambda: DEFAULT_COMPARATOR_FACTORIES 

94 ) 

95 similarity_metric: SimilarityMetric = _default_similarity 

96 merge_strategy: MergeStrategy | None = None 

97 name_factory: NameFactory | None = None 

98 allowed_root_types: tuple[str, ...] = ("object", "array") 

99 preserve_common_keywords: bool = True 

100 include_root: bool = False 

101 skip_existing_definitions: bool = True 

102 

103 def __post_init__(self) -> None: 

104 if not 0 < self.similarity_threshold <= 1: 

105 raise ValueError("similarity_threshold must be in the (0, 1] range") 

106 if self.min_total_keys < 0: 

107 raise ValueError("min_total_keys must be >= 0") 

108 if self.min_occurrences < 2: 

109 raise ValueError("min_occurrences must be >= 2") 

110 if not self.defs_key: 

111 raise ValueError("defs_key must not be empty") 

112 if not self.allowed_root_types: 

113 raise ValueError("allowed_root_types must not be empty") 

114 

115 @property 

116 def normalized_ref_prefix(self) -> str: 

117 if self.ref_prefix is not None: 

118 return self.ref_prefix.rstrip("/") 

119 return f"#/{self.defs_key}" 

120 

121 

122@dataclass(slots=True) 

123class SchemaCandidate: 

124 path: SchemaPath 

125 schema: dict 

126 type_signature: tuple[str, ...] 

127 tokens: frozenset[str] 

128 total_keys: int 

129 

130 

131@dataclass(slots=True) 

132class CandidateGroup: 

133 members: list[SchemaCandidate] 

134 merged_schema: dict 

135 total_keys: int 

136 benefit: int 

137 definition_name: str = "" 

138 

139 

140class SchemaReferencePostprocessor: 

141 """ 

142 Standalone JSON Schema postprocessor that extracts repeated or highly similar 

143 structures into shared definitions and replaces occurrences with ``$ref``. 

144 

145 The postprocessor is intentionally independent from ``Converter`` itself: 

146 it can be run on any already-built schema. Candidate groups are merged 

147 through a fresh internal ``Converter`` run so the resulting definition stays 

148 aligned with the project's normal schema-combination pipeline. 

149 """ 

150 

151 @classmethod 

152 def process(cls, schema: dict, config: SchemaReferenceExtractionConfig | None = None) -> dict: 

153 if not isinstance(schema, dict): 

154 raise TypeError("schema must be a dict") 

155 

156 config = config or SchemaReferenceExtractionConfig() 

157 prepared = copy.deepcopy(schema) 

158 

159 candidates = cls._collect_candidates(prepared, config) 

160 if len(candidates) < config.min_occurrences: 

161 return prepared 

162 

163 groups = cls._build_groups(candidates, config) 

164 if not groups: 

165 return prepared 

166 

167 selected_groups = cls._select_groups(groups, config) 

168 if not selected_groups: 

169 return prepared 

170 

171 defs = prepared.setdefault(config.defs_key, {}) 

172 if not isinstance(defs, dict): 

173 raise TypeError(f"{config.defs_key} must be a dict when present") 

174 

175 for index, group in enumerate(selected_groups, start=1): 

176 name_factory = config.name_factory or cls._default_name_factory 

177 definition_name = cls._ensure_unique_definition_name( 

178 defs, name_factory(index, group, config) 

179 ) 

180 group.definition_name = definition_name 

181 defs[definition_name] = group.merged_schema 

182 

183 ref_node = {"$ref": f"{config.normalized_ref_prefix}/{definition_name}"} 

184 for member in sorted(group.members, key=lambda item: len(item.path), reverse=True): 

185 cls._replace_at_path(prepared, member.path, copy.deepcopy(ref_node)) 

186 

187 return prepared 

188 

189 @classmethod 

190 def extract(cls, schema: dict, config: SchemaReferenceExtractionConfig | None = None) -> dict: 

191 return cls.process(schema, config) 

192 

193 @classmethod 

194 def _collect_candidates( 

195 cls, schema: dict, config: SchemaReferenceExtractionConfig 

196 ) -> list[SchemaCandidate]: 

197 candidates: list[SchemaCandidate] = [] 

198 

199 def walk(node: object, path: SchemaPath, inside_definition_section: bool) -> None: 

200 if not isinstance(node, dict): 

201 return 

202 

203 local_inside_defs = inside_definition_section 

204 if path: 

205 parent_key = path[-1] 

206 if parent_key in DEFINITION_SECTION_KEYS: 

207 local_inside_defs = True 

208 

209 if ( 

210 (config.include_root or path) 

211 and (not local_inside_defs or not config.skip_existing_definitions) 

212 and cls._is_schema_candidate(node, config) 

213 ): 

214 tokens = cls._collect_structural_tokens(node) 

215 total_keys = cls._count_total_keys(tokens) 

216 if total_keys >= config.min_total_keys: 

217 type_signature = cls._type_signature(node) 

218 candidates.append( 

219 SchemaCandidate( 

220 path=path, 

221 schema=copy.deepcopy(node), 

222 type_signature=type_signature, 

223 tokens=frozenset(tokens), 

224 total_keys=total_keys, 

225 ) 

226 ) 

227 

228 for key, value in node.items(): 

229 next_inside_defs = local_inside_defs or key in DEFINITION_SECTION_KEYS 

230 if key in { 

231 "properties", 

232 "patternProperties", 

233 config.defs_key, 

234 "$defs", 

235 "definitions", 

236 }: 

237 if isinstance(value, dict): 

238 for child_key, child_value in value.items(): 

239 walk(child_value, path + (key, child_key), next_inside_defs) 

240 continue 

241 

242 if key in STRUCTURAL_CONTAINER_KEYS: 

243 walk(value, path + (key,), next_inside_defs) 

244 continue 

245 

246 if key in STRUCTURAL_VARIANT_KEYS and isinstance(value, list): 

247 for index, item in enumerate(value): 

248 walk(item, path + (key, index), next_inside_defs) 

249 

250 walk(schema, (), False) 

251 return candidates 

252 

253 @classmethod 

254 def _is_schema_candidate(cls, schema: dict, config: SchemaReferenceExtractionConfig) -> bool: 

255 if "$ref" in schema: 

256 return False 

257 

258 type_signature = cls._type_signature(schema) 

259 if not type_signature: 

260 return False 

261 

262 return any(item in config.allowed_root_types for item in type_signature) 

263 

264 @staticmethod 

265 def _type_signature(schema: dict) -> tuple[str, ...]: 

266 types = infer_schema_types(schema) 

267 if types: 

268 return tuple(sorted(types)) 

269 

270 inferred = infer_schema_type(schema) 

271 if inferred is not None: 

272 return (inferred,) 

273 

274 if isinstance(schema.get("properties"), dict) or isinstance( 

275 schema.get("patternProperties"), dict 

276 ): 

277 return ("object",) 

278 if "items" in schema or "prefixItems" in schema: 

279 return ("array",) 

280 if any(key in schema for key in ("anyOf", "oneOf", "allOf")): 

281 return ("union",) 

282 return () 

283 

284 @classmethod 

285 def _collect_structural_tokens(cls, schema: dict) -> set[str]: 

286 tokens: set[str] = set() 

287 

288 def walk(node: object, prefix: str) -> None: 

289 if not isinstance(node, dict): 

290 return 

291 

292 type_signature = cls._type_signature(node) 

293 if type_signature: 

294 tokens.add(f"{prefix}|type:{','.join(type_signature)}") 

295 

296 format_value = node.get("format") 

297 if isinstance(format_value, str): 

298 tokens.add(f"{prefix}|format:{format_value}") 

299 

300 if isinstance(node.get("enum"), list): 

301 tokens.add(f"{prefix}|enum") 

302 

303 properties = node.get("properties") 

304 if isinstance(properties, dict): 

305 for name, child in sorted(properties.items()): 

306 child_prefix = f"{prefix}/properties/{name}" 

307 tokens.add(f"{prefix}|prop:{name}") 

308 walk(child, child_prefix) 

309 

310 pattern_properties = node.get("patternProperties") 

311 if isinstance(pattern_properties, dict): 

312 for name, child in sorted(pattern_properties.items()): 

313 child_prefix = f"{prefix}/patternProperties/{name}" 

314 tokens.add(f"{prefix}|pattern:{name}") 

315 walk(child, child_prefix) 

316 

317 items = node.get("items") 

318 if isinstance(items, dict): 

319 tokens.add(f"{prefix}|items") 

320 walk(items, f"{prefix}/items") 

321 elif isinstance(items, list): 

322 for index, child in enumerate(items): 

323 tokens.add(f"{prefix}|items:{index}") 

324 walk(child, f"{prefix}/items/{index}") 

325 

326 for key in ("anyOf", "oneOf", "allOf", "prefixItems"): 

327 variants = node.get(key) 

328 if not isinstance(variants, list): 

329 continue 

330 tokens.add(f"{prefix}|{key}:{len(variants)}") 

331 for child in variants: 

332 walk(child, f"{prefix}/{key}/*") 

333 

334 for key in STRUCTURAL_CONTAINER_KEYS: 

335 child = node.get(key) 

336 if isinstance(child, dict): 

337 tokens.add(f"{prefix}|{key}") 

338 walk(child, f"{prefix}/{key}") 

339 

340 walk(schema, "#") 

341 return tokens 

342 

343 @staticmethod 

344 def _count_total_keys(tokens: Iterable[str]) -> int: 

345 return sum(1 for token in tokens if "|prop:" in token or "|pattern:" in token) 

346 

347 @classmethod 

348 def _build_groups( 

349 cls, candidates: list[SchemaCandidate], config: SchemaReferenceExtractionConfig 

350 ) -> list[CandidateGroup]: 

351 by_type: dict[tuple[str, ...], list[SchemaCandidate]] = {} 

352 for candidate in candidates: 

353 by_type.setdefault(candidate.type_signature, []).append(candidate) 

354 

355 groups: list[CandidateGroup] = [] 

356 for type_signature_candidates in by_type.values(): 

357 ordered = sorted( 

358 type_signature_candidates, 

359 key=lambda item: (-item.total_keys, len(item.path), item.path), 

360 ) 

361 consumed: set[int] = set() 

362 

363 for index, seed in enumerate(ordered): 

364 if index in consumed: 

365 continue 

366 

367 members = [seed] 

368 consumed.add(index) 

369 

370 scored: list[tuple[float, int, SchemaCandidate]] = [] 

371 for other_index, other in enumerate(ordered): 

372 if other_index in consumed: 

373 continue 

374 score = config.similarity_metric(seed.tokens, other.tokens) 

375 if score >= config.similarity_threshold: 

376 scored.append((score, other_index, other)) 

377 

378 scored.sort(key=lambda item: (-item[0], -item[2].total_keys, item[2].path)) 

379 

380 for _, other_index, other in scored: 

381 if all( 

382 config.similarity_metric(existing.tokens, other.tokens) 

383 >= config.similarity_threshold 

384 for existing in members 

385 ): 

386 members.append(other) 

387 consumed.add(other_index) 

388 

389 if len(members) < config.min_occurrences: 

390 continue 

391 

392 merged_schema = cls._merge_group([member.schema for member in members], config) 

393 merged_tokens = cls._collect_structural_tokens(merged_schema) 

394 merged_total_keys = cls._count_total_keys(merged_tokens) 

395 benefit = sum(member.total_keys for member in members) - merged_total_keys 

396 if benefit <= 0: 

397 continue 

398 

399 groups.append( 

400 CandidateGroup( 

401 members=members, 

402 merged_schema=merged_schema, 

403 total_keys=merged_total_keys, 

404 benefit=benefit, 

405 ) 

406 ) 

407 

408 groups.sort( 

409 key=lambda group: ( 

410 -group.benefit, 

411 -len(group.members), 

412 -group.total_keys, 

413 group.members[0].path, 

414 ) 

415 ) 

416 return groups 

417 

418 @classmethod 

419 def _merge_group(cls, schemas: list[dict], config: SchemaReferenceExtractionConfig) -> dict: 

420 merge_strategy = config.merge_strategy or cls._default_merge_strategy 

421 return merge_strategy(schemas, config) 

422 

423 @staticmethod 

424 def _default_merge_strategy( 

425 schemas: list[dict], config: SchemaReferenceExtractionConfig 

426 ) -> dict: 

427 converter = Converter( 

428 pseudo_handler=config.merge_pseudo_handler, 

429 base_of=config.merge_base_of, 

430 ) 

431 for schema in schemas: 

432 converter.add_schema(copy.deepcopy(schema)) 

433 for factory in config.merge_comparator_factories: 

434 converter.register(factory()) 

435 if config.preserve_common_keywords: 

436 converter.register(PreserveCommonKeywordsComparator()) 

437 return converter.run() 

438 

439 @classmethod 

440 def _select_groups( 

441 cls, groups: list[CandidateGroup], config: SchemaReferenceExtractionConfig 

442 ) -> list[CandidateGroup]: 

443 selected: list[CandidateGroup] = [] 

444 occupied_paths: list[SchemaPath] = [] 

445 

446 for group in groups: 

447 available_members = [ 

448 member 

449 for member in group.members 

450 if not any(cls._paths_overlap(member.path, occupied) for occupied in occupied_paths) 

451 ] 

452 if len(available_members) < config.min_occurrences: 

453 continue 

454 

455 if len(available_members) != len(group.members): 

456 merged_schema = cls._merge_group( 

457 [member.schema for member in available_members], config 

458 ) 

459 merged_total_keys = cls._count_total_keys( 

460 cls._collect_structural_tokens(merged_schema) 

461 ) 

462 benefit = sum(member.total_keys for member in available_members) - merged_total_keys 

463 if benefit <= 0: 

464 continue 

465 group = CandidateGroup( 

466 members=available_members, 

467 merged_schema=merged_schema, 

468 total_keys=merged_total_keys, 

469 benefit=benefit, 

470 ) 

471 

472 selected.append(group) 

473 occupied_paths.extend(member.path for member in group.members) 

474 

475 return selected 

476 

477 @staticmethod 

478 def _paths_overlap(left: SchemaPath, right: SchemaPath) -> bool: 

479 shortest = min(len(left), len(right)) 

480 return left[:shortest] == right[:shortest] 

481 

482 @classmethod 

483 def _replace_at_path(cls, document: dict, path: SchemaPath, new_value: dict) -> None: 

484 if not path: 

485 document.clear() 

486 document.update(new_value) 

487 return 

488 

489 parent: object = document 

490 for segment in path[:-1]: 

491 if isinstance(segment, int): 

492 if not isinstance(parent, list): 

493 raise TypeError("Path points to a list index inside a non-list container") 

494 parent = parent[segment] 

495 else: 

496 if not isinstance(parent, dict): 

497 raise TypeError("Path points to a dict key inside a non-dict container") 

498 parent = parent[segment] 

499 

500 last = path[-1] 

501 if isinstance(last, int): 

502 if not isinstance(parent, list): 

503 raise TypeError("Path points to a list index inside a non-list container") 

504 parent[last] = new_value 

505 return 

506 

507 if not isinstance(parent, dict): 

508 raise TypeError("Path points to a dict key inside a non-dict container") 

509 parent[last] = new_value 

510 

511 @classmethod 

512 def _default_name_factory( 

513 cls, 

514 index: int, 

515 group: CandidateGroup, 

516 config: SchemaReferenceExtractionConfig, 

517 ) -> str: 

518 meaningful_tail_parts: list[str] = [] 

519 for member in group.members: 

520 meaningful = [ 

521 str(part) 

522 for part in member.path 

523 if isinstance(part, str) 

524 and part not in MEANINGFUL_PATH_PARTS_BLACKLIST 

525 and not part.startswith("$") 

526 ] 

527 if meaningful: 

528 meaningful_tail_parts.append(meaningful[-1]) 

529 

530 if meaningful_tail_parts: 

531 name, count = Counter(meaningful_tail_parts).most_common(1)[0] 

532 if count == len(group.members) or count >= 2: 

533 normalized = cls._normalize_definition_name(name) 

534 if normalized: 

535 return normalized 

536 

537 root_type = ( 

538 group.members[0].type_signature[0] if group.members[0].type_signature else "schema" 

539 ) 

540 return f"{cls._normalize_definition_name(root_type)}{index}" 

541 

542 @staticmethod 

543 def _normalize_definition_name(value: str) -> str: 

544 parts = [part for part in re.split(r"[^0-9A-Za-z]+", value) if part] 

545 if not parts: 

546 return "SharedSchema" 

547 normalized = "".join(part[:1].upper() + part[1:] for part in parts) 

548 if normalized[0].isdigit(): 

549 return f"Shared{normalized}" 

550 return normalized 

551 

552 @staticmethod 

553 def _ensure_unique_definition_name(defs: dict, base_name: str) -> str: 

554 candidate = base_name 

555 index = 2 

556 while candidate in defs: 

557 candidate = f"{base_name}{index}" 

558 index += 1 

559 return candidate