diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 58eca32066..1e3249a157 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -14,17 +14,17 @@ repos: - id: end-of-file-fixer - id: trailing-whitespace - repo: https://github.com/asottile/pyupgrade - rev: v3.19.0 + rev: v3.19.1 hooks: - id: pyupgrade - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.8.2 + rev: v0.9.6 hooks: - id: ruff args: [--fix] - id: ruff-format - repo: https://github.com/python-poetry/poetry - rev: 1.8.0 + rev: 1.8.5 hooks: - id: poetry-check - id: poetry-lock @@ -32,6 +32,6 @@ repos: - "--check" - "--no-update" - repo: https://github.com/gitleaks/gitleaks - rev: v8.21.2 + rev: v8.23.3 hooks: - id: gitleaks diff --git a/lerobot/common/datasets/factory.py b/lerobot/common/datasets/factory.py index 58ff400e46..95ba76b8c8 100644 --- a/lerobot/common/datasets/factory.py +++ b/lerobot/common/datasets/factory.py @@ -104,7 +104,7 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas ) logging.info( "Multiple datasets were provided. Applied the following index mapping to the provided datasets: " - f"{pformat(dataset.repo_id_to_index , indent=2)}" + f"{pformat(dataset.repo_id_to_index, indent=2)}" ) if cfg.dataset.use_imagenet_stats: diff --git a/lerobot/common/datasets/push_dataset_to_hub/dora_parquet_format.py b/lerobot/common/datasets/push_dataset_to_hub/dora_parquet_format.py index 95f9c00712..4968e0020b 100644 --- a/lerobot/common/datasets/push_dataset_to_hub/dora_parquet_format.py +++ b/lerobot/common/datasets/push_dataset_to_hub/dora_parquet_format.py @@ -72,7 +72,7 @@ def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episod # However, note that "nearest" might synchronize the reference camera with other cameras on slightly future timestamps. # are too far appart. direction="nearest", - tolerance=pd.Timedelta(f"{1/fps} seconds"), + tolerance=pd.Timedelta(f"{1 / fps} seconds"), ) # Remove rows with episode_index -1 which indicates data that correspond to in-between episodes df = df[df["episode_index"] != -1] diff --git a/lerobot/common/policies/act/modeling_act.py b/lerobot/common/policies/act/modeling_act.py index 9a1036c3a8..f2b16a1eb5 100644 --- a/lerobot/common/policies/act/modeling_act.py +++ b/lerobot/common/policies/act/modeling_act.py @@ -409,9 +409,9 @@ def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, tuple[Tensor, Tenso latent dimension. """ if self.config.use_vae and self.training: - assert ( - "action" in batch - ), "actions must be provided when using the variational objective in training mode." + assert "action" in batch, ( + "actions must be provided when using the variational objective in training mode." + ) batch_size = ( batch["observation.images"] diff --git a/lerobot/common/policies/diffusion/configuration_diffusion.py b/lerobot/common/policies/diffusion/configuration_diffusion.py index 31d5dc8b09..d571e15237 100644 --- a/lerobot/common/policies/diffusion/configuration_diffusion.py +++ b/lerobot/common/policies/diffusion/configuration_diffusion.py @@ -221,7 +221,7 @@ def validate_features(self) -> None: for key, image_ft in self.image_features.items(): if image_ft.shape != first_image_ft.shape: raise ValueError( - f"`{key}` does not match `{first_image_key}`, but we " "expect all image shapes to match." + f"`{key}` does not match `{first_image_key}`, but we expect all image shapes to match." ) @property diff --git a/lerobot/common/policies/tdmpc/modeling_tdmpc.py b/lerobot/common/policies/tdmpc/modeling_tdmpc.py index c4f90b8dd1..0940f19860 100644 --- a/lerobot/common/policies/tdmpc/modeling_tdmpc.py +++ b/lerobot/common/policies/tdmpc/modeling_tdmpc.py @@ -594,9 +594,9 @@ def _apply_fn(m): self.apply(_apply_fn) for m in [self._reward, *self._Qs]: - assert isinstance( - m[-1], nn.Linear - ), "Sanity check. The last linear layer needs 0 initialization on weights." + assert isinstance(m[-1], nn.Linear), ( + "Sanity check. The last linear layer needs 0 initialization on weights." + ) nn.init.zeros_(m[-1].weight) nn.init.zeros_(m[-1].bias) # this has already been done, but keep this line here for good measure diff --git a/lerobot/common/policies/vqbet/configuration_vqbet.py b/lerobot/common/policies/vqbet/configuration_vqbet.py index 47007e8231..59389d6e75 100644 --- a/lerobot/common/policies/vqbet/configuration_vqbet.py +++ b/lerobot/common/policies/vqbet/configuration_vqbet.py @@ -184,7 +184,7 @@ def validate_features(self) -> None: for key, image_ft in self.image_features.items(): if image_ft.shape != first_image_ft.shape: raise ValueError( - f"`{key}` does not match `{first_image_key}`, but we " "expect all image shapes to match." + f"`{key}` does not match `{first_image_key}`, but we expect all image shapes to match." ) @property diff --git a/lerobot/common/policies/vqbet/vqbet_utils.py b/lerobot/common/policies/vqbet/vqbet_utils.py index 90a2cfda37..a2bd2df3dc 100644 --- a/lerobot/common/policies/vqbet/vqbet_utils.py +++ b/lerobot/common/policies/vqbet/vqbet_utils.py @@ -203,9 +203,9 @@ def __init__(self, config: VQBeTConfig): def forward(self, input, targets=None): device = input.device b, t, d = input.size() - assert ( - t <= self.config.gpt_block_size - ), f"Cannot forward sequence of length {t}, block size is only {self.config.gpt_block_size}" + assert t <= self.config.gpt_block_size, ( + f"Cannot forward sequence of length {t}, block size is only {self.config.gpt_block_size}" + ) # positional encodings that are added to the input embeddings pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) # shape (1, t) @@ -273,10 +273,10 @@ def configure_parameters(self): assert len(inter_params) == 0, "parameters {} made it into both decay/no_decay sets!".format( str(inter_params) ) - assert ( - len(param_dict.keys() - union_params) == 0 - ), "parameters {} were not separated into either decay/no_decay set!".format( - str(param_dict.keys() - union_params), + assert len(param_dict.keys() - union_params) == 0, ( + "parameters {} were not separated into either decay/no_decay set!".format( + str(param_dict.keys() - union_params), + ) ) decay = [param_dict[pn] for pn in sorted(decay)] @@ -419,9 +419,9 @@ def get_codebook_vector_from_indices(self, indices): # and the network should be able to reconstruct if quantize_dim < self.num_quantizers: - assert ( - self.quantize_dropout > 0.0 - ), "quantize dropout must be greater than 0 if you wish to reconstruct from a signal with less fine quantizations" + assert self.quantize_dropout > 0.0, ( + "quantize dropout must be greater than 0 if you wish to reconstruct from a signal with less fine quantizations" + ) indices = F.pad(indices, (0, self.num_quantizers - quantize_dim), value=-1) # get ready for gathering @@ -472,9 +472,9 @@ def forward(self, x, indices=None, return_all_codes=False, sample_codebook_temp= all_indices = [] if return_loss: - assert not torch.any( - indices == -1 - ), "some of the residual vq indices were dropped out. please use indices derived when the module is in eval mode to derive cross entropy loss" + assert not torch.any(indices == -1), ( + "some of the residual vq indices were dropped out. please use indices derived when the module is in eval mode to derive cross entropy loss" + ) ce_losses = [] should_quantize_dropout = self.training and self.quantize_dropout and not return_loss @@ -887,9 +887,9 @@ def calculate_ce_loss(codes): # only calculate orthogonal loss for the activated codes for this batch if self.orthogonal_reg_active_codes_only: - assert not ( - is_multiheaded and self.separate_codebook_per_head - ), "orthogonal regularization for only active codes not compatible with multi-headed with separate codebooks yet" + assert not (is_multiheaded and self.separate_codebook_per_head), ( + "orthogonal regularization for only active codes not compatible with multi-headed with separate codebooks yet" + ) unique_code_ids = torch.unique(embed_ind) codebook = codebook[:, unique_code_ids] @@ -999,9 +999,9 @@ def gumbel_sample( ind = sampling_logits.argmax(dim=dim) one_hot = F.one_hot(ind, size).type(dtype) - assert not ( - reinmax and not straight_through - ), "reinmax can only be turned on if using straight through gumbel softmax" + assert not (reinmax and not straight_through), ( + "reinmax can only be turned on if using straight through gumbel softmax" + ) if not straight_through or temperature <= 0.0 or not training: return ind, one_hot @@ -1209,9 +1209,9 @@ def __init__( self.gumbel_sample = gumbel_sample self.sample_codebook_temp = sample_codebook_temp - assert not ( - use_ddp and num_codebooks > 1 and kmeans_init - ), "kmeans init is not compatible with multiple codebooks in distributed environment for now" + assert not (use_ddp and num_codebooks > 1 and kmeans_init), ( + "kmeans init is not compatible with multiple codebooks in distributed environment for now" + ) self.sample_fn = sample_vectors_distributed if use_ddp and sync_kmeans else batched_sample_vectors self.kmeans_all_reduce_fn = distributed.all_reduce if use_ddp and sync_kmeans else noop diff --git a/lerobot/common/robot_devices/control_utils.py b/lerobot/common/robot_devices/control_utils.py index 9368b89d64..7264f07813 100644 --- a/lerobot/common/robot_devices/control_utils.py +++ b/lerobot/common/robot_devices/control_utils.py @@ -33,7 +33,7 @@ def log_control_info(robot: Robot, dt_s, episode_index=None, frame_index=None, f def log_dt(shortname, dt_val_s): nonlocal log_items, fps - info_str = f"{shortname}:{dt_val_s * 1000:5.2f} ({1/ dt_val_s:3.1f}hz)" + info_str = f"{shortname}:{dt_val_s * 1000:5.2f} ({1 / dt_val_s:3.1f}hz)" if fps is not None: actual_fps = 1 / dt_val_s if actual_fps < fps - 1: diff --git a/lerobot/common/utils/io_utils.py b/lerobot/common/utils/io_utils.py index 3fc405f7ef..da0be1c771 100644 --- a/lerobot/common/utils/io_utils.py +++ b/lerobot/common/utils/io_utils.py @@ -58,7 +58,7 @@ def _deserialize(target, source): # Check that they have exactly the same set of keys. if target.keys() != source.keys(): raise ValueError( - f"Dictionary keys do not match.\n" f"Expected: {target.keys()}, got: {source.keys()}" + f"Dictionary keys do not match.\nExpected: {target.keys()}, got: {source.keys()}" ) # Recursively update each key. diff --git a/lerobot/scripts/visualize_dataset.py b/lerobot/scripts/visualize_dataset.py index ca17640723..626b0bde0e 100644 --- a/lerobot/scripts/visualize_dataset.py +++ b/lerobot/scripts/visualize_dataset.py @@ -111,9 +111,9 @@ def visualize_dataset( output_dir: Path | None = None, ) -> Path | None: if save: - assert ( - output_dir is not None - ), "Set an output directory where to write .rrd files with `--output-dir path/to/directory`." + assert output_dir is not None, ( + "Set an output directory where to write .rrd files with `--output-dir path/to/directory`." + ) repo_id = dataset.repo_id diff --git a/tests/scripts/save_dataset_to_safetensors.py b/tests/scripts/save_dataset_to_safetensors.py index 84c8f169ac..3b77348cb4 100644 --- a/tests/scripts/save_dataset_to_safetensors.py +++ b/tests/scripts/save_dataset_to_safetensors.py @@ -49,17 +49,17 @@ def save_dataset_to_safetensors(output_dir, repo_id="lerobot/pusht"): # save 2 first frames of first episode i = dataset.episode_data_index["from"][0].item() save_file(dataset[i], repo_dir / f"frame_{i}.safetensors") - save_file(dataset[i + 1], repo_dir / f"frame_{i+1}.safetensors") + save_file(dataset[i + 1], repo_dir / f"frame_{i + 1}.safetensors") # save 2 frames at the middle of first episode i = int((dataset.episode_data_index["to"][0].item() - dataset.episode_data_index["from"][0].item()) / 2) save_file(dataset[i], repo_dir / f"frame_{i}.safetensors") - save_file(dataset[i + 1], repo_dir / f"frame_{i+1}.safetensors") + save_file(dataset[i + 1], repo_dir / f"frame_{i + 1}.safetensors") # save 2 last frames of first episode i = dataset.episode_data_index["to"][0].item() - save_file(dataset[i - 2], repo_dir / f"frame_{i-2}.safetensors") - save_file(dataset[i - 1], repo_dir / f"frame_{i-1}.safetensors") + save_file(dataset[i - 2], repo_dir / f"frame_{i - 2}.safetensors") + save_file(dataset[i - 1], repo_dir / f"frame_{i - 1}.safetensors") # TODO(rcadene): Enable testing on second and last episode # We currently cant because our test dataset only contains the first episode diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 2945df4109..8664d33e5e 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -336,9 +336,9 @@ def load_and_compare(i): assert new_keys == old_keys, f"{new_keys=} and {old_keys=} are not the same" for key in new_frame: - assert torch.isclose( - new_frame[key], old_frame[key] - ).all(), f"{key=} for index={i} does not contain the same value" + assert torch.isclose(new_frame[key], old_frame[key]).all(), ( + f"{key=} for index={i} does not contain the same value" + ) # test2 first frames of first episode i = dataset.episode_data_index["from"][0].item() diff --git a/tests/test_image_transforms.py b/tests/test_image_transforms.py index c118018a73..19bd77df7e 100644 --- a/tests/test_image_transforms.py +++ b/tests/test_image_transforms.py @@ -343,13 +343,13 @@ def test_save_all_transforms(img_tensor_factory, tmp_path): # Check if the combined transforms directory exists and contains the right files combined_transforms_dir = tmp_path / "all" assert combined_transforms_dir.exists(), "Combined transforms directory was not created." - assert any( - combined_transforms_dir.iterdir() - ), "No transformed images found in combined transforms directory." + assert any(combined_transforms_dir.iterdir()), ( + "No transformed images found in combined transforms directory." + ) for i in range(1, n_examples + 1): - assert ( - combined_transforms_dir / f"{i}.png" - ).exists(), f"Combined transform image {i}.png was not found." + assert (combined_transforms_dir / f"{i}.png").exists(), ( + f"Combined transform image {i}.png was not found." + ) def test_save_each_transform(img_tensor_factory, tmp_path): @@ -369,6 +369,6 @@ def test_save_each_transform(img_tensor_factory, tmp_path): # Check for specific files within each transform directory expected_files = [f"{i}.png" for i in range(1, n_examples + 1)] + ["min.png", "max.png", "mean.png"] for file_name in expected_files: - assert ( - transform_dir / file_name - ).exists(), f"{file_name} was not found in {transform} directory." + assert (transform_dir / file_name).exists(), ( + f"{file_name} was not found in {transform} directory." + ) diff --git a/tests/test_online_buffer.py b/tests/test_online_buffer.py index 092cd3d085..db53808d2a 100644 --- a/tests/test_online_buffer.py +++ b/tests/test_online_buffer.py @@ -132,9 +132,9 @@ def test_fifo(): buffer.add_data(new_data) n_more_episodes = 2 # Developer sanity check (in case someone changes the global `buffer_capacity`). - assert ( - n_episodes + n_more_episodes - ) * n_frames_per_episode > buffer_capacity, "Something went wrong with the test code." + assert (n_episodes + n_more_episodes) * n_frames_per_episode > buffer_capacity, ( + "Something went wrong with the test code." + ) more_new_data = make_spoof_data_frames(n_more_episodes, n_frames_per_episode) buffer.add_data(more_new_data) assert len(buffer) == buffer_capacity, "The buffer should be full." @@ -203,9 +203,9 @@ def test_delta_timestamps_outside_tolerance_outside_episode_range(): item = buffer[2] data, is_pad = item["index"], item["index_is_pad"] assert torch.equal(data, torch.tensor([0, 0, 2, 4, 4])), "Data does not match expected values" - assert torch.equal( - is_pad, torch.tensor([True, False, False, True, True]) - ), "Padding does not match expected values" + assert torch.equal(is_pad, torch.tensor([True, False, False, True, True])), ( + "Padding does not match expected values" + ) # Arbitrarily set small dataset sizes, making sure to have uneven sizes. diff --git a/tests/test_policies.py b/tests/test_policies.py index 4374157de5..27cf49f883 100644 --- a/tests/test_policies.py +++ b/tests/test_policies.py @@ -193,12 +193,12 @@ def test_policy(ds_repo_id, env_name, env_kwargs, policy_name, policy_kwargs): observation_ = deepcopy(observation) with torch.inference_mode(): action = policy.select_action(observation).cpu().numpy() - assert set(observation) == set( - observation_ - ), "Observation batch keys are not the same after a forward pass." - assert all( - torch.equal(observation[k], observation_[k]) for k in observation - ), "Observation batch values are not the same after a forward pass." + assert set(observation) == set(observation_), ( + "Observation batch keys are not the same after a forward pass." + ) + assert all(torch.equal(observation[k], observation_[k]) for k in observation), ( + "Observation batch values are not the same after a forward pass." + ) # Test step through policy env.step(action)