You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I'm following all the documentation steps to finetune pi0_base on Fractal dataset. I managed to compute the stats (previous step in the documentation), it worked without any problem, but when I run the train script I face the follwoing error:
Traceback (most recent call last):
File "/home/ubuntu/Desktop/openpi/.venv/lib/python3.11/site-packages/jaxtyping/_decorator.py", line 446, in wrapped_fn_impl
param_fn(*args, **kwargs)
File "<@beartype(openpi.training.utils.TrainState) at 0x7f2010728ae0>", line 135, in TrainState
File "/home/ubuntu/Desktop/openpi/.venv/lib/python3.11/site-packages/beartype/_check/forward/reference/fwdrefmeta.py", line 295, in __instancecheck__
return cls.__is_instance_beartype__(obj)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ubuntu/Desktop/openpi/.venv/lib/python3.11/site-packages/beartype/_check/forward/reference/fwdrefabc.py", line 112, in __is_instance_beartype__
return isinstance(obj, cls.__type_beartype__) # type: ignore[arg-type]
^^^^^^^^^^^^^^^^^^^^^
File "/home/ubuntu/Desktop/openpi/.venv/lib/python3.11/site-packages/beartype/_check/forward/reference/fwdrefmeta.py", line 451, in __type_beartype__
referent = import_module_attr(
^^^^^^^^^^^^^^^^^^^
File "/home/ubuntu/Desktop/openpi/.venv/lib/python3.11/site-packages/beartype/_util/module/utilmodimport.py", line 295, in import_module_attr
raise exception_cls(exception_message)
beartype.roar.BeartypeCallHintForwardRefException: Forward reference "ArrayTree" unimportable from module "openpi.training.utils".
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/home/ubuntu/Desktop/openpi/.venv/lib/python3.11/site-packages/jaxtyping/_decorator.py", line 811, in _get_problem_arg
fn(*args, **kwargs)
File "<@beartype(openpi.training.utils.check_single_arg) at 0x7f201072a660>", line 53, in check_single_arg
File "/home/ubuntu/Desktop/openpi/.venv/lib/python3.11/site-packages/beartype/_check/forward/reference/fwdrefmeta.py", line 295, in __instancecheck__
return cls.__is_instance_beartype__(obj)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ubuntu/Desktop/openpi/.venv/lib/python3.11/site-packages/beartype/_check/forward/reference/fwdrefabc.py", line 112, in __is_instance_beartype__
return isinstance(obj, cls.__type_beartype__) # type: ignore[arg-type]
^^^^^^^^^^^^^^^^^^^^^
File "/home/ubuntu/Desktop/openpi/.venv/lib/python3.11/site-packages/beartype/_check/forward/reference/fwdrefmeta.py", line 451, in __type_beartype__
referent = import_module_attr(
^^^^^^^^^^^^^^^^^^^
File "/home/ubuntu/Desktop/openpi/.venv/lib/python3.11/site-packages/beartype/_util/module/utilmodimport.py", line 295, in import_module_attr
raise exception_cls(exception_message)
beartype.roar.BeartypeCallHintForwardRefException: Forward reference "ArrayTree" unimportable from module "openpi.training.utils".
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/ubuntu/Desktop/openpi/.venv/lib/python3.11/site-packages/jaxtyping/_decorator.py", line 451, in wrapped_fn_impl
argmsg = _get_problem_arg(
^^^^^^^^^^^^^^^^^
File "/home/ubuntu/Desktop/openpi/.venv/lib/python3.11/site-packages/jaxtyping/_decorator.py", line 814, in _get_problem_arg
raise TypeCheckError(
jaxtyping.TypeCheckError:
The problem arose whilst typechecking parameter 'opt_state'.
Here is my config file:
TrainConfig(
name="pi0_fractal",
# Here is an example of loading a pi0 model for LoRA fine-tuning.
model=pi0.Pi0Config(action_dim=7, action_horizon=10, max_token_len=180, paligemma_variant="gemma_2b_lora", action_expert_variant="gemma_300m"),
data=LeRobotFractalDataConfig(
repo_id="IPEC-COMMUNITY/fractal20220817_data_lerobot",
base_config=DataConfig(
local_files_only=False, # Set to True for local-only datasets.
prompt_from_task=True,
),
),
batch_size=8,
num_workers=64,
weight_loader=weight_loaders.CheckpointWeightLoader("s3://openpi-assets/checkpoints/pi0_base/params"),
num_train_steps=30_000,
# The freeze filter defines which parameters should be frozen during training.
# We have a convenience function in the model config that returns the default freeze filter
# for the given model config for LoRA finetuning. Just make sure it matches the model config
# you chose above.
freeze_filter=pi0.Pi0Config(
paligemma_variant="gemma_2b_lora", action_expert_variant="gemma_300m"
).get_freeze_filter(),
# Turn off EMA for LoRA finetuning.
ema_decay=None,
),
Any help is welcome!!
Thanks
The text was updated successfully, but these errors were encountered:
Hi, I'm new to OpenPi and ran into the same error while trying to fine-tune the pi0_base model on HSR dataset.
Since the error seemed to be caused by type checking of "opt_state", I tried modifying the type definition of the TrainState class in openpi/src/openpi/training/utils.py (starting at line 15) from: opt_state: optax.OptState
to: opt_state: Any
This change resolved the error, and I was able to start training successfully.
However, I'm a bit concerned that removing the type check might cause issues later. I’d appreciate it if someone more experienced with OpenPi could confirm whether this workaround is acceptable or if there's a better fix.
Uh oh!
There was an error while loading. Please reload this page.
Hi,
I'm following all the documentation steps to finetune pi0_base on Fractal dataset. I managed to compute the stats (previous step in the documentation), it worked without any problem, but when I run the train script I face the follwoing error:
Here is my config file:
Any help is welcome!!
Thanks
The text was updated successfully, but these errors were encountered: