Skip to content

Update pre-commits #733

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,24 +14,24 @@ repos:
- id: end-of-file-fixer
- id: trailing-whitespace
- repo: https://github.com/asottile/pyupgrade
rev: v3.19.0
rev: v3.19.1
hooks:
- id: pyupgrade
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.8.2
rev: v0.9.6
hooks:
- id: ruff
args: [--fix]
- id: ruff-format
- repo: https://github.com/python-poetry/poetry
rev: 1.8.0
rev: 1.8.5
hooks:
- id: poetry-check
- id: poetry-lock
args:
- "--check"
- "--no-update"
- repo: https://github.com/gitleaks/gitleaks
rev: v8.21.2
rev: v8.23.3
hooks:
- id: gitleaks
2 changes: 1 addition & 1 deletion lerobot/common/datasets/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas
)
logging.info(
"Multiple datasets were provided. Applied the following index mapping to the provided datasets: "
f"{pformat(dataset.repo_id_to_index , indent=2)}"
f"{pformat(dataset.repo_id_to_index, indent=2)}"
)

if cfg.dataset.use_imagenet_stats:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episod
# However, note that "nearest" might synchronize the reference camera with other cameras on slightly future timestamps.
# are too far appart.
direction="nearest",
tolerance=pd.Timedelta(f"{1/fps} seconds"),
tolerance=pd.Timedelta(f"{1 / fps} seconds"),
)
# Remove rows with episode_index -1 which indicates data that correspond to in-between episodes
df = df[df["episode_index"] != -1]
Expand Down
6 changes: 3 additions & 3 deletions lerobot/common/policies/act/modeling_act.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,9 +409,9 @@ def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, tuple[Tensor, Tenso
latent dimension.
"""
if self.config.use_vae and self.training:
assert (
"action" in batch
), "actions must be provided when using the variational objective in training mode."
assert "action" in batch, (
"actions must be provided when using the variational objective in training mode."
)

batch_size = (
batch["observation.images"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def validate_features(self) -> None:
for key, image_ft in self.image_features.items():
if image_ft.shape != first_image_ft.shape:
raise ValueError(
f"`{key}` does not match `{first_image_key}`, but we " "expect all image shapes to match."
f"`{key}` does not match `{first_image_key}`, but we expect all image shapes to match."
)

@property
Expand Down
6 changes: 3 additions & 3 deletions lerobot/common/policies/tdmpc/modeling_tdmpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,9 +594,9 @@ def _apply_fn(m):

self.apply(_apply_fn)
for m in [self._reward, *self._Qs]:
assert isinstance(
m[-1], nn.Linear
), "Sanity check. The last linear layer needs 0 initialization on weights."
assert isinstance(m[-1], nn.Linear), (
"Sanity check. The last linear layer needs 0 initialization on weights."
)
nn.init.zeros_(m[-1].weight)
nn.init.zeros_(m[-1].bias) # this has already been done, but keep this line here for good measure

Expand Down
2 changes: 1 addition & 1 deletion lerobot/common/policies/vqbet/configuration_vqbet.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def validate_features(self) -> None:
for key, image_ft in self.image_features.items():
if image_ft.shape != first_image_ft.shape:
raise ValueError(
f"`{key}` does not match `{first_image_key}`, but we " "expect all image shapes to match."
f"`{key}` does not match `{first_image_key}`, but we expect all image shapes to match."
)

@property
Expand Down
44 changes: 22 additions & 22 deletions lerobot/common/policies/vqbet/vqbet_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,9 +203,9 @@ def __init__(self, config: VQBeTConfig):
def forward(self, input, targets=None):
device = input.device
b, t, d = input.size()
assert (
t <= self.config.gpt_block_size
), f"Cannot forward sequence of length {t}, block size is only {self.config.gpt_block_size}"
assert t <= self.config.gpt_block_size, (
f"Cannot forward sequence of length {t}, block size is only {self.config.gpt_block_size}"
)

# positional encodings that are added to the input embeddings
pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) # shape (1, t)
Expand Down Expand Up @@ -273,10 +273,10 @@ def configure_parameters(self):
assert len(inter_params) == 0, "parameters {} made it into both decay/no_decay sets!".format(
str(inter_params)
)
assert (
len(param_dict.keys() - union_params) == 0
), "parameters {} were not separated into either decay/no_decay set!".format(
str(param_dict.keys() - union_params),
assert len(param_dict.keys() - union_params) == 0, (
"parameters {} were not separated into either decay/no_decay set!".format(
str(param_dict.keys() - union_params),
)
)

decay = [param_dict[pn] for pn in sorted(decay)]
Expand Down Expand Up @@ -419,9 +419,9 @@ def get_codebook_vector_from_indices(self, indices):
# and the network should be able to reconstruct

if quantize_dim < self.num_quantizers:
assert (
self.quantize_dropout > 0.0
), "quantize dropout must be greater than 0 if you wish to reconstruct from a signal with less fine quantizations"
assert self.quantize_dropout > 0.0, (
"quantize dropout must be greater than 0 if you wish to reconstruct from a signal with less fine quantizations"
)
indices = F.pad(indices, (0, self.num_quantizers - quantize_dim), value=-1)

# get ready for gathering
Expand Down Expand Up @@ -472,9 +472,9 @@ def forward(self, x, indices=None, return_all_codes=False, sample_codebook_temp=
all_indices = []

if return_loss:
assert not torch.any(
indices == -1
), "some of the residual vq indices were dropped out. please use indices derived when the module is in eval mode to derive cross entropy loss"
assert not torch.any(indices == -1), (
"some of the residual vq indices were dropped out. please use indices derived when the module is in eval mode to derive cross entropy loss"
)
ce_losses = []

should_quantize_dropout = self.training and self.quantize_dropout and not return_loss
Expand Down Expand Up @@ -887,9 +887,9 @@ def calculate_ce_loss(codes):
# only calculate orthogonal loss for the activated codes for this batch

if self.orthogonal_reg_active_codes_only:
assert not (
is_multiheaded and self.separate_codebook_per_head
), "orthogonal regularization for only active codes not compatible with multi-headed with separate codebooks yet"
assert not (is_multiheaded and self.separate_codebook_per_head), (
"orthogonal regularization for only active codes not compatible with multi-headed with separate codebooks yet"
)
unique_code_ids = torch.unique(embed_ind)
codebook = codebook[:, unique_code_ids]

Expand Down Expand Up @@ -999,9 +999,9 @@ def gumbel_sample(
ind = sampling_logits.argmax(dim=dim)
one_hot = F.one_hot(ind, size).type(dtype)

assert not (
reinmax and not straight_through
), "reinmax can only be turned on if using straight through gumbel softmax"
assert not (reinmax and not straight_through), (
"reinmax can only be turned on if using straight through gumbel softmax"
)

if not straight_through or temperature <= 0.0 or not training:
return ind, one_hot
Expand Down Expand Up @@ -1209,9 +1209,9 @@ def __init__(
self.gumbel_sample = gumbel_sample
self.sample_codebook_temp = sample_codebook_temp

assert not (
use_ddp and num_codebooks > 1 and kmeans_init
), "kmeans init is not compatible with multiple codebooks in distributed environment for now"
assert not (use_ddp and num_codebooks > 1 and kmeans_init), (
"kmeans init is not compatible with multiple codebooks in distributed environment for now"
)

self.sample_fn = sample_vectors_distributed if use_ddp and sync_kmeans else batched_sample_vectors
self.kmeans_all_reduce_fn = distributed.all_reduce if use_ddp and sync_kmeans else noop
Expand Down
2 changes: 1 addition & 1 deletion lerobot/common/robot_devices/control_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def log_control_info(robot: Robot, dt_s, episode_index=None, frame_index=None, f

def log_dt(shortname, dt_val_s):
nonlocal log_items, fps
info_str = f"{shortname}:{dt_val_s * 1000:5.2f} ({1/ dt_val_s:3.1f}hz)"
info_str = f"{shortname}:{dt_val_s * 1000:5.2f} ({1 / dt_val_s:3.1f}hz)"
if fps is not None:
actual_fps = 1 / dt_val_s
if actual_fps < fps - 1:
Expand Down
2 changes: 1 addition & 1 deletion lerobot/common/utils/io_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def _deserialize(target, source):
# Check that they have exactly the same set of keys.
if target.keys() != source.keys():
raise ValueError(
f"Dictionary keys do not match.\n" f"Expected: {target.keys()}, got: {source.keys()}"
f"Dictionary keys do not match.\nExpected: {target.keys()}, got: {source.keys()}"
)

# Recursively update each key.
Expand Down
6 changes: 3 additions & 3 deletions lerobot/scripts/visualize_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,9 @@ def visualize_dataset(
output_dir: Path | None = None,
) -> Path | None:
if save:
assert (
output_dir is not None
), "Set an output directory where to write .rrd files with `--output-dir path/to/directory`."
assert output_dir is not None, (
"Set an output directory where to write .rrd files with `--output-dir path/to/directory`."
)

repo_id = dataset.repo_id

Expand Down
8 changes: 4 additions & 4 deletions tests/scripts/save_dataset_to_safetensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,17 +49,17 @@ def save_dataset_to_safetensors(output_dir, repo_id="lerobot/pusht"):
# save 2 first frames of first episode
i = dataset.episode_data_index["from"][0].item()
save_file(dataset[i], repo_dir / f"frame_{i}.safetensors")
save_file(dataset[i + 1], repo_dir / f"frame_{i+1}.safetensors")
save_file(dataset[i + 1], repo_dir / f"frame_{i + 1}.safetensors")

# save 2 frames at the middle of first episode
i = int((dataset.episode_data_index["to"][0].item() - dataset.episode_data_index["from"][0].item()) / 2)
save_file(dataset[i], repo_dir / f"frame_{i}.safetensors")
save_file(dataset[i + 1], repo_dir / f"frame_{i+1}.safetensors")
save_file(dataset[i + 1], repo_dir / f"frame_{i + 1}.safetensors")

# save 2 last frames of first episode
i = dataset.episode_data_index["to"][0].item()
save_file(dataset[i - 2], repo_dir / f"frame_{i-2}.safetensors")
save_file(dataset[i - 1], repo_dir / f"frame_{i-1}.safetensors")
save_file(dataset[i - 2], repo_dir / f"frame_{i - 2}.safetensors")
save_file(dataset[i - 1], repo_dir / f"frame_{i - 1}.safetensors")

# TODO(rcadene): Enable testing on second and last episode
# We currently cant because our test dataset only contains the first episode
Expand Down
6 changes: 3 additions & 3 deletions tests/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,9 +336,9 @@ def load_and_compare(i):
assert new_keys == old_keys, f"{new_keys=} and {old_keys=} are not the same"

for key in new_frame:
assert torch.isclose(
new_frame[key], old_frame[key]
).all(), f"{key=} for index={i} does not contain the same value"
assert torch.isclose(new_frame[key], old_frame[key]).all(), (
f"{key=} for index={i} does not contain the same value"
)

# test2 first frames of first episode
i = dataset.episode_data_index["from"][0].item()
Expand Down
18 changes: 9 additions & 9 deletions tests/test_image_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,13 +343,13 @@ def test_save_all_transforms(img_tensor_factory, tmp_path):
# Check if the combined transforms directory exists and contains the right files
combined_transforms_dir = tmp_path / "all"
assert combined_transforms_dir.exists(), "Combined transforms directory was not created."
assert any(
combined_transforms_dir.iterdir()
), "No transformed images found in combined transforms directory."
assert any(combined_transforms_dir.iterdir()), (
"No transformed images found in combined transforms directory."
)
for i in range(1, n_examples + 1):
assert (
combined_transforms_dir / f"{i}.png"
).exists(), f"Combined transform image {i}.png was not found."
assert (combined_transforms_dir / f"{i}.png").exists(), (
f"Combined transform image {i}.png was not found."
)


def test_save_each_transform(img_tensor_factory, tmp_path):
Expand All @@ -369,6 +369,6 @@ def test_save_each_transform(img_tensor_factory, tmp_path):
# Check for specific files within each transform directory
expected_files = [f"{i}.png" for i in range(1, n_examples + 1)] + ["min.png", "max.png", "mean.png"]
for file_name in expected_files:
assert (
transform_dir / file_name
).exists(), f"{file_name} was not found in {transform} directory."
assert (transform_dir / file_name).exists(), (
f"{file_name} was not found in {transform} directory."
)
12 changes: 6 additions & 6 deletions tests/test_online_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,9 +132,9 @@ def test_fifo():
buffer.add_data(new_data)
n_more_episodes = 2
# Developer sanity check (in case someone changes the global `buffer_capacity`).
assert (
n_episodes + n_more_episodes
) * n_frames_per_episode > buffer_capacity, "Something went wrong with the test code."
assert (n_episodes + n_more_episodes) * n_frames_per_episode > buffer_capacity, (
"Something went wrong with the test code."
)
more_new_data = make_spoof_data_frames(n_more_episodes, n_frames_per_episode)
buffer.add_data(more_new_data)
assert len(buffer) == buffer_capacity, "The buffer should be full."
Expand Down Expand Up @@ -203,9 +203,9 @@ def test_delta_timestamps_outside_tolerance_outside_episode_range():
item = buffer[2]
data, is_pad = item["index"], item["index_is_pad"]
assert torch.equal(data, torch.tensor([0, 0, 2, 4, 4])), "Data does not match expected values"
assert torch.equal(
is_pad, torch.tensor([True, False, False, True, True])
), "Padding does not match expected values"
assert torch.equal(is_pad, torch.tensor([True, False, False, True, True])), (
"Padding does not match expected values"
)


# Arbitrarily set small dataset sizes, making sure to have uneven sizes.
Expand Down
12 changes: 6 additions & 6 deletions tests/test_policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,12 +193,12 @@ def test_policy(ds_repo_id, env_name, env_kwargs, policy_name, policy_kwargs):
observation_ = deepcopy(observation)
with torch.inference_mode():
action = policy.select_action(observation).cpu().numpy()
assert set(observation) == set(
observation_
), "Observation batch keys are not the same after a forward pass."
assert all(
torch.equal(observation[k], observation_[k]) for k in observation
), "Observation batch values are not the same after a forward pass."
assert set(observation) == set(observation_), (
"Observation batch keys are not the same after a forward pass."
)
assert all(torch.equal(observation[k], observation_[k]) for k in observation), (
"Observation batch values are not the same after a forward pass."
)

# Test step through policy
env.step(action)
Expand Down