Skip to content

Commit 8feeede

Browse files
committed
fix stats per episodes and aggregate stats and casting to tensor
1 parent 0b29fc3 commit 8feeede

File tree

5 files changed

+120
-78
lines changed

5 files changed

+120
-78
lines changed

lerobot/common/datasets/compute_stats.py

Lines changed: 16 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616
import numpy as np
17-
import torch
1817

1918
from lerobot.common.datasets.utils import load_image_as_numpy
2019

@@ -31,7 +30,7 @@ def compute_episode_stats(episode_buffer: dict, features: dict, num_image_sample
3130
"max": np.max(data, axis=axes_to_reduce),
3231
"mean": np.mean(data, axis=axes_to_reduce),
3332
"std": np.std(data, axis=axes_to_reduce),
34-
"count": data.shape[0],
33+
"count": np.array([data.shape[0]]),
3534
}
3635
return stats
3736

@@ -71,11 +70,11 @@ def compute_image_stats(image_paths: list[str], num_samples: int | None = None)
7170
"max": np.max(images, axis=axes_to_reduce, keepdims=True),
7271
"mean": np.mean(images, axis=axes_to_reduce, keepdims=True),
7372
"std": np.std(images, axis=axes_to_reduce, keepdims=True),
74-
"count": len(images),
7573
}
7674
for key in image_stats: # squeeze batch dim
7775
image_stats[key] = np.squeeze(image_stats[key], axis=0)
7876

77+
image_stats["count"] = np.array([len(images)])
7978
return image_stats
8079

8180

@@ -95,15 +94,15 @@ def _assert_type_and_shape(stats_list):
9594
for i in range(len(stats_list)):
9695
for fkey in stats_list[i]:
9796
for k, v in stats_list[i][fkey].items():
98-
if not isinstance(v, torch.Tensor):
97+
if not isinstance(v, np.ndarray):
9998
raise ValueError(
100-
f"Stats must be compared of torch tensors, but is {type(v)} instead."
99+
f"Stats must be composed of numpy array, but key '{k}' of feature '{fkey}' is of type '{type(v)}' instead."
101100
)
102101
if v.ndim == 0:
103102
raise ValueError("Number of dimensions must be at least 1, and is 0 instead.")
104-
if k == "count" and v.shape != torch.Size([1]):
103+
if k == "count" and v.shape != (1,):
105104
raise ValueError(f"Shape of 'count' must be (1), but is {v.shape} instead.")
106-
if "image" in k and v.shape != torch.Size([3, 1, 1]):
105+
if "image" in k and v.shape != (3, 1, 1):
107106
raise ValueError(f"Shape of '{k}' must be (3,1,1), but is {v.shape} instead.")
108107

109108
_assert_type_and_shape(stats_list)
@@ -116,35 +115,33 @@ def _assert_type_and_shape(stats_list):
116115
stats_with_key = [stats[key] for stats in stats_list if key in stats]
117116

118117
# Aggregate 'min' and 'max' using np.minimum and np.maximum
119-
min_, argmin_ = torch.min(torch.stack([s["min"] for s in stats_with_key]), dim=0)
120-
max_, argmax_ = torch.max(torch.stack([s["max"] for s in stats_with_key]), dim=0)
121-
aggregated_stats[key]["min"] = min_
122-
aggregated_stats[key]["max"] = max_
118+
aggregated_stats[key]["min"] = np.min(np.stack([s["min"] for s in stats_with_key]), axis=0)
119+
aggregated_stats[key]["max"] = np.max(np.stack([s["max"] for s in stats_with_key]), axis=0)
123120

124121
# Extract means, variances (std^2), and counts
125-
means = torch.stack([s["mean"] for s in stats_with_key])
126-
variances = torch.stack([s["std"] ** 2 for s in stats_with_key])
127-
counts = torch.stack([s["count"] for s in stats_with_key])
122+
means = np.stack([s["mean"] for s in stats_with_key])
123+
variances = np.stack([s["std"] ** 2 for s in stats_with_key])
124+
counts = np.stack([s["count"] for s in stats_with_key])
128125

