Skip to content

Error when running MPO script run_mpo.py: ValueError: Graph parent item 0 is not a Tensor; #342

Open
@CharlieMou

Description

@CharlieMou

The following is the terminal output when I run python run_mpo.py

(acme) root@autodl-container-fa5b45ae8c-68d88031:~/acme/examples/baselines/rl_continuous# python run_mpo.py
/root/miniconda3/envs/acme/lib/python3.8/site-packages/tensorflow_probability/python/__init__.py:57: DeprecationWarning: distutils Version classes are deprecated. Use packaging.version instead.
  if (distutils.version.LooseVersion(tf.__version__) <
I0606 21:16:20.734650 140006953583808 xla_bridge.py:355] Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker: 
I0606 21:16:20.824139 140006953583808 xla_bridge.py:355] Unable to initialize backend 'rocm': NOT_FOUND: Could not find registered platform with name: "rocm". Available platform names are: Host CUDA Interpreter
I0606 21:16:20.824507 140006953583808 xla_bridge.py:355] Unable to initialize backend 'tpu': module 'jaxlib.xla_extension' has no attribute 'get_tpu_client'
I0606 21:16:20.824591 140006953583808 xla_bridge.py:355] Unable to initialize backend 'plugin': xla_extension has no attributes named get_plugin_device_client. Compile TensorFlow with //tensorflow/compiler/xla/python:enable_plugin_device set to true (defaults to false) to enable this.
I0606 21:16:21.697998 139961080727296 courier_utils.py:120] Binding: run
I0606 21:16:21.699319 139961080727296 lp_utils.py:87] StepsLimiter: Starting with max_steps = 1000000 (actor_steps)
I0606 21:16:21.700335 139936644708096 node.py:61] Reverb client connecting to: localhost:34247
I0606 21:16:21.702090 139961089120000 savers.py:166] Attempting to restore checkpoint: None
/root/miniconda3/envs/acme/lib/python3.8/site-packages/gym/envs/registration.py:592: UserWarning: WARN: The environment HalfCheetah-v2 is out of date. You should consider upgrading to version `v4`.
  logger.warn(
I0606 21:16:21.715310 139936057513728 node.py:61] Reverb client connecting to: localhost:34247
I0606 21:16:21.716707 139916252006144 node.py:61] Reverb client connecting to: localhost:34247
I0606 21:16:21.717766 139915740313344 node.py:61] Reverb client connecting to: localhost:34247
I0606 21:16:21.724305 139915715135232 node.py:61] Reverb client connecting to: localhost:34247
I0606 21:16:21.725041 139961089120000 courier_utils.py:120] Binding: get_counts
I0606 21:16:21.861637 139961089120000 courier_utils.py:120] Binding: get_directory
I0606 21:16:21.876847 139961089120000 courier_utils.py:120] Binding: get_steps_key
I0606 21:16:21.892327 139961089120000 courier_utils.py:120] Binding: increment
I0606 21:16:21.914711 139961089120000 courier_utils.py:120] Binding: restore
I0606 21:16:21.921361 139961089120000 courier_utils.py:120] Binding: save
I0606 21:16:21.924899 139961089120000 savers.py:155] Saving checkpoint: /root/acme/20250606-211617/checkpoints/counter
/root/miniconda3/envs/acme/lib/python3.8/site-packages/Cython/Distutils/old_build_ext.py:15: DeprecationWarning: dep_util is Deprecated. Use functions from setuptools instead.
  from distutils.dep_util import newer, newer_group
