Skip to content

Commit c4c2ce0

Browse files
authored
Update pre-commits (#733)
1 parent 2cb0bf5 commit c4c2ce0

16 files changed

+69
-69
lines changed

.pre-commit-config.yaml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,25 +14,25 @@ repos:
1414
- id: end-of-file-fixer
1515
- id: trailing-whitespace
1616
- repo: https://github.com/asottile/pyupgrade
17-
rev: v3.19.0
17+
rev: v3.19.1
1818
hooks:
1919
- id: pyupgrade
2020
- repo: https://github.com/astral-sh/ruff-pre-commit
21-
rev: v0.8.2
21+
rev: v0.9.6
2222
hooks:
2323
- id: ruff
2424
args: [--fix]
2525
- id: ruff-format
2626
- repo: https://github.com/python-poetry/poetry
27-
rev: 1.8.0
27+
rev: 1.8.5
2828
hooks:
2929
- id: poetry-check
3030
- id: poetry-lock
3131
args:
3232
- "--check"
3333
- "--no-update"
3434
- repo: https://github.com/gitleaks/gitleaks
35-
rev: v8.21.2
35+
rev: v8.23.3
3636
hooks:
3737
- id: gitleaks
3838
- repo: https://github.com/woodruffw/zizmor-pre-commit

lerobot/common/datasets/factory.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas
104104
)
105105
logging.info(
106106
"Multiple datasets were provided. Applied the following index mapping to the provided datasets: "
107-
f"{pformat(dataset.repo_id_to_index , indent=2)}"
107+
f"{pformat(dataset.repo_id_to_index, indent=2)}"
108108
)
109109

110110
if cfg.dataset.use_imagenet_stats:

lerobot/common/datasets/push_dataset_to_hub/dora_parquet_format.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episod
7272
# However, note that "nearest" might synchronize the reference camera with other cameras on slightly future timestamps.
7373
# are too far appart.
7474
direction="nearest",
75-
tolerance=pd.Timedelta(f"{1/fps} seconds"),
75+
tolerance=pd.Timedelta(f"{1 / fps} seconds"),
7676
)
7777
# Remove rows with episode_index -1 which indicates data that correspond to in-between episodes
7878
df = df[df["episode_index"] != -1]

