@@ -87,7 +87,7 @@ def __init__(
87
87
self .pull_from_repo (allow_patterns = "meta/" )
88
88
self .info = load_info (self .root )
89
89
self .stats = load_stats (self .root )
90
- self .tasks = load_tasks (self .root )
90
+ self .tasks , self . task_to_task_index = load_tasks (self .root )
91
91
self .episodes = load_episodes (self .root )
92
92
93
93
def pull_from_repo (
@@ -202,31 +202,35 @@ def chunks_size(self) -> int:
202
202
"""Max number of episodes per chunk."""
203
203
return self .info ["chunks_size" ]
204
204
205
- @property
206
- def task_to_task_index (self ) -> dict :
207
- return {task : task_idx for task_idx , task in self .tasks .items ()}
208
-
209
- def get_task_index (self , task : str ) -> int :
205
+ def get_task_index (self , task : str ) -> int | None :
210
206
"""
211
207
Given a task in natural language, returns its task_index if the task already exists in the dataset,
212
- otherwise creates a new task_index.
208
+ otherwise return None.
209
+ """
210
+ return self .task_to_task_index .get (task , None )
211
+
212
+ def add_task (self , task : str ):
213
213
"""
214
- task_index = self .task_to_task_index .get (task , None )
215
- return task_index if task_index is not None else self .total_tasks
214
+ Given a task in natural language, add it to the dictionnary of tasks.
215
+ """
216
+ if task in self .task_to_task_index :
217
+ raise ValueError (f"The task '{ task } ' already exists and can't be added twice." )
218
+
219
+ task_index = self .info ["total_tasks" ]
220
+ self .task_to_task_index [task ] = task_index
221
+ self .tasks [task_index ] = task
222
+ self .info ["total_tasks" ] += 1
223
+
224
+ task_dict = {
225
+ "task_index" : task_index ,
226
+ "task" : task ,
227
+ }
228
+ append_jsonlines (task_dict , self .root / TASKS_PATH )
216
229
217
- def save_episode (self , episode_index : int , episode_length : int , task : str , task_index : int ) -> None :
230
+ def save_episode (self , episode_index : int , episode_length : int , episode_tasks : list [ str ] ) -> None :
218
231
self .info ["total_episodes" ] += 1
219
232
self .info ["total_frames" ] += episode_length
220
233
221
- if task_index not in self .tasks :
222
- self .info ["total_tasks" ] += 1
223
- self .tasks [task_index ] = task
224
- task_dict = {
225
- "task_index" : task_index ,
226
- "task" : task ,
227
- }
228
- append_jsonlines (task_dict , self .root / TASKS_PATH )
229
-
230
234
chunk = self .get_episode_chunk (episode_index )
231
235
if chunk >= self .total_chunks :
232
236
self .info ["total_chunks" ] += 1
@@ -237,7 +241,7 @@ def save_episode(self, episode_index: int, episode_length: int, task: str, task_
237
241
238
242
episode_dict = {
239
243
"episode_index" : episode_index ,
240
- "tasks" : [ task ] ,
244
+ "tasks" : episode_tasks ,
241
245
"length" : episode_length ,
242
246
}
243
247
self .episodes .append (episode_dict )
@@ -313,7 +317,8 @@ def create(
313
317
314
318
features = {** features , ** DEFAULT_FEATURES }
315
319
316
- obj .tasks , obj .stats , obj .episodes = {}, {}, []
320
+ obj .tasks , obj .task_to_task_index = {}, {}
321
+ obj .stats , obj .episodes = {}, []
317
322
obj .info = create_empty_dataset_info (CODEBASE_VERSION , fps , robot_type , features , use_videos )
318
323
if len (obj .video_keys ) > 0 and not use_videos :
319
324
raise ValueError ()
@@ -691,10 +696,13 @@ def __repr__(self):
691
696
692
697
def create_episode_buffer (self , episode_index : int | None = None ) -> dict :
693
698
current_ep_idx = self .meta .total_episodes if episode_index is None else episode_index
694
- return {
695
- "size" : 0 ,
696
- ** {key : current_ep_idx if key == "episode_index" else [] for key in self .features },
697
- }
699
+ ep_buffer = {}
700
+ # size and task are special cases that are not in self.features
701
+ ep_buffer ["size" ] = 0
702
+ ep_buffer ["task" ] = []
703
+ for key in self .features :
704
+ ep_buffer [key ] = current_ep_idx if key == "episode_index" else []
705
+ return ep_buffer
698
706
699
707
def _get_image_file_path (self , episode_index : int , image_key : str , frame_index : int ) -> Path :
700
708
fpath = DEFAULT_IMAGE_PATH .format (
@@ -718,6 +726,8 @@ def add_frame(self, frame: dict) -> None:
718
726
"""
719
727
# TODO(aliberts, rcadene): Add sanity check for the input, check it's numpy or torch,
720
728
# check the dtype and shape matches, etc.
729
+ if "task" not in frame :
730
+ raise ValueError ("The mandatory feature 'task' wasn't found in `frame` dictionnary." )
721
731
722
732
if self .episode_buffer is None :
723
733
self .episode_buffer = self .create_episode_buffer ()
@@ -728,24 +738,31 @@ def add_frame(self, frame: dict) -> None:
728
738
self .episode_buffer ["timestamp" ].append (timestamp )
729
739
730
740
for key in frame :
741
+ if key == "task" :
742
+ # Note: we associate the task in natural language to its task index during `save_episode`
743
+ self .episode_buffer ["task" ].append (frame ["task" ])
744
+ continue
745
+
731
746
if key not in self .features :
732
- raise ValueError (key )
747
+ raise ValueError (
748
+ f"An element of the frame is not in the features. '{ key } ' not in '{ self .features .keys ()} '."
749
+ )
733
750
734
- if self .features [key ]["dtype" ] not in ["image" , "video" ]:
735
- item = frame [key ].numpy () if isinstance (frame [key ], torch .Tensor ) else frame [key ]
736
- self .episode_buffer [key ].append (item )
737
- elif self .features [key ]["dtype" ] in ["image" , "video" ]:
751
+ if self .features [key ]["dtype" ] in ["image" , "video" ]:
738
752
img_path = self ._get_image_file_path (
739
753
episode_index = self .episode_buffer ["episode_index" ], image_key = key , frame_index = frame_index
740
754
)
741
755
if frame_index == 0 :
742
756
img_path .parent .mkdir (parents = True , exist_ok = True )
743
757
self ._save_image (frame [key ], img_path )
744
758
self .episode_buffer [key ].append (str (img_path ))
759
+ else :
760
+ item = frame [key ].numpy () if isinstance (frame [key ], torch .Tensor ) else frame [key ]
761
+ self .episode_buffer [key ].append (item )
745
762
746
763
self .episode_buffer ["size" ] += 1
747
764
748
- def save_episode (self , task : str , encode_videos : bool = True , episode_data : dict | None = None ) -> None :
765
+ def save_episode (self , encode_videos : bool = True , episode_data : dict | None = None ) -> None :
749
766
"""
750
767
This will save to disk the current episode in self.episode_buffer. Note that since it affects files on
751
768
disk, it sets self.consolidated to False to ensure proper consolidation later on before uploading to
@@ -758,7 +775,11 @@ def save_episode(self, task: str, encode_videos: bool = True, episode_data: dict
758
775
if not episode_data :
759
776
episode_buffer = self .episode_buffer
760
777
778
+ # size and task are special cases that won't be added to hf_dataset
761
779
episode_length = episode_buffer .pop ("size" )
780
+ tasks = episode_buffer .pop ("task" )
781
+ episode_tasks = list (set (tasks ))
782
+
762
783
episode_index = episode_buffer ["episode_index" ]
763
784
if episode_index != self .meta .total_episodes :
764
785
# TODO(aliberts): Add option to use existing episode_index
@@ -772,21 +793,27 @@ def save_episode(self, task: str, encode_videos: bool = True, episode_data: dict
772
793
"You must add one or several frames with `add_frame` before calling `add_episode`."
773
794
)
774
795
775
- task_index = self .meta .get_task_index (task )
776
-
777
796
if not set (episode_buffer .keys ()) == set (self .features ):
778
- raise ValueError ()
797
+ raise ValueError (
798
+ f"Features from `episode_buffer` don't match the ones in `self.features`: '{ set (episode_buffer .keys ())} ' vs '{ set (self .features )} '"
799
+ )
800
+
801
+ episode_buffer ["index" ] = np .arange (self .meta .total_frames , self .meta .total_frames + episode_length )
802
+ episode_buffer ["episode_index" ] = np .full ((episode_length ,), episode_index )
803
+
804
+ # Add new tasks to the tasks dictionnary
805
+ for task in episode_tasks :
806
+ task_index = self .meta .get_task_index (task )
807
+ if task_index is None :
808
+ self .meta .add_task (task )
809
+
810
+ # Given tasks in natural language, find their corresponding task indices
811
+ episode_buffer ["task_index" ] = np .array ([self .meta .get_task_index (task ) for task in tasks ])
779
812
780
813
for key , ft in self .features .items ():
781
- if key == "index" :
782
- episode_buffer [key ] = np .arange (
783
- self .meta .total_frames , self .meta .total_frames + episode_length
784
- )
785
- elif key == "episode_index" :
786
- episode_buffer [key ] = np .full ((episode_length ,), episode_index )
787
- elif key == "task_index" :
788
- episode_buffer [key ] = np .full ((episode_length ,), task_index )
789
- elif ft ["dtype" ] in ["image" , "video" ]:
814
+ # index, episode_index, task_index are already processed above, and image and video
815
+ # are processed separately by storing image path and frame info as meta data
816
+ if key in ["index" , "episode_index" , "task_index" ] or ft ["dtype" ] in ["image" , "video" ]:
790
817
continue
791
818
elif len (ft ["shape" ]) == 1 and ft ["shape" ][0 ] == 1 :
792
819
episode_buffer [key ] = np .array (episode_buffer [key ], dtype = ft ["dtype" ])
@@ -798,7 +825,7 @@ def save_episode(self, task: str, encode_videos: bool = True, episode_data: dict
798
825
self ._wait_image_writer ()
799
826
self ._save_episode_table (episode_buffer , episode_index )
800
827
801
- self .meta .save_episode (episode_index , episode_length , task , task_index )
828
+ self .meta .save_episode (episode_index , episode_length , episode_tasks )
802
829
803
830
if encode_videos and len (self .meta .video_keys ) > 0 :
804
831
video_paths = self .encode_episode_videos (episode_index )
0 commit comments