Skip to content

Commit 055a6f6

Browse files
authored
add root argument to the dataset visualizer to visualize local datasets (#249)
1 parent e54d6ea commit 055a6f6

File tree

2 files changed

+27
-1
lines changed

2 files changed

+27
-1
lines changed

lerobot/scripts/visualize_dataset.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,14 +106,15 @@ def visualize_dataset(
106106
ws_port: int = 9087,
107107
save: bool = False,
108108
output_dir: Path | None = None,
109+
root: Path | None = None,
109110
) -> Path | None:
110111
if save:
111112
assert (
112113
output_dir is not None
113114
), "Set an output directory where to write .rrd files with `--output-dir path/to/directory`."
114115

115116
logging.info("Loading dataset")
116-
dataset = LeRobotDataset(repo_id)
117+
dataset = LeRobotDataset(repo_id, root=root)
117118

118119
logging.info("Loading dataloader")
119120
episode_sampler = EpisodeSampler(dataset, episode_index)
@@ -256,6 +257,12 @@ def main():
256257
help="Directory path to write a .rrd file when `--save 1` is set.",
257258
)
258259

260+
parser.add_argument(
261+
"--root",
262+
type=str,
263+
help="Root directory for a dataset stored on a local machine.",
264+
)
265+
259266
args = parser.parse_args()
260267
visualize_dataset(**vars(args))
261268

tests/test_visualize_dataset.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
16+
from pathlib import Path
17+
1618
import pytest
1719

1820
from lerobot.scripts.visualize_dataset import visualize_dataset
@@ -31,3 +33,20 @@ def test_visualize_dataset(tmpdir, repo_id):
3133
output_dir=tmpdir,
3234
)
3335
assert rrd_path.exists()
36+
37+
38+
@pytest.mark.parametrize(
39+
"repo_id",
40+
["lerobot/pusht"],
41+
)
42+
@pytest.mark.parametrize("root", [Path(__file__).parent / "data"])
43+
def test_visualize_local_dataset(tmpdir, repo_id, root):
44+
rrd_path = visualize_dataset(
45+
repo_id,
46+
episode_index=0,
47+
batch_size=32,
48+
save=True,
49+
output_dir=tmpdir,
50+
root=root,
51+
)
52+
assert rrd_path.exists()

0 commit comments

Comments
 (0)