129126
# Compute total counts
130-
total_count = counts.sum(dim=0)
127+
total_count = counts.sum(axis=0)
131128

132129
# Prepare weighted mean by matching number of dimensions
133130
while counts.ndim < means.ndim:
134-
counts = counts.unsqueeze(-1)
131+
counts = np.expand_dims(counts, axis=-1)
135132

136133
# Compute the weighted mean
137134
weighted_means = means * counts
138-
total_mean = weighted_means.sum(dim=0) / total_count
135+
total_mean = weighted_means.sum(axis=0) / total_count
139136

140137
# Compute the variance using the parallel algorithm
141138
delta_means = means - total_mean
142139
weighted_variances = (variances + delta_means**2) * counts
143-
total_variance = weighted_variances.sum(dim=0) / total_count
140+
total_variance = weighted_variances.sum(axis=0) / total_count
144141

145142
# Store the aggregated stats
146143
aggregated_stats[key]["mean"] = total_mean
147-
aggregated_stats[key]["std"] = torch.sqrt(total_variance)
144+
aggregated_stats[key]["std"] = np.sqrt(total_variance)
148145
aggregated_stats[key]["count"] = total_count
149146

150147
return aggregated_stats

lerobot/common/datasets/lerobot_dataset.py

Lines changed: 28 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,7 @@
3434
from lerobot.common.datasets.utils import (
3535
DEFAULT_FEATURES,
3636
DEFAULT_IMAGE_PATH,
37-
EPISODES_PATH,
38-
EPISODES_STATS_PATH,
3937
INFO_PATH,
40-
STATS_PATH,
41-
TASKS_PATH,
42-
append_jsonlines,
4338
backward_compatible_episodes_stats,
4439
check_delta_timestamps,
4540
check_timestamps_sync,
@@ -58,9 +53,13 @@
5853
load_info,
5954
load_stats,
6055
load_tasks,
61-
serialize_dict,
56+
write_episode,
57+
write_episode_stats,
58+
write_info,
6259
write_json,
6360
write_parquet,
61+
write_stats,
62+
write_task,
6463
)
6564
from lerobot.common.datasets.video_utils import (
6665
VideoFrame,
@@ -101,7 +100,7 @@ def __init__(
101100
"'episodes_stats.jsonl' not found. Use global dataset stats for each episode instead.",
102101
stacklevel=1,
103102
)
104-
self.episodes_stats = backward_compatible_episodes_stats(self.stats, self.episodes.keys())
103+
self.episodes_stats = backward_compatible_episodes_stats(self.stats, self.episodes)
105104

106105
def pull_from_repo(
107106
self,
@@ -241,30 +240,26 @@ def save_episode(
241240
if task_index not in self.tasks:
242241
self.info["total_tasks"] += 1
243242
self.tasks[task_index] = task
244-
task_dict = {
245-
"task_index": task_index,
246-
"task": task,
247-
}
248-
append_jsonlines(task_dict, self.root / TASKS_PATH)
243+
write_task(task_index, task, self.root)
249244

250245
chunk = self.get_episode_chunk(episode_index)
251246
if chunk >= self.total_chunks:
252247
self.info["total_chunks"] += 1
253248

254249
self.info["splits"] = {"train": f"0:{self.info['total_episodes']}"}
255250
self.info["total_videos"] += len(self.video_keys)
256-
write_json(self.info, self.root / INFO_PATH)
251+
write_info(self.info, self.root)
257252

258253
episode_dict = {
259254
"episode_index": episode_index,
260255
"tasks": [task],
261256
"length": episode_length,
262257
}
263-
self.episodes.append(episode_dict)
264-
append_jsonlines(episode_dict, self.root / EPISODES_PATH)
258+
self.episodes[episode_index] = episode_dict
259+
write_episode(episode_dict, self.root)
265260

266-
self.episodes_stats.append(episode_stats)
267-
append_jsonlines(episode_stats, self.root / EPISODES_STATS_PATH)
261+
self.episodes_stats[episode_index] = episode_stats
262+
write_episode_stats(episode_index, episode_stats, self.root)
268263

269264
def write_video_info(self) -> None:
270265
"""
@@ -323,7 +318,7 @@ def create(
323318
# TODO(aliberts, rcadene): implement sanity check for features
324319
features = {**features, **DEFAULT_FEATURES}
325320

326-
obj.tasks, obj.stats, obj.episodes, obj.episodes_stats = {}, {}, [], []
321+
obj.tasks, obj.stats, obj.episodes, obj.episodes_stats = {}, {}, {}, {}
327322
obj.info = create_empty_dataset_info(CODEBASE_VERSION, fps, robot_type, features, use_videos)
328323
if len(obj.video_keys) > 0 and not use_videos:
329324
raise ValueError()
@@ -664,8 +659,7 @@ def __getitem__(self, idx) -> dict:
664659

665660
query_indices = None
666661
if self.delta_indices is not None:
667-
current_ep_idx = self.episodes.index(ep_idx) if self.episodes is not None else ep_idx
668-
query_indices, padding = self._get_query_indices(idx, current_ep_idx)
662+
query_indices, padding = self._get_query_indices(idx, ep_idx)
669663
query_result = self._query_hf_dataset(query_indices)
670664
item = {**item, **padding}
671665
for key, val in query_result.items():
@@ -807,18 +801,20 @@ def _prepare_episode_buffer(self, episode_buffer: dict, task: str):
807801
raise ValueError()
808802

809803
for key, ft in self.features.items():
804+
# We add an extra dimension to index, frame_index, timestamp, episode_index, task_index
805+
# to fit the shape `(1,)` defined in `self.features`
810806
if key == "index":
811807
episode_buffer[key] = np.arange(
812808
self.meta.total_frames, self.meta.total_frames + episode_length
813-
)
809+
)[:, np.newaxis]
810+
elif key == "frame_index" or key == "timestamp":
811+
episode_buffer[key] = np.array(episode_buffer[key])[:, np.newaxis]
814812
elif key == "episode_index":
815-
episode_buffer[key] = np.full((episode_length,), episode_index)
813+
episode_buffer[key] = np.full((episode_length, 1), episode_index)
816814
elif key == "task_index":
817-
episode_buffer[key] = np.full((episode_length,), task_index)
815+
episode_buffer[key] = np.full((episode_length, 1), task_index)
818816
elif ft["dtype"] in ["image", "video"]:
819817
continue
820-
elif len(ft["shape"]) == 1 and ft["shape"][0] == 1:
821-
episode_buffer[key] = np.array(episode_buffer[key], dtype=ft["dtype"])
822818
elif len(ft["shape"]) == 1 and ft["shape"][0] > 1:
823819
episode_buffer[key] = np.stack(episode_buffer[key])
824820
else:
@@ -828,7 +824,7 @@ def _prepare_episode_buffer(self, episode_buffer: dict, task: str):
828824

829825
def _compute_episode_stats(self, episode_buffer: dict):
830826
ep_stats = compute_episode_stats(episode_buffer, self.features)
831-
return serialize_dict(ep_stats)
827+
return ep_stats
832828

833829
def _save_episode_table(self, episode_buffer: dict, episode_index: int) -> None:
834830
episode_dict = {key: episode_buffer[key] for key in self.hf_features}
@@ -926,9 +922,8 @@ def consolidate(self, run_compute_stats: bool = True, keep_image_files: bool = F
926922

927923
if run_compute_stats:
928924
self.stop_image_writer()
929-
self.meta.stats = aggregate_stats(self.meta.episodes_stats)
930-
serialized_stats = serialize_dict(self.meta.stats)
931-
write_json(serialized_stats, self.root / STATS_PATH)
925+
self.meta.stats = aggregate_stats(list(self.meta.episodes_stats.values()))
926+
write_stats(self.meta.stats, self.root)
932927
self.consolidated = True
933928
else:
934929
logging.warning(
@@ -1051,7 +1046,10 @@ def __init__(
10511046

10521047
self.image_transforms = image_transforms
10531048
self.delta_timestamps = delta_timestamps
1054-
self.stats = aggregate_stats(self._datasets)
1049+
# TODO(rcadene, aliberts): We should not perform this aggregation for datasets
1050+
# with multiple robots of different ranges. Instead we should have one normalization
1051+
# per robot.
1052+
self.stats = aggregate_stats([dataset.meta.stats for dataset in self._datasets])
10551053

10561054
@property
10571055
def repo_id_to_index(self):

lerobot/common/datasets/utils.py

Lines changed: 39 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -163,41 +163,73 @@ def append_jsonlines(data: dict, fpath: Path) -> None:
163163
writer.write(data)
164164

165165

166+
def write_info(info: dict, local_dir: Path):
167+
write_json(info, local_dir / INFO_PATH)
168+
169+
166170
def load_info(local_dir: Path) -> dict:
167171
info = load_json(local_dir / INFO_PATH)
168172
for ft in info["features"].values():
169173
ft["shape"] = tuple(ft["shape"])
170174
return info
171175

172176

177+
def write_stats(stats: dict, local_dir: Path):
178+
serialized_stats = serialize_dict(stats)
179+
write_json(serialized_stats, local_dir / STATS_PATH)
180+
181+
182+
def cast_stats_to_numpy(stats):
183+
stats = {key: np.array(value) for key, value in flatten_dict(stats).items()}
184+
return unflatten_dict(stats)
185+
186+
173187
def load_stats(local_dir: Path) -> dict:
174188
if not (local_dir / STATS_PATH).exists():
175189
return None
176190
stats = load_json(local_dir / STATS_PATH)
177-
stats = {key: np.array(value) for key, value in flatten_dict(stats).items()}
178-
return unflatten_dict(stats)
191+
return cast_stats_to_numpy(stats)
192+
193+
194+
def write_task(task_index: int, task: dict, local_dir: Path):
195+
task_dict = {
196+
"task_index": task_index,
197+
"task": task,
198+
}
199+
append_jsonlines(task_dict, local_dir / TASKS_PATH)
179200

180201

181202
def load_tasks(local_dir: Path) -> dict:
182203
tasks = load_jsonlines(local_dir / TASKS_PATH)
183204
return {item["task_index"]: item["task"] for item in sorted(tasks, key=lambda x: x["task_index"])}
184205

185206

207+
def write_episode(episode: dict, local_dir: Path):
208+
append_jsonlines(episode, local_dir / EPISODES_PATH)
209+
210+
186211
def load_episodes(local_dir: Path) -> dict:
187212
episodes = load_jsonlines(local_dir / EPISODES_PATH)
188213
return {item["episode_index"]: item for item in sorted(episodes, key=lambda x: x["episode_index"])}
189214

190215

216+
def write_episode_stats(episode_index: int, episode_stats: dict, local_dir: Path):
217+
# We wrap episode_stats in a dictionnary since `episode_stats["episode_index"]`
218+
# is a dictionary of stats and not an integer.
219+
episode_stats = {"episode_index": episode_index, "stats": serialize_dict(episode_stats)}
220+
append_jsonlines(episode_stats, local_dir / EPISODES_STATS_PATH)
221+
222+
191223
def load_episodes_stats(local_dir: Path) -> dict:
192-
episodes_tasks = load_jsonlines(local_dir / EPISODES_STATS_PATH)
224+
episodes_stats = load_jsonlines(local_dir / EPISODES_STATS_PATH)
193225
return {
194-
item["episode_index"]: item["stats"]
195-
for item in sorted(episodes_tasks, key=lambda x: x["episode_index"])
226+
item["episode_index"]: cast_stats_to_numpy(item["stats"])
227+
for item in sorted(episodes_stats, key=lambda x: x["episode_index"])
196228
}
197229

198230

199231
def backward_compatible_episodes_stats(stats, episodes: list[int]):
200-
return {ep_idx: {"episode_index": ep_idx, "stats": stats} for ep_idx in episodes}
232+
return {ep_idx: stats for ep_idx in episodes}
201233

202234

203235
def load_image_as_numpy(fpath: str | Path, dtype="float32", channel_first: bool = True) -> np.ndarray:
@@ -381,7 +413,7 @@ def create_empty_dataset_info(
381413

382414

383415
def get_episode_data_index(
384-
episode_dicts: list[dict], episodes: list[int] | None = None
416+
episode_dicts: dict[dict], episodes: list[int] | None = None
385417
) -> dict[str, torch.Tensor]:
386418
episode_lengths = {ep_idx: ep_dict["length"] for ep_idx, ep_dict in episode_dicts.items()}
387419
if episodes is not None:

lerobot/common/policies/normalize.py

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
16+
import numpy as np
1617
import torch
1718
from torch import Tensor, nn
1819

@@ -22,7 +23,7 @@
2223
def create_stats_buffers(
2324
features: dict[str, PolicyFeature],
2425
norm_map: dict[str, NormalizationMode],
25-
stats: dict[str, dict[str, Tensor]] | None = None,
26+
stats: dict[str, dict[str, torch.Tensor]] | None = None,
2627
) -> dict[str, dict[str, nn.ParameterDict]]:
2728
"""
2829
Create buffers per modality (e.g. "observation.image", "action") containing their mean, std, min, max
@@ -78,16 +79,27 @@ def create_stats_buffers(
7879
)
7980

8081
if stats:
81-
# Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated
82-
# tensors anywhere (for example, when we use the same stats for normalization and
83-
# unnormalization). See the logic here
84-
# https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97.
85-
if norm_mode is NormalizationMode.MEAN_STD:
86-
buffer["mean"].data = stats[key]["mean"].clone()
87-
buffer["std"].data = stats[key]["std"].clone()
88-
elif norm_mode is NormalizationMode.MIN_MAX:
89-
buffer["min"].data = stats[key]["min"].clone()
90-
buffer["max"].data = stats[key]["max"].clone()
82+
if isinstance(stats[key]["mean"], np.ndarray):
83+
if norm_mode is NormalizationMode.MEAN_STD:
84+
buffer["mean"].data = torch.from_numpy(stats[key]["mean"]).to(dtype=torch.float32)
85+
buffer["std"].data = torch.from_numpy(stats[key]["std"]).to(dtype=torch.float32)
86+
elif norm_mode is NormalizationMode.MIN_MAX:
87+
buffer["min"].data = torch.from_numpy(stats[key]["min"]).to(dtype=torch.float32)
88+
buffer["max"].data = torch.from_numpy(stats[key]["max"]).to(dtype=torch.float32)
89+
elif isinstance(stats[key]["mean"], torch.Tensor):
90+
# Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated
91+
# tensors anywhere (for example, when we use the same stats for normalization and
92+
# unnormalization). See the logic here
93+
# https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97.
94+
if norm_mode is NormalizationMode.MEAN_STD:
95+
buffer["mean"].data = stats[key]["mean"].clone().to(dtype=torch.float32)
96+
buffer["std"].data = stats[key]["std"].clone().to(dtype=torch.float32)
97+
elif norm_mode is NormalizationMode.MIN_MAX:
98+
buffer["min"].data = stats[key]["min"].clone().to(dtype=torch.float32)
99+
buffer["max"].data = stats[key]["max"].clone().to(dtype=torch.float32)
100+
else:
101+
type_ = type(stats[key]["mean"])
102+
raise ValueError(f"np.ndarray or torch.Tensor expected, but type is '{type_}' instead.")
91103

92104
stats_buffers[key] = buffer
93105
return stats_buffers

0 commit comments

Comments
 (0)