Skip to content

Commit 1ee1acf

Browse files
authored
Comply with torchvision 0.21 custom transforms (#665)
1 parent c4d912a commit 1ee1acf

File tree

3 files changed

+2761
-2598
lines changed

3 files changed

+2761
-2598
lines changed

lerobot/common/datasets/transforms.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616
import collections
17-
from typing import Any, Callable, Dict, Sequence
17+
from typing import Any, Callable, Sequence
1818

1919
import torch
2020
from torchvision.transforms import v2
@@ -129,11 +129,12 @@ def _check_input(self, sharpness):
129129

130130
return float(sharpness[0]), float(sharpness[1])
131131

132-
def _generate_value(self, left: float, right: float) -> float:
133-
return torch.empty(1).uniform_(left, right).item()
132+
def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]:
133+
sharpness_factor = torch.empty(1).uniform_(self.sharpness[0], self.sharpness[1]).item()
134+
return {"sharpness_factor": sharpness_factor}
134135

135-
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
136-
sharpness_factor = self._generate_value(self.sharpness[0], self.sharpness[1])
136+
def transform(self, inpt: Any, params: dict[str, Any]) -> Any:
137+
sharpness_factor = params["sharpness_factor"]
137138
return self._call_kernel(F.adjust_sharpness, inpt, sharpness_factor=sharpness_factor)
138139

139140

0 commit comments

Comments
 (0)