|
6 | 6 |
|
7 | 7 |
|
8 | 8 | class ArrayFeatureExtractorTranslator(Translator):
|
9 |
| - """Processes an ArgMax node and updates the variables with the output expression.""" |
10 |
| - |
11 |
| - |
12 |
| - def process(self): |
| 9 | + """Processes an ArrayFeatureExtractor node and updates the variables with the output expression. |
| 10 | + |
| 11 | + ArrayFeatureExtractor can be considered the opposit of :class:`ConactTranslator`, as |
| 12 | + in most cases it will be used to pick one or more features out of a group of column |
| 13 | + previously concatenated, or to pick a specific feature out of the result of an ArgMax operation. |
| 14 | +
|
| 15 | + The provided indices always refer to the **last** axis of the input tensor. |
| 16 | + If the input is a 2D tensor, the last axis is the column axis. So an index |
| 17 | + of ``0`` would mean the first column. If the input is a 1D tensor instead the |
| 18 | + last axis is the row axis. So an index of ``0`` would mean the first row. |
| 19 | +
|
| 20 | + This could be confusing because axis are inverted between tensors and mustela column groups. |
| 21 | + In the case of Tensors, axis=0 means row=0, while instead of mustela |
| 22 | + column groups (by virtue of being a group of columns), axis=0 means |
| 23 | + the first column. |
| 24 | +
|
| 25 | + We have to consider that the indices we receive, in case of column groups, |
| 26 | + are actually column indices, not row indices as in case of a tensor, |
| 27 | + the last index would be the column index. In case of single columns, |
| 28 | + instead the index is the index of a row like it would be with a 1D tensor. |
| 29 | + """ |
| 30 | + def process(self) -> None: |
| 31 | + """Performs the translation and set the output variable.""" |
13 | 32 | # https://onnx.ai/onnx/operators/onnx_aionnxml_ArrayFeatureExtractor.html
|
14 | 33 |
|
15 |
| - # Given an array of features, grab only one of them |
16 |
| - # This probably is used to extract a single feature from a list of features |
17 |
| - # Previously made by Concat. |
18 |
| - # Or to pick the right feature from the result of ArgMax |
19 | 34 | data = self._variables.consume(self.inputs[0])
|
20 | 35 | indices = self._variables.consume(self.inputs[1])
|
21 | 36 |
|
22 |
| - data_keys = None |
23 |
| - if isinstance(data, dict): |
24 |
| - # This expects that dictionaries are sorted by insertion order |
25 |
| - # AND that all values of the dictionary are featues with dim_value: 1 |
26 |
| - # TODO: Implement a class for Concatenaed values |
27 |
| - # that implements support based on dimensions |
28 |
| - data_keys = list(data.keys()) |
29 |
| - data = list(data.values()) |
| 37 | + if not isinstance(data, dict): |
| 38 | + # TODO: Implement support for selecting rows from a 1D tensor |
| 39 | + raise NotImplementedError("ArrayFeatureExtractor only supports column groups as inputs") |
| 40 | + |
| 41 | + # This expects that dictionaries are sorted by insertion order |
| 42 | + # AND that all values of the dictionary are columns. |
| 43 | + data_keys = list(data.keys()) |
| 44 | + data = list(data.values()) |
30 | 45 |
|
31 | 46 | if isinstance(indices, (list, tuple)):
|
32 |
| - # We only work with dictionaries of faturename: feature |
33 |
| - # So when we are expected to output a list of features |
34 |
| - # we should output a dictionary of features as they are just sorted. |
| 47 | + if data_keys is None: |
| 48 | + raise ValueError("ArrayFeatureExtractor expects a group of columns as input when receiving a list of indices") |
| 49 | + if len(indices) > len(data_keys): |
| 50 | + raise ValueError("Indices requested are more than the available numer of columns.") |
| 51 | + # Pick only the columns that are in the list of indicies. |
35 | 52 | result = {data_keys[i]: data[i] for i in indices}
|
36 |
| - elif isinstance(indices, int): |
37 |
| - result = data[indices] |
38 | 53 | elif isinstance(indices, ibis.expr.types.Column):
|
39 | 54 | # The indices that we need to pick are contained in
|
40 | 55 | # another column of the table.
|
|
0 commit comments