Skip to content

Commit ceaa179

Browse files
Account for aliases in alpha spec check
1 parent c31993c commit ceaa179

File tree

2 files changed

+188
-65
lines changed

2 files changed

+188
-65
lines changed

src/rapids_pre_commit_hooks/alpha_spec.py

Lines changed: 93 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def is_rapids_cuda_suffixed_package(name):
7979
)
8080

8181

82-
def check_package_spec(linter, args, node):
82+
def check_package_spec(linter, args, anchors, used_anchors, node):
8383
@total_ordering
8484
class SpecPriority:
8585
def __init__(self, spec):
@@ -111,42 +111,51 @@ def create_specifier_string(specifiers):
111111
if req.name in RAPIDS_ALPHA_SPEC_PACKAGES or is_rapids_cuda_suffixed_package(
112112
req.name
113113
):
114-
has_alpha_spec = any(str(s) == ALPHA_SPECIFIER for s in req.specifier)
115-
if args.mode == "development" and not has_alpha_spec:
116-
linter.add_warning(
117-
(node.start_mark.index, node.end_mark.index),
118-
f"add alpha spec for RAPIDS package {req.name}",
119-
).add_replacement(
120-
(node.start_mark.index, node.end_mark.index),
121-
str(
122-
req.name
123-
+ create_specifier_string(
124-
{str(s) for s in req.specifier} | {ALPHA_SPECIFIER}
125-
)
126-
),
127-
)
128-
elif args.mode == "release" and has_alpha_spec:
129-
linter.add_warning(
130-
(node.start_mark.index, node.end_mark.index),
131-
f"remove alpha spec for RAPIDS package {req.name}",
132-
).add_replacement(
133-
(node.start_mark.index, node.end_mark.index),
134-
str(
135-
req.name
136-
+ create_specifier_string(
137-
{str(s) for s in req.specifier} - {ALPHA_SPECIFIER}
138-
)
139-
),
140-
)
141-
142-
143-
def check_packages(linter, args, node):
114+
anchor = (
115+
keys[0]
116+
if (keys := [key for key, value in anchors.items() if value == node])
117+
else None
118+
)
119+
if anchor not in used_anchors:
120+
used_anchors.add(anchor)
121+
has_alpha_spec = any(str(s) == ALPHA_SPECIFIER for s in req.specifier)
122+
if args.mode == "development" and not has_alpha_spec:
123+
linter.add_warning(
124+
(node.start_mark.index, node.end_mark.index),
125+
f"add alpha spec for RAPIDS package {req.name}",
126+
).add_replacement(
127+
(node.start_mark.index, node.end_mark.index),
128+
str(
129+
(f"&{anchor} " if anchor else "")
130+
+ req.name
131+
+ create_specifier_string(
132+
{str(s) for s in req.specifier} | {ALPHA_SPECIFIER},
133+
)
134+
),
135+
)
136+
elif args.mode == "release" and has_alpha_spec:
137+
linter.add_warning(
138+
(node.start_mark.index, node.end_mark.index),
139+
f"remove alpha spec for RAPIDS package {req.name}",
140+
).add_replacement(
141+
(node.start_mark.index, node.end_mark.index),
142+
str(
143+
(f"&{anchor} " if anchor else "")
144+
+ req.name
145+
+ create_specifier_string(
146+
{str(s) for s in req.specifier} - {ALPHA_SPECIFIER},
147+
)
148+
),
149+
)
150+
151+
152+
def check_packages(linter, args, anchors, used_anchors, node):
144153
if node_has_type(node, "seq"):
145154
for package_spec in node.value:
146-
check_package_spec(linter, args, package_spec)
155+
check_package_spec(linter, args, anchors, used_anchors, package_spec)
147156

148157

149-
def check_common(linter, args, node):
158+
def check_common(linter, args, anchors, used_anchors, node):
150159
if node_has_type(node, "seq"):
151160
for dependency_set in node.value:
152161
if node_has_type(dependency_set, "map"):
@@ -155,10 +164,12 @@ def check_common(linter, args, node):
155164
node_has_type(dependency_set_key, "str")
156165
and dependency_set_key.value == "packages"
157166
):
158-
check_packages(linter, args, dependency_set_value)
167+
check_packages(
168+
linter, args, anchors, used_anchors, dependency_set_value
169+
)
159170

160171

161-
def check_matrices(linter, args, node):
172+
def check_matrices(linter, args, anchors, used_anchors, node):
162173
if node_has_type(node, "seq"):
163174
for item in node.value:
164175
if node_has_type(item, "map"):
@@ -167,10 +178,12 @@ def check_matrices(linter, args, node):
167178
node_has_type(matrix_key, "str")
168179
and matrix_key.value == "packages"
169180
):
170-
check_packages(linter, args, matrix_value)
181+
check_packages(
182+
linter, args, anchors, used_anchors, matrix_value
183+
)
171184

172185

173-
def check_specific(linter, args, node):
186+
def check_specific(linter, args, anchors, used_anchors, node):
174187
if node_has_type(node, "seq"):
175188
for matrix_matcher in node.value:
176189
if node_has_type(matrix_matcher, "map"):
@@ -179,30 +192,66 @@ def check_specific(linter, args, node):
179192
node_has_type(matrix_matcher_key, "str")
180193
and matrix_matcher_key.value == "matrices"
181194
):
182-
check_matrices(linter, args, matrix_matcher_value)
195+
check_matrices(
196+
linter, args, anchors, used_anchors, matrix_matcher_value
197+
)
183198

184199

185-
def check_dependencies(linter, args, node):
200+
def check_dependencies(linter, args, anchors, used_anchors, node):
186201
if node_has_type(node, "map"):
187202
for _, dependencies_value in node.value:
188203
if node_has_type(dependencies_value, "map"):
189204
for dependency_key, dependency_value in dependencies_value.value:
190205
if node_has_type(dependency_key, "str"):
191206
if dependency_key.value == "common":
192-
check_common(linter, args, dependency_value)
207+
check_common(
208+
linter, args, anchors, used_anchors, dependency_value
209+
)
193210
elif dependency_key.value == "specific":
194-
check_specific(linter, args, dependency_value)
211+
check_specific(
212+
linter, args, anchors, used_anchors, dependency_value
213+
)
195214

196215

197-
def check_root(linter, args, node):
216+
def check_root(linter, args, anchors, used_anchors, node):
198217
if node_has_type(node, "map"):
199218
for root_key, root_value in node.value:
200219
if node_has_type(root_key, "str") and root_key.value == "dependencies":
201-
check_dependencies(linter, args, root_value)
220+
check_dependencies(linter, args, anchors, used_anchors, root_value)
221+
222+
223+
class AnchorPreservingLoader(yaml.SafeLoader):
224+
"""A SafeLoader that preserves the anchors for later reference. The anchors can
225+
be found in the document_anchors member, which is a list of dictionaries, one
226+
dictionary for each parsed document.
227+
"""
228+
229+
def __init__(self, stream):
230+
super().__init__(stream)
231+
self.document_anchors = []
232+
233+
def compose_document(self):
234+
# Drop the DOCUMENT-START event.
235+
self.get_event()
236+
237+
# Compose the root node.
238+
node = self.compose_node(None, None)
239+
240+
# Drop the DOCUMENT-END event.
241+
self.get_event()
242+
243+
self.document_anchors.append(self.anchors)
244+
self.anchors = {}
245+
return node
202246

203247

204248
def check_alpha_spec(linter, args):
205-
check_root(linter, args, yaml.compose(linter.content))
249+
loader = yaml.SafeLoader(linter.content)
250+
try:
251+
root = loader.get_single_node()
252+
finally:
253+
loader.dispose()
254+
check_root(linter, args, loader.document_anchors[0], set(), root)
206255

207256

208257
def main():

0 commit comments

Comments
 (0)