@@ -28,9 +28,17 @@ def __init__(self, all_entry_points):
28
28
super ().__init__ (msg )
29
29
30
30
31
- def default_select (identifier , all_entry_points ): # pylint: disable=inconsistent-return-statements
31
+ class AmbiguousPluginOverrideError (AmbiguousPluginError ):
32
+ """Raised when a class name produces more than one override for an entry_point."""
33
+
34
+
35
+ def _default_select_no_override (identifier , all_entry_points ): # pylint: disable=inconsistent-return-statements
32
36
"""
33
- Raise an exception when we have ambiguous entry points.
37
+ Selects plugin for the given identifier, raising on error:
38
+
39
+ Raises:
40
+ - PluginMissingError when we don't have an entry point.
41
+ - AmbiguousPluginError when we have ambiguous entry points.
34
42
"""
35
43
36
44
if len (all_entry_points ) == 0 :
@@ -41,6 +49,37 @@ def default_select(identifier, all_entry_points): # pylint: disable=inconsisten
41
49
raise AmbiguousPluginError (all_entry_points )
42
50
43
51
52
+ def default_select (identifier , all_entry_points ):
53
+ """
54
+ Selects plugin for the given identifier with the ability for a Plugin to override
55
+ the default entry point.
56
+
57
+ Raises:
58
+ - PluginMissingError when we don't have an entry point or entry point to override.
59
+ - AmbiguousPluginError when we have ambiguous entry points.
60
+ """
61
+
62
+ # Split entry points into overrides and non-overrides
63
+ overrides = []
64
+ block_entry_points = []
65
+
66
+ for block_entry_point in all_entry_points :
67
+ if block_entry_point .group .endswith ('.overrides' ):
68
+ overrides .append (block_entry_point )
69
+ else :
70
+ block_entry_points .append (block_entry_point )
71
+
72
+ # Get the default entry point
73
+ default_plugin = _default_select_no_override (identifier , block_entry_points )
74
+
75
+ # If we have an unambiguous override, that gets priority. Otherwise, return default.
76
+ if len (overrides ) == 1 :
77
+ return overrides [0 ]
78
+ elif len (overrides ) > 1 :
79
+ raise AmbiguousPluginOverrideError (overrides )
80
+ return default_plugin
81
+
82
+
44
83
class Plugin :
45
84
"""Base class for a system that uses entry_points to load plugins.
46
85
@@ -75,12 +114,20 @@ def _load_class_entry_point(cls, entry_point):
75
114
def load_class (cls , identifier , default = None , select = None ):
76
115
"""Load a single class specified by identifier.
77
116
78
- If `identifier` specifies more than a single class, and `select` is not None,
79
- then call `select` on the list of entry_points. Otherwise, choose
80
- the first one and log a warning.
117
+ By default, this returns the class mapped to `identifier` from entry_points
118
+ matching `{cls.entry_points}.overrides` or `{cls.entry_points}`, in that order.
81
119
82
- If `default` is provided, return it if no entry_point matching
83
- `identifier` is found. Otherwise, will raise a PluginMissingError
120
+ If multiple classes are found for either `{cls.entry_points}.overrides` or
121
+ `{cls.entry_points}`, it will raise an `AmbiguousPluginError`.
122
+
123
+ If no classes are found for `{cls.entry_points}`, it will raise a `PluginMissingError`.
124
+
125
+ Args:
126
+ - identifier: The class to match on.
127
+
128
+ Kwargs:
129
+ - default: A class to return if no entry_point matching `identifier` is found.
130
+ - select: A function to override our default_select functionality.
84
131
85
132
If `select` is provided, it should be a callable of the form::
86
133
@@ -100,7 +147,11 @@ def select(identifier, all_entry_points):
100
147
if select is None :
101
148
select = default_select
102
149
103
- all_entry_points = list (importlib .metadata .entry_points (group = cls .entry_point , name = identifier ))
150
+ all_entry_points = [
151
+ * importlib .metadata .entry_points (group = f'{ cls .entry_point } .overrides' , name = identifier ),
152
+ * importlib .metadata .entry_points (group = cls .entry_point , name = identifier )
153
+ ]
154
+
104
155
for extra_identifier , extra_entry_point in iter (cls .extra_entry_points ):
105
156
if identifier == extra_identifier :
106
157
all_entry_points .append (extra_entry_point )
@@ -146,7 +197,7 @@ def load_classes(cls, fail_silently=True):
146
197
raise
147
198
148
199
@classmethod
149
- def register_temp_plugin (cls , class_ , identifier = None , dist = 'xblock' ):
200
+ def register_temp_plugin (cls , class_ , identifier = None , dist = 'xblock' , group = 'xblock.v1' ):
150
201
"""Decorate a function to run with a temporary plugin available.
151
202
152
203
Use it like this in tests::
@@ -164,6 +215,7 @@ def test_the_thing():
164
215
entry_point = Mock (
165
216
dist = Mock (key = dist ),
166
217
load = Mock (return_value = class_ ),
218
+ group = group
167
219
)
168
220
entry_point .name = identifier
169
221
0 commit comments