10
10
11
11
12
12
def to_strict_json_schema (model : type [pydantic .BaseModel ]) -> dict [str , Any ]:
13
- return _ensure_strict_json_schema (model_json_schema (model ), path = ())
13
+ schema = model_json_schema (model )
14
+ return _ensure_strict_json_schema (schema , path = (), root = schema )
14
15
15
16
16
17
def _ensure_strict_json_schema (
17
18
json_schema : object ,
19
+ * ,
18
20
path : tuple [str , ...],
21
+ root : dict [str , object ],
19
22
) -> dict [str , Any ]:
20
23
"""Mutates the given JSON schema to ensure it conforms to the `strict` standard
21
24
that the API expects.
22
25
"""
23
26
if not is_dict (json_schema ):
24
27
raise TypeError (f"Expected { json_schema } to be a dictionary; path={ path } " )
25
28
29
+ defs = json_schema .get ("$defs" )
30
+ if is_dict (defs ):
31
+ for def_name , def_schema in defs .items ():
32
+ _ensure_strict_json_schema (def_schema , path = (* path , "$defs" , def_name ), root = root )
33
+
34
+ definitions = json_schema .get ("definitions" )
35
+ if is_dict (definitions ):
36
+ for definition_name , definition_schema in definitions .items ():
37
+ _ensure_strict_json_schema (definition_schema , path = (* path , "definitions" , definition_name ), root = root )
38
+
26
39
typ = json_schema .get ("type" )
27
40
if typ == "object" and "additionalProperties" not in json_schema :
28
41
json_schema ["additionalProperties" ] = False
@@ -33,48 +46,80 @@ def _ensure_strict_json_schema(
33
46
if is_dict (properties ):
34
47
json_schema ["required" ] = [prop for prop in properties .keys ()]
35
48
json_schema ["properties" ] = {
36
- key : _ensure_strict_json_schema (prop_schema , path = (* path , "properties" , key ))
49
+ key : _ensure_strict_json_schema (prop_schema , path = (* path , "properties" , key ), root = root )
37
50
for key , prop_schema in properties .items ()
38
51
}
39
52
40
53
# arrays
41
54
# { 'type': 'array', 'items': {...} }
42
55
items = json_schema .get ("items" )
43
56
if is_dict (items ):
44
- json_schema ["items" ] = _ensure_strict_json_schema (items , path = (* path , "items" ))
57
+ json_schema ["items" ] = _ensure_strict_json_schema (items , path = (* path , "items" ), root = root )
45
58
46
59
# unions
47
60
any_of = json_schema .get ("anyOf" )
48
61
if is_list (any_of ):
49
62
json_schema ["anyOf" ] = [
50
- _ensure_strict_json_schema (variant , path = (* path , "anyOf" , str (i ))) for i , variant in enumerate (any_of )
63
+ _ensure_strict_json_schema (variant , path = (* path , "anyOf" , str (i )), root = root )
64
+ for i , variant in enumerate (any_of )
51
65
]
52
66
53
67
# intersections
54
68
all_of = json_schema .get ("allOf" )
55
69
if is_list (all_of ):
56
70
if len (all_of ) == 1 :
57
- json_schema .update (_ensure_strict_json_schema (all_of [0 ], path = (* path , "allOf" , "0" )))
71
+ json_schema .update (_ensure_strict_json_schema (all_of [0 ], path = (* path , "allOf" , "0" ), root = root ))
58
72
json_schema .pop ("allOf" )
59
73
else :
60
74
json_schema ["allOf" ] = [
61
- _ensure_strict_json_schema (entry , path = (* path , "allOf" , str (i ))) for i , entry in enumerate (all_of )
75
+ _ensure_strict_json_schema (entry , path = (* path , "allOf" , str (i )), root = root )
76
+ for i , entry in enumerate (all_of )
62
77
]
63
78
64
- defs = json_schema .get ("$defs" )
65
- if is_dict (defs ):
66
- for def_name , def_schema in defs .items ():
67
- _ensure_strict_json_schema (def_schema , path = (* path , "$defs" , def_name ))
79
+ # we can't use `$ref`s if there are also other properties defined, e.g.
80
+ # `{"$ref": "...", "description": "my description"}`
81
+ #
82
+ # so we unravel the ref
83
+ # `{"type": "string", "description": "my description"}`
84
+ ref = json_schema .get ("$ref" )
85
+ if ref and has_more_than_n_keys (json_schema , 1 ):
86
+ assert isinstance (ref , str ), f"Received non-string $ref - { ref } "
68
87
69
- definitions = json_schema .get ("definitions" )
70
- if is_dict (definitions ):
71
- for definition_name , definition_schema in definitions .items ():
72
- _ensure_strict_json_schema (definition_schema , path = (* path , "definitions" , definition_name ))
88
+ resolved = resolve_ref (root = root , ref = ref )
89
+ if not is_dict (resolved ):
90
+ raise ValueError (f"Expected `$ref: { ref } ` to resolved to a dictionary but got { resolved } " )
91
+
92
+ # properties from the json schema take priority over the ones on the `$ref`
93
+ json_schema .update ({** resolved , ** json_schema })
94
+ json_schema .pop ("$ref" )
73
95
74
96
return json_schema
75
97
76
98
99
+ def resolve_ref (* , root : dict [str , object ], ref : str ) -> object :
100
+ if not ref .startswith ("#/" ):
101
+ raise ValueError (f"Unexpected $ref format { ref !r} ; Does not start with #/" )
102
+
103
+ path = ref [2 :].split ("/" )
104
+ resolved = root
105
+ for key in path :
106
+ value = resolved [key ]
107
+ assert is_dict (value ), f"encountered non-dictionary entry while resolving { ref } - { resolved } "
108
+ resolved = value
109
+
110
+ return resolved
111
+
112
+
77
113
def is_dict (obj : object ) -> TypeGuard [dict [str , object ]]:
78
114
# just pretend that we know there are only `str` keys
79
115
# as that check is not worth the performance cost
80
116
return _is_dict (obj )
117
+
118
+
119
+ def has_more_than_n_keys (obj : dict [str , object ], n : int ) -> bool :
120
+ i = 0
121
+ for _ in obj .keys ():
122
+ i += 1
123
+ if i > n :
124
+ return True
125
+ return False
0 commit comments