Skip to content

Commit e81c36c

Browse files
authored
Fix dataset version tags (#790)
1 parent ed83cbd commit e81c36c

File tree

3 files changed

+23
-1
lines changed

3 files changed

+23
-1
lines changed

lerobot/common/datasets/lerobot_dataset.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
16+
import contextlib
1617
import logging
1718
import shutil
1819
from pathlib import Path
@@ -27,6 +28,7 @@
2728
from datasets import concatenate_datasets, load_dataset
2829
from huggingface_hub import HfApi, snapshot_download
2930
from huggingface_hub.constants import REPOCARD_NAME
31+
from huggingface_hub.errors import RevisionNotFoundError
3032

3133
from lerobot.common.constants import HF_LEROBOT_HOME
3234
from lerobot.common.datasets.compute_stats import aggregate_stats, compute_episode_stats
@@ -517,6 +519,7 @@ def push_to_hub(
517519
branch: str | None = None,
518520
tags: list | None = None,
519521
license: str | None = "apache-2.0",
522+
tag_version: bool = True,
520523
push_videos: bool = True,
521524
private: bool = False,
522525
allow_patterns: list[str] | str | None = None,
@@ -562,6 +565,11 @@ def push_to_hub(
562565
)
563566
card.push_to_hub(repo_id=self.repo_id, repo_type="dataset", revision=branch)
564567

568+
if tag_version:
569+
with contextlib.suppress(RevisionNotFoundError):
570+
hub_api.delete_tag(self.repo_id, tag=CODEBASE_VERSION, repo_type="dataset")
571+
hub_api.create_tag(self.repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset")
572+
565573
def pull_from_repo(
566574
self,
567575
allow_patterns: list[str] | str | None = None,

lerobot/common/datasets/utils.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
import torch
3232
from datasets.table import embed_table_storage
3333
from huggingface_hub import DatasetCard, DatasetCardData, HfApi
34+
from huggingface_hub.errors import RevisionNotFoundError
3435
from PIL import Image as PILImage
3536
from torchvision import transforms
3637

@@ -325,6 +326,19 @@ def get_safe_version(repo_id: str, version: str | packaging.version.Version) ->
325326
)
326327
hub_versions = get_repo_versions(repo_id)
327328

329+
if not hub_versions:
330+
raise RevisionNotFoundError(
331+
f"""Your dataset must be tagged with a codebase version.
332+
Assuming _version_ is the codebase_version value in the info.json, you can run this:
333+
```python
334+
from huggingface_hub import HfApi
335+
336+
hub_api = HfApi()
337+
hub_api.create_tag("{repo_id}", tag="_version_", repo_type="dataset")
338+
```
339+
"""
340+
)
341+
328342
if target_version in hub_versions:
329343
return f"v{target_version}"
330344

lerobot/common/datasets/v21/convert_dataset_v20_to_v21.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def convert_dataset(
5757
dataset.meta.info["codebase_version"] = CODEBASE_VERSION
5858
write_info(dataset.meta.info, dataset.root)
5959

60-
dataset.push_to_hub(branch=branch, allow_patterns="meta/")
60+
dataset.push_to_hub(branch=branch, tag_version=False, allow_patterns="meta/")
6161

6262
# delete old stats.json file
6363
if (dataset.root / STATS_PATH).is_file:

0 commit comments

Comments
 (0)