Skip to content

Cannot finetune pi0_base model #489

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
pablovalle opened this issue May 21, 2025 · 1 comment
Open

Cannot finetune pi0_base model #489

pablovalle opened this issue May 21, 2025 · 1 comment

Comments

@pablovalle
Copy link

pablovalle commented May 21, 2025

Hi,

I'm following all the documentation steps to finetune pi0_base on Fractal dataset. I managed to compute the stats (previous step in the documentation), it worked without any problem, but when I run the train script I face the follwoing error:

Traceback (most recent call last):
  File "/home/ubuntu/Desktop/openpi/.venv/lib/python3.11/site-packages/jaxtyping/_decorator.py", line 446, in wrapped_fn_impl
    param_fn(*args, **kwargs)
  File "<@beartype(openpi.training.utils.TrainState) at 0x7f2010728ae0>", line 135, in TrainState
  File "/home/ubuntu/Desktop/openpi/.venv/lib/python3.11/site-packages/beartype/_check/forward/reference/fwdrefmeta.py", line 295, in __instancecheck__
    return cls.__is_instance_beartype__(obj)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ubuntu/Desktop/openpi/.venv/lib/python3.11/site-packages/beartype/_check/forward/reference/fwdrefabc.py", line 112, in __is_instance_beartype__
    return isinstance(obj, cls.__type_beartype__)  # type: ignore[arg-type]
                           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/ubuntu/Desktop/openpi/.venv/lib/python3.11/site-packages/beartype/_check/forward/reference/fwdrefmeta.py", line 451, in __type_beartype__
    referent = import_module_attr(
               ^^^^^^^^^^^^^^^^^^^
  File "/home/ubuntu/Desktop/openpi/.venv/lib/python3.11/site-packages/beartype/_util/module/utilmodimport.py", line 295, in import_module_attr
    raise exception_cls(exception_message)
beartype.roar.BeartypeCallHintForwardRefException: Forward reference "ArrayTree" unimportable from module "openpi.training.utils".

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/ubuntu/Desktop/openpi/.venv/lib/python3.11/site-packages/jaxtyping/_decorator.py", line 811, in _get_problem_arg
    fn(*args, **kwargs)
  File "<@beartype(openpi.training.utils.check_single_arg) at 0x7f201072a660>", line 53, in check_single_arg
  File "/home/ubuntu/Desktop/openpi/.venv/lib/python3.11/site-packages/beartype/_check/forward/reference/fwdrefmeta.py", line 295, in __instancecheck__
    return cls.__is_instance_beartype__(obj)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ubuntu/Desktop/openpi/.venv/lib/python3.11/site-packages/beartype/_check/forward/reference/fwdrefabc.py", line 112, in __is_instance_beartype__
    return isinstance(obj, cls.__type_beartype__)  # type: ignore[arg-type]
                           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/ubuntu/Desktop/openpi/.venv/lib/python3.11/site-packages/beartype/_check/forward/reference/fwdrefmeta.py", line 451, in __type_beartype__
    referent = import_module_attr(
               ^^^^^^^^^^^^^^^^^^^
  File "/home/ubuntu/Desktop/openpi/.venv/lib/python3.11/site-packages/beartype/_util/module/utilmodimport.py", line 295, in import_module_attr
    raise exception_cls(exception_message)
beartype.roar.BeartypeCallHintForwardRefException: Forward reference "ArrayTree" unimportable from module "openpi.training.utils".

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/ubuntu/Desktop/openpi/.venv/lib/python3.11/site-packages/jaxtyping/_decorator.py", line 451, in wrapped_fn_impl
    argmsg = _get_problem_arg(
             ^^^^^^^^^^^^^^^^^
  File "/home/ubuntu/Desktop/openpi/.venv/lib/python3.11/site-packages/jaxtyping/_decorator.py", line 814, in _get_problem_arg
    raise TypeCheckError(
jaxtyping.TypeCheckError:
The problem arose whilst typechecking parameter 'opt_state'.

Here is my config file:


TrainConfig(
        name="pi0_fractal",
        # Here is an example of loading a pi0 model for LoRA fine-tuning.
        model=pi0.Pi0Config(action_dim=7, action_horizon=10, max_token_len=180, paligemma_variant="gemma_2b_lora", action_expert_variant="gemma_300m"),
        data=LeRobotFractalDataConfig(
            repo_id="IPEC-COMMUNITY/fractal20220817_data_lerobot",
            base_config=DataConfig(
                local_files_only=False,  # Set to True for local-only datasets.
                prompt_from_task=True,
            ),
        ),
        batch_size=8,
        num_workers=64,
        weight_loader=weight_loaders.CheckpointWeightLoader("s3://openpi-assets/checkpoints/pi0_base/params"),
        num_train_steps=30_000,
        # The freeze filter defines which parameters should be frozen during training.
        # We have a convenience function in the model config that returns the default freeze filter
        # for the given model config for LoRA finetuning. Just make sure it matches the model config
        # you chose above.
        freeze_filter=pi0.Pi0Config(
            paligemma_variant="gemma_2b_lora", action_expert_variant="gemma_300m"
        ).get_freeze_filter(),
        # Turn off EMA for LoRA finetuning.
        ema_decay=None,
    ),

Any help is welcome!!
Thanks

@t-rakko
Copy link

t-rakko commented May 28, 2025

Hi, I'm new to OpenPi and ran into the same error while trying to fine-tune the pi0_base model on HSR dataset.

Since the error seemed to be caused by type checking of "opt_state", I tried modifying the type definition of the TrainState class in openpi/src/openpi/training/utils.py (starting at line 15) from:
opt_state: optax.OptState
to:
opt_state: Any
This change resolved the error, and I was able to start training successfully.

However, I'm a bit concerned that removing the type check might cause issues later. I’d appreciate it if someone more experienced with OpenPi could confirm whether this workaround is acceptable or if there's a better fix.

Hope this helps.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants