|
14 | 14 | # See the License for the specific language governing permissions and
|
15 | 15 | # limitations under the License.
|
16 | 16 | import collections
|
17 |
| -from typing import Any, Callable, Dict, Sequence |
| 17 | +from typing import Any, Callable, Sequence |
18 | 18 |
|
19 | 19 | import torch
|
20 | 20 | from torchvision.transforms import v2
|
@@ -129,11 +129,12 @@ def _check_input(self, sharpness):
|
129 | 129 |
|
130 | 130 | return float(sharpness[0]), float(sharpness[1])
|
131 | 131 |
|
132 |
| - def _generate_value(self, left: float, right: float) -> float: |
133 |
| - return torch.empty(1).uniform_(left, right).item() |
| 132 | + def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]: |
| 133 | + sharpness_factor = torch.empty(1).uniform_(self.sharpness[0], self.sharpness[1]).item() |
| 134 | + return {"sharpness_factor": sharpness_factor} |
134 | 135 |
|
135 |
| - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: |
136 |
| - sharpness_factor = self._generate_value(self.sharpness[0], self.sharpness[1]) |
| 136 | + def transform(self, inpt: Any, params: dict[str, Any]) -> Any: |
| 137 | + sharpness_factor = params["sharpness_factor"] |
137 | 138 | return self._call_kernel(F.adjust_sharpness, inpt, sharpness_factor=sharpness_factor)
|
138 | 139 |
|
139 | 140 |
|
|
0 commit comments