I0606 21:16:22.730576 139961080727296 lp_utils.py:95] StepsLimiter: Reached 0 recorded steps
/root/miniconda3/envs/acme/lib/python3.8/site-packages/gym/envs/mujoco/mujoco_env.py:237: UserWarning: WARN: This version of the mujoco environments depends on the mujoco-py bindings, which are no longer maintained and may stop working. Please upgrade to the v4 versions of the environments (which depend on the mujoco python bindings instead), unless you are trying to precisely replicate previous works).
  logger.warn(
/root/miniconda3/envs/acme/lib/python3.8/site-packages/gym/core.py:329: DeprecationWarning: WARN: Initializing wrapper in old step API which returns one bool instead of two. It is recommended to set `new_step_api=True` to use new step API. This will be the default behaviour in future.
  deprecation(
/root/miniconda3/envs/acme/lib/python3.8/site-packages/gym/wrappers/step_api_compatibility.py:39: DeprecationWarning: WARN: Initializing environment in old step API which returns one bool instead of two. It is recommended to set `new_step_api=True` to use new step API. This will be the default behaviour in future.
  deprecation(
I0606 21:16:23.305870 139916252006144 csv.py:76] Logging to /root/acme/20250606-211617/logs/actor/logs.csv
I0606 21:16:23.306252 139916252006144 courier_utils.py:120] Binding: run
I0606 21:16:23.469376 139916252006144 courier_utils.py:120] Binding: run_episode
I0606 21:16:23.472962 139915715135232 csv.py:76] Logging to /root/acme/20250606-211617/logs/actor/logs.csv
I0606 21:16:23.473167 139915715135232 courier_utils.py:120] Binding: run
I0606 21:16:23.473235 139915715135232 courier_utils.py:120] Binding: run_episode
I0606 21:16:23.479296 139936057513728 csv.py:76] Logging to /root/acme/20250606-211617/logs/actor/logs.csv
I0606 21:16:23.479869 139936057513728 courier_utils.py:120] Binding: run
I0606 21:16:23.479939 139936057513728 courier_utils.py:120] Binding: run_episode
I0606 21:16:23.481003 139915740313344 csv.py:76] Logging to /root/acme/20250606-211617/logs/actor/logs.csv
I0606 21:16:23.481796 139915740313344 courier_utils.py:120] Binding: run
I0606 21:16:23.481884 139915740313344 courier_utils.py:120] Binding: run_episode
I0606 21:16:23.584140 139915706742528 csv.py:76] Logging to /root/acme/20250606-211617/logs/evaluator/logs.csv
I0606 21:16:23.589572 139915706742528 courier_utils.py:120] Binding: run
I0606 21:16:23.600646 139915706742528 courier_utils.py:120] Binding: run_episode
I0606 21:16:23.625711 139937162417920 builder.py:248] Creating off-policy replay buffer with replay fraction 1 of batch 256
[reverb/cc/platform/tfrecord_checkpointer.cc:162]  Initializing TFRecordCheckpointer in /tmp/tmph0y27ie2.
[reverb/cc/platform/tfrecord_checkpointer.cc:552] Loading latest checkpoint from /tmp/tmph0y27ie2
[reverb/cc/platform/default/server.cc:71] Started replay server on port 34247
I0606 21:16:25.226038 139936644708096 csv.py:76] Logging to /root/acme/20250606-211617/logs/learner/logs.csv
I0606 21:16:25.226341 139936644708096 learning.py:126] Learner process id: 0. Devices passed: [StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0)]
I0606 21:16:25.226398 139936644708096 learning.py:128] Learner process id: 0. Local devices from JAX API: [StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0)]
[reverb/cc/client.cc:165] Sampler and server are owned by the same process (866) so Table priority_table is accessed directly without gRPC.
[reverb/cc/client.cc:165] Sampler and server are owned by the same process (866) so Table priority_table is accessed directly without gRPC.
[reverb/cc/client.cc:165] Sampler and server are owned by the same process (866) so Table priority_table is accessed directly without gRPC.
[reverb/cc/client.cc:165] Sampler and server are owned by the same process (866) so Table priority_table is accessed directly without gRPC.
[reverb/cc/client.cc:165] Sampler and server are owned by the same process (866) so Table priority_table is accessed directly without gRPC.
Node ThreadWorker(thread=<Thread(learner, stopped daemon 139936644708096)>, future=<Future at 0x7f55302f27c0 state=finished raised ValueError>) crashed:
Traceback (most recent call last):
  File "/root/miniconda3/envs/acme/lib/python3.8/site-packages/launchpad/launch/worker_manager.py", line 474, in _check_workers
    worker.future.result()
  File "/root/miniconda3/envs/acme/lib/python3.8/concurrent/futures/_base.py", line 437, in result
    return self.__get_result()
  File "/root/miniconda3/envs/acme/lib/python3.8/concurrent/futures/_base.py", line 389, in __get_result
    raise self._exception
  File "/root/miniconda3/envs/acme/lib/python3.8/site-packages/launchpad/launch/worker_manager.py", line 250, in run_inner
    future.set_result(f())
  File "/root/miniconda3/envs/acme/lib/python3.8/site-packages/launchpad/nodes/python/node.py", line 75, in _construct_function
    return functools.partial(self._function, *args, **kwargs)()
  File "/root/miniconda3/envs/acme/lib/python3.8/site-packages/launchpad/nodes/courier/node.py", line 113, in run
    instance = self._construct_instance()  # pytype:disable=wrong-arg-types
  File "/root/miniconda3/envs/acme/lib/python3.8/site-packages/launchpad/nodes/python/node.py", line 180, in _construct_instance
    self._instance = self._constructor(*args, **kwargs)
  File "/root/miniconda3/envs/acme/lib/python3.8/site-packages/acme/jax/experiments/make_distributed_experiment.py", line 185, in build_learner
    learner = experiment.builder.make_learner(random_key, networks, iterator,
  File "/root/miniconda3/envs/acme/lib/python3.8/site-packages/acme/agents/jax/mpo/builder.py", line 165, in make_learner
    learner = learning.MPOLearner(
  File "/root/miniconda3/envs/acme/lib/python3.8/site-packages/acme/agents/jax/mpo/learning.py", line 223, in __init__
    network_params, _ = mpo_networks.init_params(
  File "/root/miniconda3/envs/acme/lib/python3.8/site-packages/acme/agents/jax/mpo/networks.py", line 118, in init_params
    params_policy_head = networks.policy_head.init(rng_keys[2], embeddings)
  File "/root/miniconda3/envs/acme/lib/python3.8/site-packages/haiku/_src/transform.py", line 114, in init_fn
    params, state = f.init(*args, **kwargs)
  File "/root/miniconda3/envs/acme/lib/python3.8/site-packages/haiku/_src/transform.py", line 338, in init_fn
    f(*args, **kwargs)
  File "/root/miniconda3/envs/acme/lib/python3.8/site-packages/acme/agents/jax/mpo/networks.py", line 217, in policy_fn
    return networks_lib.MultivariateNormalDiagHead(
  File "/root/miniconda3/envs/acme/lib/python3.8/site-packages/haiku/_src/module.py", line 426, in wrapped
    out = f(*args, **kwargs)
  File "/root/miniconda3/envs/acme/lib/python3.8/contextlib.py", line 75, in inner
    return func(*args, **kwds)
  File "/root/miniconda3/envs/acme/lib/python3.8/site-packages/haiku/_src/module.py", line 272, in run_interceptors
    return bound_method(*args, **kwargs)
  File "/root/miniconda3/envs/acme/lib/python3.8/site-packages/acme/jax/networks/distributional.py", line 297, in __call__
    return tfd.MultivariateNormalDiag(loc=loc, scale_diag=scale)
  File "/root/miniconda3/envs/acme/lib/python3.8/site-packages/decorator.py", line 235, in fun
    return caller(func, *(extras + args), **kw)
  File "/root/miniconda3/envs/acme/lib/python3.8/site-packages/tensorflow_probability/substrates/jax/distributions/distribution.py", line 342, in wrapped_init
    default_init(self_, *args, **kwargs)
  File "/root/miniconda3/envs/acme/lib/python3.8/site-packages/tensorflow_probability/substrates/jax/distributions/mvn_diag.py", line 235, in __init__
    scale = diag_cls(
  File "/root/miniconda3/envs/acme/lib/python3.8/site-packages/tensorflow_probability/python/internal/backend/jax/gen/linear_operator_diag.py", line 191, in __init__
    self._set_graph_parents([self._diag])
  File "/root/miniconda3/envs/acme/lib/python3.8/site-packages/tensorflow_probability/python/internal/backend/jax/gen/linear_operator.py", line 1178, in _set_graph_parents
    raise ValueError("Graph parent item %d is not a Tensor; %s." % (i, t))
ValueError: Graph parent item 0 is not a Tensor; [[0.500001 0.500001 0.500001 0.500001 0.500001 0.500001]].
I0606 21:16:27.433487 140006953583808 savers.py:220] Caught SIGTERM: forcing a checkpoint save.
I0606 21:16:27.433681 140006953583808 savers.py:155] Saving checkpoint: /root/acme/20250606-211617/checkpoints/counter
Error in atexit._run_exitfuncs:
Traceback (most recent call last):
  File "/root/miniconda3/envs/acme/lib/python3.8/site-packages/launchpad/launch/worker_manager.py", line 428, in wait
    raise failure
  File "/root/miniconda3/envs/acme/lib/python3.8/site-packages/launchpad/launch/worker_manager.py", line 474, in _check_workers
    worker.future.result()
  File "/root/miniconda3/envs/acme/lib/python3.8/concurrent/futures/_base.py", line 437, in result
    return self.__get_result()
  File "/root/miniconda3/envs/acme/lib/python3.8/concurrent/futures/_base.py", line 389, in __get_result
    raise self._exception
  File "/root/miniconda3/envs/acme/lib/python3.8/site-packages/launchpad/launch/worker_manager.py", line 250, in run_inner
    future.set_result(f())
  File "/root/miniconda3/envs/acme/lib/python3.8/site-packages/launchpad/nodes/python/node.py", line 75, in _construct_function
    return functools.partial(self._function, *args, **kwargs)()
  File "/root/miniconda3/envs/acme/lib/python3.8/site-packages/launchpad/nodes/courier/node.py", line 113, in run
    instance = self._construct_instance()  # pytype:disable=wrong-arg-types
  File "/root/miniconda3/envs/acme/lib/python3.8/site-packages/launchpad/nodes/python/node.py", line 180, in _construct_instance
    self._instance = self._constructor(*args, **kwargs)
  File "/root/miniconda3/envs/acme/lib/python3.8/site-packages/acme/jax/experiments/make_distributed_experiment.py", line 185, in build_learner
    learner = experiment.builder.make_learner(random_key, networks, iterator,
  File "/root/miniconda3/envs/acme/lib/python3.8/site-packages/acme/agents/jax/mpo/builder.py", line 165, in make_learner
    learner = learning.MPOLearner(
  File "/root/miniconda3/envs/acme/lib/python3.8/site-packages/acme/agents/jax/mpo/learning.py", line 223, in __init__
    network_params, _ = mpo_networks.init_params(
  File "/root/miniconda3/envs/acme/lib/python3.8/site-packages/acme/agents/jax/mpo/networks.py", line 118, in init_params
    params_policy_head = networks.policy_head.init(rng_keys[2], embeddings)
  File "/root/miniconda3/envs/acme/lib/python3.8/site-packages/haiku/_src/transform.py", line 114, in init_fn
    params, state = f.init(*args, **kwargs)
  File "/root/miniconda3/envs/acme/lib/python3.8/site-packages/haiku/_src/transform.py", line 338, in init_fn
    f(*args, **kwargs)
  File "/root/miniconda3/envs/acme/lib/python3.8/site-packages/acme/agents/jax/mpo/networks.py", line 217, in policy_fn
    return networks_lib.MultivariateNormalDiagHead(
  File "/root/miniconda3/envs/acme/lib/python3.8/site-packages/haiku/_src/module.py", line 426, in wrapped
    out = f(*args, **kwargs)
  File "/root/miniconda3/envs/acme/lib/python3.8/contextlib.py", line 75, in inner
    return func(*args, **kwds)
  File "/root/miniconda3/envs/acme/lib/python3.8/site-packages/haiku/_src/module.py", line 272, in run_interceptors
    return bound_method(*args, **kwargs)
  File "/root/miniconda3/envs/acme/lib/python3.8/site-packages/acme/jax/networks/distributional.py", line 297, in __call__
    return tfd.MultivariateNormalDiag(loc=loc, scale_diag=scale)
  File "/root/miniconda3/envs/acme/lib/python3.8/site-packages/decorator.py", line 235, in fun
    return caller(func, *(extras + args), **kw)
  File "/root/miniconda3/envs/acme/lib/python3.8/site-packages/tensorflow_probability/substrates/jax/distributions/distribution.py", line 342, in wrapped_init
    default_init(self_, *args, **kwargs)
  File "/root/miniconda3/envs/acme/lib/python3.8/site-packages/tensorflow_probability/substrates/jax/distributions/mvn_diag.py", line 235, in __init__
    scale = diag_cls(
  File "/root/miniconda3/envs/acme/lib/python3.8/site-packages/tensorflow_probability/python/internal/backend/jax/gen/linear_operator_diag.py", line 191, in __init__
    self._set_graph_parents([self._diag])
  File "/root/miniconda3/envs/acme/lib/python3.8/site-packages/tensorflow_probability/python/internal/backend/jax/gen/linear_operator.py", line 1178, in _set_graph_parents
    raise ValueError("Graph parent item %d is not a Tensor; %s." % (i, t))
ValueError: Graph parent item 0 is not a Tensor; [[0.500001 0.500001 0.500001 0.500001 0.500001 0.500001]].

Here is my environment setup:

  • Python 3.8.20
  • CUDA 11.8
Package                      Version
---------------------------- --------------------
absl-py                      1.4.0
ale-py                       0.7.5
astunparse                   1.6.3
atari-py                     0.2.9
attrs                        25.3.0
backports.zoneinfo           0.2.1
bsuite                       0.3.5
cached-property              2.0.1
cachetools                   5.5.2
certifi                      2025.4.26
cffi                         1.17.1
charset-normalizer           3.4.2
chex                         0.1.7
cloudpickle                  3.1.1
contourpy                    1.1.1
cycler                       0.12.1
Cython                       0.29.37
decorator                    5.2.1
dill                         0.4.0
distrax                      0.1.3
dm-acme                      0.4.1
dm_control                   1.0.23
dm-env                       1.6
dm-haiku                     0.0.9
dm-launchpad                 0.5.2
dm-reverb                    0.7.2
dm-sonnet                    2.0.2
dm-tree                      0.1.8
etils                        1.3.0
exceptiongroup               1.3.0
execnet                      2.1.1
fasteners                    0.19
flatbuffers                  25.2.10
flax                         0.7.2
fonttools                    4.57.0
frozendict                   2.4.6
gast                         0.6.0
glfw                         2.9.0
google-auth                  2.40.2
google-auth-oauthlib         0.4.6
google-pasta                 0.2.0
googleapis-common-protos     1.70.0
grpcio                       1.70.0
gym                          0.25.0
gym-notices                  0.0.8
h5py                         3.11.0
idna                         3.10
imageio                      2.35.1
importlab                    0.8.1
importlib_metadata           8.5.0
importlib_resources          6.4.5
iniconfig                    2.1.0
jax                          0.4.3
jaxlib                       0.4.3+cuda11.cudnn86
jmp                          0.0.4
keras                        2.8.0
Keras-Preprocessing          1.1.2
kiwisolver                   1.4.7
labmaze                      1.0.6
lazy_loader                  0.4
libclang                     18.1.1
libcst                       1.1.0
lockfile                     0.12.2
lxml                         5.4.0
Markdown                     3.7
markdown-it-py               3.0.0
MarkupSafe                   2.1.5
matplotlib                   3.7.5
mdurl                        0.1.2
mizani                       0.9.3
ml-dtypes                    0.2.0
mock                         5.2.0
msgpack                      1.1.0
mujoco                       3.2.3
mujoco-py                    2.0.2.5
mypy_extensions              1.1.0
nest-asyncio                 1.6.0
networkx                     3.1
ninja                        1.11.1.4
numpy                        1.22.4
oauthlib                     3.2.2
opt_einsum                   3.4.0
optax                        0.1.4
orbax-checkpoint             0.1.6
packaging                    25.0
pandas                       2.0.3
patsy                        1.0.1
pillow                       10.4.0
pip                          25.0.1
plotnine                     0.10.1
pluggy                       1.5.0
portpicker                   1.6.0
promise                      2.3
protobuf                     3.20.3
psutil                       7.0.0
pyasn1                       0.6.1
pyasn1_modules               0.4.2
pycparser                    2.22
pygame                       2.1.0
Pygments                     2.19.1
PyOpenGL                     3.1.9
pyparsing                    3.1.4
pytest                       8.3.5
pytest-xdist                 3.6.1
python-dateutil              2.9.0.post0
pytype                       2021.8.11
pytz                         2025.2
PyWavelets                   1.4.1
PyYAML                       6.0.2
requests                     2.32.3
requests-oauthlib            2.0.0
rich                         14.0.0
rlax                         0.1.5
rlds                         0.1.8
rsa                          4.9.1
scikit-image                 0.21.0
scipy                        1.10.1
setuptools                   75.3.2
six                          1.17.0
statsmodels                  0.14.1
tabulate                     0.9.0
tensorboard                  2.8.0
tensorboard-data-server      0.6.1
tensorboard-plugin-wit       1.8.1
tensorflow                   2.8.0
tensorflow-datasets          4.6.0
tensorflow-io-gcs-filesystem 0.34.0
tensorflow-metadata          1.14.0
tensorflow-probability       0.15.0
tensorstore                  0.1.45
termcolor                    2.4.0
tf-estimator-nightly         2.8.0.dev2021122109
tifffile                     2023.7.10
toml                         0.10.2
tomli                        2.2.1
toolz                        1.0.0
tqdm                         4.67.1
trfl                         1.2.0
typed-ast                    1.5.5
typing_extensions            4.13.2
typing-inspect               0.9.0
tzdata                       2025.2
urllib3                      2.2.3
Werkzeug                     3.0.6
wheel                        0.45.1
wrapt                        1.17.2
zipp                         3.20.2

A similar open issue can be found at #282

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions