Coverage for genschema / postprocessing / schema_references.py: 79%
342 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-25 09:44 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-25 09:44 +0000
1from __future__ import annotations
3import copy
4import re
5from collections import Counter
6from dataclasses import dataclass, field
7from typing import Callable, Iterable, Literal, TypeAlias
9from ..comparators import (
10 DeleteElement,
11 EmptyComparator,
12 EnumComparator,
13 FormatComparator,
14 PreserveCommonKeywordsComparator,
15 RequiredComparator,
16)
17from ..comparators.template import Comparator
18from ..comparators.type import infer_schema_type, infer_schema_types
19from ..pipeline import Converter
20from ..pseudo_arrays import PseudoArrayHandlerBase
22PathSegment: TypeAlias = str | int
23SchemaPath: TypeAlias = tuple[PathSegment, ...]
24ComparatorFactory: TypeAlias = Callable[[], Comparator]
25SimilarityMetric: TypeAlias = Callable[[frozenset[str], frozenset[str]], float]
26MergeStrategy: TypeAlias = Callable[[list[dict], "SchemaReferenceExtractionConfig"], dict]
27NameFactory: TypeAlias = Callable[[int, "CandidateGroup", "SchemaReferenceExtractionConfig"], str]
29DEFAULT_COMPARATOR_FACTORIES: tuple[ComparatorFactory, ...] = (
30 FormatComparator,
31 EnumComparator,
32 RequiredComparator,
33 EmptyComparator,
34 DeleteElement,
35 lambda: DeleteElement("isPseudoArray"),
36)
38DEFINITION_SECTION_KEYS = {"$defs", "definitions"}
39STRUCTURAL_CONTAINER_KEYS = (
40 "items",
41 "additionalProperties",
42 "contains",
43 "if",
44 "then",
45 "else",
46 "not",
47 "propertyNames",
48 "unevaluatedProperties",
49 "unevaluatedItems",
50)
51STRUCTURAL_VARIANT_KEYS = ("anyOf", "oneOf", "allOf", "prefixItems")
52MEANINGFUL_PATH_PARTS_BLACKLIST = {
53 "properties",
54 "patternProperties",
55 "items",
56 "anyOf",
57 "oneOf",
58 "allOf",
59 "prefixItems",
60 "additionalProperties",
61 "contains",
62 "if",
63 "then",
64 "else",
65 "not",
66 "propertyNames",
67 "unevaluatedProperties",
68 "unevaluatedItems",
69 "$defs",
70 "definitions",
71}
74def _default_similarity(left: frozenset[str], right: frozenset[str]) -> float:
75 if not left and not right:
76 return 1.0
77 if not left or not right:
78 return 0.0
79 intersection_size = len(left & right)
80 return (2 * intersection_size) / (len(left) + len(right))
83@dataclass(slots=True, frozen=True)
84class SchemaReferenceExtractionConfig:
85 similarity_threshold: float = 0.85
86 min_total_keys: int = 3
87 min_occurrences: int = 2
88 defs_key: str = "$defs"
89 ref_prefix: str | None = None
90 merge_base_of: Literal["anyOf", "oneOf", "allOf"] = "anyOf"
91 merge_pseudo_handler: PseudoArrayHandlerBase | None = None
92 merge_comparator_factories: tuple[ComparatorFactory, ...] = field(
93 default_factory=lambda: DEFAULT_COMPARATOR_FACTORIES
94 )
95 similarity_metric: SimilarityMetric = _default_similarity
96 merge_strategy: MergeStrategy | None = None
97 name_factory: NameFactory | None = None
98 allowed_root_types: tuple[str, ...] = ("object", "array")
99 preserve_common_keywords: bool = True
100 include_root: bool = False
101 skip_existing_definitions: bool = True
103 def __post_init__(self) -> None:
104 if not 0 < self.similarity_threshold <= 1:
105 raise ValueError("similarity_threshold must be in the (0, 1] range")
106 if self.min_total_keys < 0:
107 raise ValueError("min_total_keys must be >= 0")
108 if self.min_occurrences < 2:
109 raise ValueError("min_occurrences must be >= 2")
110 if not self.defs_key:
111 raise ValueError("defs_key must not be empty")
112 if not self.allowed_root_types:
113 raise ValueError("allowed_root_types must not be empty")
115 @property
116 def normalized_ref_prefix(self) -> str:
117 if self.ref_prefix is not None:
118 return self.ref_prefix.rstrip("/")
119 return f"#/{self.defs_key}"
122@dataclass(slots=True)
123class SchemaCandidate:
124 path: SchemaPath
125 schema: dict
126 type_signature: tuple[str, ...]
127 tokens: frozenset[str]
128 total_keys: int
131@dataclass(slots=True)
132class CandidateGroup:
133 members: list[SchemaCandidate]
134 merged_schema: dict
135 total_keys: int
136 benefit: int
137 definition_name: str = ""
140class SchemaReferencePostprocessor:
141 """
142 Standalone JSON Schema postprocessor that extracts repeated or highly similar
143 structures into shared definitions and replaces occurrences with ``$ref``.
145 The postprocessor is intentionally independent from ``Converter`` itself:
146 it can be run on any already-built schema. Candidate groups are merged
147 through a fresh internal ``Converter`` run so the resulting definition stays
148 aligned with the project's normal schema-combination pipeline.
149 """
151 @classmethod
152 def process(cls, schema: dict, config: SchemaReferenceExtractionConfig | None = None) -> dict:
153 if not isinstance(schema, dict):
154 raise TypeError("schema must be a dict")
156 config = config or SchemaReferenceExtractionConfig()
157 prepared = copy.deepcopy(schema)
159 candidates = cls._collect_candidates(prepared, config)
160 if len(candidates) < config.min_occurrences:
161 return prepared
163 groups = cls._build_groups(candidates, config)
164 if not groups:
165 return prepared
167 selected_groups = cls._select_groups(groups, config)
168 if not selected_groups:
169 return prepared
171 defs = prepared.setdefault(config.defs_key, {})
172 if not isinstance(defs, dict):
173 raise TypeError(f"{config.defs_key} must be a dict when present")
175 for index, group in enumerate(selected_groups, start=1):
176 name_factory = config.name_factory or cls._default_name_factory
177 definition_name = cls._ensure_unique_definition_name(
178 defs, name_factory(index, group, config)
179 )
180 group.definition_name = definition_name
181 defs[definition_name] = group.merged_schema
183 ref_node = {"$ref": f"{config.normalized_ref_prefix}/{definition_name}"}
184 for member in sorted(group.members, key=lambda item: len(item.path), reverse=True):
185 cls._replace_at_path(prepared, member.path, copy.deepcopy(ref_node))
187 return prepared
189 @classmethod
190 def extract(cls, schema: dict, config: SchemaReferenceExtractionConfig | None = None) -> dict:
191 return cls.process(schema, config)
193 @classmethod
194 def _collect_candidates(
195 cls, schema: dict, config: SchemaReferenceExtractionConfig
196 ) -> list[SchemaCandidate]:
197 candidates: list[SchemaCandidate] = []
199 def walk(node: object, path: SchemaPath, inside_definition_section: bool) -> None:
200 if not isinstance(node, dict):
201 return
203 local_inside_defs = inside_definition_section
204 if path:
205 parent_key = path[-1]
206 if parent_key in DEFINITION_SECTION_KEYS:
207 local_inside_defs = True
209 if (
210 (config.include_root or path)
211 and (not local_inside_defs or not config.skip_existing_definitions)
212 and cls._is_schema_candidate(node, config)
213 ):
214 tokens = cls._collect_structural_tokens(node)
215 total_keys = cls._count_total_keys(tokens)
216 if total_keys >= config.min_total_keys:
217 type_signature = cls._type_signature(node)
218 candidates.append(
219 SchemaCandidate(
220 path=path,
221 schema=copy.deepcopy(node),
222 type_signature=type_signature,
223 tokens=frozenset(tokens),
224 total_keys=total_keys,
225 )
226 )
228 for key, value in node.items():
229 next_inside_defs = local_inside_defs or key in DEFINITION_SECTION_KEYS
230 if key in {
231 "properties",
232 "patternProperties",
233 config.defs_key,
234 "$defs",
235 "definitions",
236 }:
237 if isinstance(value, dict):
238 for child_key, child_value in value.items():
239 walk(child_value, path + (key, child_key), next_inside_defs)
240 continue
242 if key in STRUCTURAL_CONTAINER_KEYS:
243 walk(value, path + (key,), next_inside_defs)
244 continue
246 if key in STRUCTURAL_VARIANT_KEYS and isinstance(value, list):
247 for index, item in enumerate(value):
248 walk(item, path + (key, index), next_inside_defs)
250 walk(schema, (), False)
251 return candidates
253 @classmethod
254 def _is_schema_candidate(cls, schema: dict, config: SchemaReferenceExtractionConfig) -> bool:
255 if "$ref" in schema:
256 return False
258 type_signature = cls._type_signature(schema)
259 if not type_signature:
260 return False
262 return any(item in config.allowed_root_types for item in type_signature)
264 @staticmethod
265 def _type_signature(schema: dict) -> tuple[str, ...]:
266 types = infer_schema_types(schema)
267 if types:
268 return tuple(sorted(types))
270 inferred = infer_schema_type(schema)
271 if inferred is not None:
272 return (inferred,)
274 if isinstance(schema.get("properties"), dict) or isinstance(
275 schema.get("patternProperties"), dict
276 ):
277 return ("object",)
278 if "items" in schema or "prefixItems" in schema:
279 return ("array",)
280 if any(key in schema for key in ("anyOf", "oneOf", "allOf")):
281 return ("union",)
282 return ()
284 @classmethod
285 def _collect_structural_tokens(cls, schema: dict) -> set[str]:
286 tokens: set[str] = set()
288 def walk(node: object, prefix: str) -> None:
289 if not isinstance(node, dict):
290 return
292 type_signature = cls._type_signature(node)
293 if type_signature:
294 tokens.add(f"{prefix}|type:{','.join(type_signature)}")
296 format_value = node.get("format")
297 if isinstance(format_value, str):
298 tokens.add(f"{prefix}|format:{format_value}")
300 if isinstance(node.get("enum"), list):
301 tokens.add(f"{prefix}|enum")
303 properties = node.get("properties")
304 if isinstance(properties, dict):
305 for name, child in sorted(properties.items()):
306 child_prefix = f"{prefix}/properties/{name}"
307 tokens.add(f"{prefix}|prop:{name}")
308 walk(child, child_prefix)
310 pattern_properties = node.get("patternProperties")
311 if isinstance(pattern_properties, dict):
312 for name, child in sorted(pattern_properties.items()):
313 child_prefix = f"{prefix}/patternProperties/{name}"
314 tokens.add(f"{prefix}|pattern:{name}")
315 walk(child, child_prefix)
317 items = node.get("items")
318 if isinstance(items, dict):
319 tokens.add(f"{prefix}|items")
320 walk(items, f"{prefix}/items")
321 elif isinstance(items, list):
322 for index, child in enumerate(items):
323 tokens.add(f"{prefix}|items:{index}")
324 walk(child, f"{prefix}/items/{index}")
326 for key in ("anyOf", "oneOf", "allOf", "prefixItems"):
327 variants = node.get(key)
328 if not isinstance(variants, list):
329 continue
330 tokens.add(f"{prefix}|{key}:{len(variants)}")
331 for child in variants:
332 walk(child, f"{prefix}/{key}/*")
334 for key in STRUCTURAL_CONTAINER_KEYS:
335 child = node.get(key)
336 if isinstance(child, dict):
337 tokens.add(f"{prefix}|{key}")
338 walk(child, f"{prefix}/{key}")
340 walk(schema, "#")
341 return tokens
343 @staticmethod
344 def _count_total_keys(tokens: Iterable[str]) -> int:
345 return sum(1 for token in tokens if "|prop:" in token or "|pattern:" in token)
347 @classmethod
348 def _build_groups(
349 cls, candidates: list[SchemaCandidate], config: SchemaReferenceExtractionConfig
350 ) -> list[CandidateGroup]:
351 by_type: dict[tuple[str, ...], list[SchemaCandidate]] = {}
352 for candidate in candidates:
353 by_type.setdefault(candidate.type_signature, []).append(candidate)
355 groups: list[CandidateGroup] = []
356 for type_signature_candidates in by_type.values():
357 ordered = sorted(
358 type_signature_candidates,
359 key=lambda item: (-item.total_keys, len(item.path), item.path),
360 )
361 consumed: set[int] = set()
363 for index, seed in enumerate(ordered):
364 if index in consumed:
365 continue
367 members = [seed]
368 consumed.add(index)
370 scored: list[tuple[float, int, SchemaCandidate]] = []
371 for other_index, other in enumerate(ordered):
372 if other_index in consumed:
373 continue
374 score = config.similarity_metric(seed.tokens, other.tokens)
375 if score >= config.similarity_threshold:
376 scored.append((score, other_index, other))
378 scored.sort(key=lambda item: (-item[0], -item[2].total_keys, item[2].path))
380 for _, other_index, other in scored:
381 if all(
382 config.similarity_metric(existing.tokens, other.tokens)
383 >= config.similarity_threshold
384 for existing in members
385 ):
386 members.append(other)
387 consumed.add(other_index)
389 if len(members) < config.min_occurrences:
390 continue
392 merged_schema = cls._merge_group([member.schema for member in members], config)
393 merged_tokens = cls._collect_structural_tokens(merged_schema)
394 merged_total_keys = cls._count_total_keys(merged_tokens)
395 benefit = sum(member.total_keys for member in members) - merged_total_keys
396 if benefit <= 0:
397 continue
399 groups.append(
400 CandidateGroup(
401 members=members,
402 merged_schema=merged_schema,
403 total_keys=merged_total_keys,
404 benefit=benefit,
405 )
406 )
408 groups.sort(
409 key=lambda group: (
410 -group.benefit,
411 -len(group.members),
412 -group.total_keys,
413 group.members[0].path,
414 )
415 )
416 return groups
418 @classmethod
419 def _merge_group(cls, schemas: list[dict], config: SchemaReferenceExtractionConfig) -> dict:
420 merge_strategy = config.merge_strategy or cls._default_merge_strategy
421 return merge_strategy(schemas, config)
423 @staticmethod
424 def _default_merge_strategy(
425 schemas: list[dict], config: SchemaReferenceExtractionConfig
426 ) -> dict:
427 converter = Converter(
428 pseudo_handler=config.merge_pseudo_handler,
429 base_of=config.merge_base_of,
430 )
431 for schema in schemas:
432 converter.add_schema(copy.deepcopy(schema))
433 for factory in config.merge_comparator_factories:
434 converter.register(factory())
435 if config.preserve_common_keywords:
436 converter.register(PreserveCommonKeywordsComparator())
437 return converter.run()
439 @classmethod
440 def _select_groups(
441 cls, groups: list[CandidateGroup], config: SchemaReferenceExtractionConfig
442 ) -> list[CandidateGroup]:
443 selected: list[CandidateGroup] = []
444 occupied_paths: list[SchemaPath] = []
446 for group in groups:
447 available_members = [
448 member
449 for member in group.members
450 if not any(cls._paths_overlap(member.path, occupied) for occupied in occupied_paths)
451 ]
452 if len(available_members) < config.min_occurrences:
453 continue
455 if len(available_members) != len(group.members):
456 merged_schema = cls._merge_group(
457 [member.schema for member in available_members], config
458 )
459 merged_total_keys = cls._count_total_keys(
460 cls._collect_structural_tokens(merged_schema)
461 )
462 benefit = sum(member.total_keys for member in available_members) - merged_total_keys
463 if benefit <= 0:
464 continue
465 group = CandidateGroup(
466 members=available_members,
467 merged_schema=merged_schema,
468 total_keys=merged_total_keys,
469 benefit=benefit,
470 )
472 selected.append(group)
473 occupied_paths.extend(member.path for member in group.members)
475 return selected
477 @staticmethod
478 def _paths_overlap(left: SchemaPath, right: SchemaPath) -> bool:
479 shortest = min(len(left), len(right))
480 return left[:shortest] == right[:shortest]
482 @classmethod
483 def _replace_at_path(cls, document: dict, path: SchemaPath, new_value: dict) -> None:
484 if not path:
485 document.clear()
486 document.update(new_value)
487 return
489 parent: object = document
490 for segment in path[:-1]:
491 if isinstance(segment, int):
492 if not isinstance(parent, list):
493 raise TypeError("Path points to a list index inside a non-list container")
494 parent = parent[segment]
495 else:
496 if not isinstance(parent, dict):
497 raise TypeError("Path points to a dict key inside a non-dict container")
498 parent = parent[segment]
500 last = path[-1]
501 if isinstance(last, int):
502 if not isinstance(parent, list):
503 raise TypeError("Path points to a list index inside a non-list container")
504 parent[last] = new_value
505 return
507 if not isinstance(parent, dict):
508 raise TypeError("Path points to a dict key inside a non-dict container")
509 parent[last] = new_value
511 @classmethod
512 def _default_name_factory(
513 cls,
514 index: int,
515 group: CandidateGroup,
516 config: SchemaReferenceExtractionConfig,
517 ) -> str:
518 meaningful_tail_parts: list[str] = []
519 for member in group.members:
520 meaningful = [
521 str(part)
522 for part in member.path
523 if isinstance(part, str)
524 and part not in MEANINGFUL_PATH_PARTS_BLACKLIST
525 and not part.startswith("$")
526 ]
527 if meaningful:
528 meaningful_tail_parts.append(meaningful[-1])
530 if meaningful_tail_parts:
531 name, count = Counter(meaningful_tail_parts).most_common(1)[0]
532 if count == len(group.members) or count >= 2:
533 normalized = cls._normalize_definition_name(name)
534 if normalized:
535 return normalized
537 root_type = (
538 group.members[0].type_signature[0] if group.members[0].type_signature else "schema"
539 )
540 return f"{cls._normalize_definition_name(root_type)}{index}"
542 @staticmethod
543 def _normalize_definition_name(value: str) -> str:
544 parts = [part for part in re.split(r"[^0-9A-Za-z]+", value) if part]
545 if not parts:
546 return "SharedSchema"
547 normalized = "".join(part[:1].upper() + part[1:] for part in parts)
548 if normalized[0].isdigit():
549 return f"Shared{normalized}"
550 return normalized
552 @staticmethod
553 def _ensure_unique_definition_name(defs: dict, base_name: str) -> str:
554 candidate = base_name
555 index = 2
556 while candidate in defs:
557 candidate = f"{base_name}{index}"
558 index += 1
559 return candidate