lerobot/common/policies/act/modeling_act.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -409,9 +409,9 @@ def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, tuple[Tensor, Tenso
409409
latent dimension.
410410
"""
411411
if self.config.use_vae and self.training:
412-
assert (
413-
"action" in batch
414-
), "actions must be provided when using the variational objective in training mode."
412+
assert "action" in batch, (
413+
"actions must be provided when using the variational objective in training mode."
414+
)
415415

416416
batch_size = (
417417
batch["observation.images"]

lerobot/common/policies/diffusion/configuration_diffusion.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ def validate_features(self) -> None:
221221
for key, image_ft in self.image_features.items():
222222
if image_ft.shape != first_image_ft.shape:
223223
raise ValueError(
224-
f"`{key}` does not match `{first_image_key}`, but we " "expect all image shapes to match."
224+
f"`{key}` does not match `{first_image_key}`, but we expect all image shapes to match."
225225
)
226226

227227
@property

lerobot/common/policies/tdmpc/modeling_tdmpc.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -594,9 +594,9 @@ def _apply_fn(m):
594594

595595
self.apply(_apply_fn)
596596
for m in [self._reward, *self._Qs]:
597-
assert isinstance(
598-
m[-1], nn.Linear
599-
), "Sanity check. The last linear layer needs 0 initialization on weights."
597+
assert isinstance(m[-1], nn.Linear), (
598+
"Sanity check. The last linear layer needs 0 initialization on weights."
599+
)
600600
nn.init.zeros_(m[-1].weight)
601601
nn.init.zeros_(m[-1].bias) # this has already been done, but keep this line here for good measure
602602

lerobot/common/policies/vqbet/configuration_vqbet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ def validate_features(self) -> None:
184184
for key, image_ft in self.image_features.items():
185185
if image_ft.shape != first_image_ft.shape:
186186
raise ValueError(
187-
f"`{key}` does not match `{first_image_key}`, but we " "expect all image shapes to match."
187+
f"`{key}` does not match `{first_image_key}`, but we expect all image shapes to match."
188188
)
189189

190190
@property

lerobot/common/policies/vqbet/vqbet_utils.py

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -203,9 +203,9 @@ def __init__(self, config: VQBeTConfig):
203203
def forward(self, input, targets=None):
204204
device = input.device
205205
b, t, d = input.size()
206-
assert (
207-
t <= self.config.gpt_block_size
208-
), f"Cannot forward sequence of length {t}, block size is only {self.config.gpt_block_size}"
206+
assert t <= self.config.gpt_block_size, (
207+
f"Cannot forward sequence of length {t}, block size is only {self.config.gpt_block_size}"
208+
)
209209

210210
# positional encodings that are added to the input embeddings
211211
pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) # shape (1, t)
@@ -273,10 +273,10 @@ def configure_parameters(self):
273273
assert len(inter_params) == 0, "parameters {} made it into both decay/no_decay sets!".format(
274274
str(inter_params)
275275
)
276-
assert (
277-
len(param_dict.keys() - union_params) == 0
278-
), "parameters {} were not separated into either decay/no_decay set!".format(
279-
str(param_dict.keys() - union_params),
276+
assert len(param_dict.keys() - union_params) == 0, (
277+
"parameters {} were not separated into either decay/no_decay set!".format(
278+
str(param_dict.keys() - union_params),
279+
)
280280
)
281281

282282
decay = [param_dict[pn] for pn in sorted(decay)]
@@ -419,9 +419,9 @@ def get_codebook_vector_from_indices(self, indices):
419419
# and the network should be able to reconstruct
420420

421421
if quantize_dim < self.num_quantizers:
422-
assert (
423-
self.quantize_dropout > 0.0
424-
), "quantize dropout must be greater than 0 if you wish to reconstruct from a signal with less fine quantizations"
422+
assert self.quantize_dropout > 0.0, (
423+
"quantize dropout must be greater than 0 if you wish to reconstruct from a signal with less fine quantizations"
424+
)
425425
indices = F.pad(indices, (0, self.num_quantizers - quantize_dim), value=-1)
426426

427427
# get ready for gathering
@@ -472,9 +472,9 @@ def forward(self, x, indices=None, return_all_codes=False, sample_codebook_temp=
472472
all_indices = []
473473

474474
if return_loss:
475-
assert not torch.any(
476-
indices == -1
477-
), "some of the residual vq indices were dropped out. please use indices derived when the module is in eval mode to derive cross entropy loss"
475+
assert not torch.any(indices == -1), (
476+
"some of the residual vq indices were dropped out. please use indices derived when the module is in eval mode to derive cross entropy loss"
477+
)
478478
ce_losses = []
479479

480480
should_quantize_dropout = self.training and self.quantize_dropout and not return_loss
@@ -887,9 +887,9 @@ def calculate_ce_loss(codes):
887887
# only calculate orthogonal loss for the activated codes for this batch
888888

889889
if self.orthogonal_reg_active_codes_only:
890-
assert not (
891-
is_multiheaded and self.separate_codebook_per_head
892-
), "orthogonal regularization for only active codes not compatible with multi-headed with separate codebooks yet"
890+
assert not (is_multiheaded and self.separate_codebook_per_head), (
891+
"orthogonal regularization for only active codes not compatible with multi-headed with separate codebooks yet"
892+
)
893893
unique_code_ids = torch.unique(embed_ind)
894894
codebook = codebook[:, unique_code_ids]
895895

@@ -999,9 +999,9 @@ def gumbel_sample(
999999
ind = sampling_logits.argmax(dim=dim)
10001000
one_hot = F.one_hot(ind, size).type(dtype)
10011001

1002-
assert not (
1003-
reinmax and not straight_through
1004-
), "reinmax can only be turned on if using straight through gumbel softmax"
1002+
assert not (reinmax and not straight_through), (
1003+
"reinmax can only be turned on if using straight through gumbel softmax"
1004+
)
10051005

10061006
if not straight_through or temperature <= 0.0 or not training:
10071007
return ind, one_hot
@@ -1209,9 +1209,9 @@ def __init__(
12091209
self.gumbel_sample = gumbel_sample
12101210
self.sample_codebook_temp = sample_codebook_temp
12111211

1212-
assert not (
1213-
use_ddp and num_codebooks > 1 and kmeans_init
1214-
), "kmeans init is not compatible with multiple codebooks in distributed environment for now"
1212+
assert not (use_ddp and num_codebooks > 1 and kmeans_init), (
1213+
"kmeans init is not compatible with multiple codebooks in distributed environment for now"
1214+
)
12151215

12161216
self.sample_fn = sample_vectors_distributed if use_ddp and sync_kmeans else batched_sample_vectors
12171217
self.kmeans_all_reduce_fn = distributed.all_reduce if use_ddp and sync_kmeans else noop

lerobot/common/robot_devices/control_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def log_control_info(robot: Robot, dt_s, episode_index=None, frame_index=None, f
3333

3434
def log_dt(shortname, dt_val_s):
3535
nonlocal log_items, fps
36-
info_str = f"{shortname}:{dt_val_s * 1000:5.2f} ({1/ dt_val_s:3.1f}hz)"
36+
info_str = f"{shortname}:{dt_val_s * 1000:5.2f} ({1 / dt_val_s:3.1f}hz)"
3737
if fps is not None:
3838
actual_fps = 1 / dt_val_s
3939
if actual_fps < fps - 1:

lerobot/common/utils/io_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def _deserialize(target, source):
5858
# Check that they have exactly the same set of keys.
5959
if target.keys() != source.keys():
6060
raise ValueError(
61-
f"Dictionary keys do not match.\n" f"Expected: {target.keys()}, got: {source.keys()}"
61+
f"Dictionary keys do not match.\nExpected: {target.keys()}, got: {source.keys()}"
6262
)
6363

6464
# Recursively update each key.

lerobot/scripts/visualize_dataset.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -111,9 +111,9 @@ def visualize_dataset(
111111
output_dir: Path | None = None,
112112
) -> Path | None:
113113
if save:
114-
assert (
115-
output_dir is not None
116-
), "Set an output directory where to write .rrd files with `--output-dir path/to/directory`."
114+
assert output_dir is not None, (
115+
"Set an output directory where to write .rrd files with `--output-dir path/to/directory`."
116+
)
117117

118118
repo_id = dataset.repo_id
119119

tests/scripts/save_dataset_to_safetensors.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,17 +49,17 @@ def save_dataset_to_safetensors(output_dir, repo_id="lerobot/pusht"):
4949
# save 2 first frames of first episode
5050
i = dataset.episode_data_index["from"][0].item()
5151
save_file(dataset[i], repo_dir / f"frame_{i}.safetensors")
52-
save_file(dataset[i + 1], repo_dir / f"frame_{i+1}.safetensors")
52+
save_file(dataset[i + 1], repo_dir / f"frame_{i + 1}.safetensors")
5353

5454
# save 2 frames at the middle of first episode
5555
i = int((dataset.episode_data_index["to"][0].item() - dataset.episode_data_index["from"][0].item()) / 2)
5656
save_file(dataset[i], repo_dir / f"frame_{i}.safetensors")
57-
save_file(dataset[i + 1], repo_dir / f"frame_{i+1}.safetensors")
57+
save_file(dataset[i + 1], repo_dir / f"frame_{i + 1}.safetensors")
5858

5959
# save 2 last frames of first episode
6060
i = dataset.episode_data_index["to"][0].item()
61-
save_file(dataset[i - 2], repo_dir / f"frame_{i-2}.safetensors")
62-
save_file(dataset[i - 1], repo_dir / f"frame_{i-1}.safetensors")
61+
save_file(dataset[i - 2], repo_dir / f"frame_{i - 2}.safetensors")
62+
save_file(dataset[i - 1], repo_dir / f"frame_{i - 1}.safetensors")
6363

6464
# TODO(rcadene): Enable testing on second and last episode
6565
# We currently cant because our test dataset only contains the first episode

tests/test_datasets.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -336,9 +336,9 @@ def load_and_compare(i):
336336
assert new_keys == old_keys, f"{new_keys=} and {old_keys=} are not the same"
337337

338338
for key in new_frame:
339-
assert torch.isclose(
340-
new_frame[key], old_frame[key]
341-
).all(), f"{key=} for index={i} does not contain the same value"
339+
assert torch.isclose(new_frame[key], old_frame[key]).all(), (
340+
f"{key=} for index={i} does not contain the same value"
341+
)
342342

343343
# test2 first frames of first episode
344344
i = dataset.episode_data_index["from"][0].item()

tests/test_image_transforms.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -343,13 +343,13 @@ def test_save_all_transforms(img_tensor_factory, tmp_path):
343343
# Check if the combined transforms directory exists and contains the right files
344344
combined_transforms_dir = tmp_path / "all"
345345
assert combined_transforms_dir.exists(), "Combined transforms directory was not created."
346-
assert any(
347-
combined_transforms_dir.iterdir()
348-
), "No transformed images found in combined transforms directory."
346+
assert any(combined_transforms_dir.iterdir()), (
347+
"No transformed images found in combined transforms directory."
348+
)
349349
for i in range(1, n_examples + 1):
350-
assert (
351-
combined_transforms_dir / f"{i}.png"
352-
).exists(), f"Combined transform image {i}.png was not found."
350+
assert (combined_transforms_dir / f"{i}.png").exists(), (
351+
f"Combined transform image {i}.png was not found."
352+
)
353353

354354

355355
def test_save_each_transform(img_tensor_factory, tmp_path):
@@ -369,6 +369,6 @@ def test_save_each_transform(img_tensor_factory, tmp_path):
369369
# Check for specific files within each transform directory
370370
expected_files = [f"{i}.png" for i in range(1, n_examples + 1)] + ["min.png", "max.png", "mean.png"]
371371
for file_name in expected_files:
372-
assert (
373-
transform_dir / file_name
374-
).exists(), f"{file_name} was not found in {transform} directory."
372+
assert (transform_dir / file_name).exists(), (
373+
f"{file_name} was not found in {transform} directory."
374+
)

tests/test_online_buffer.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -132,9 +132,9 @@ def test_fifo():
132132
buffer.add_data(new_data)
133133
n_more_episodes = 2
134134
# Developer sanity check (in case someone changes the global `buffer_capacity`).
135-
assert (
136-
n_episodes + n_more_episodes
137-
) * n_frames_per_episode > buffer_capacity, "Something went wrong with the test code."
135+
assert (n_episodes + n_more_episodes) * n_frames_per_episode > buffer_capacity, (
136+
"Something went wrong with the test code."
137+
)
138138
more_new_data = make_spoof_data_frames(n_more_episodes, n_frames_per_episode)
139139
buffer.add_data(more_new_data)
140140
assert len(buffer) == buffer_capacity, "The buffer should be full."
@@ -203,9 +203,9 @@ def test_delta_timestamps_outside_tolerance_outside_episode_range():
203203
item = buffer[2]
204204
data, is_pad = item["index"], item["index_is_pad"]
205205
assert torch.equal(data, torch.tensor([0, 0, 2, 4, 4])), "Data does not match expected values"
206-
assert torch.equal(
207-
is_pad, torch.tensor([True, False, False, True, True])
208-
), "Padding does not match expected values"
206+
assert torch.equal(is_pad, torch.tensor([True, False, False, True, True])), (
207+
"Padding does not match expected values"
208+
)
209209

210210

211211
# Arbitrarily set small dataset sizes, making sure to have uneven sizes.

tests/test_policies.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -193,12 +193,12 @@ def test_policy(ds_repo_id, env_name, env_kwargs, policy_name, policy_kwargs):
193193
observation_ = deepcopy(observation)
194194
with torch.inference_mode():
195195
action = policy.select_action(observation).cpu().numpy()
196-
assert set(observation) == set(
197-
observation_
198-
), "Observation batch keys are not the same after a forward pass."
199-
assert all(
200-
torch.equal(observation[k], observation_[k]) for k in observation
201-
), "Observation batch values are not the same after a forward pass."
196+
assert set(observation) == set(observation_), (
197+
"Observation batch keys are not the same after a forward pass."
198+
)
199+
assert all(torch.equal(observation[k], observation_[k]) for k in observation), (
200+
"Observation batch values are not the same after a forward pass."
201+
)
202202

203203
# Test step through policy
204204
env.step(action)

0 commit comments

Comments
 (0)