Skip to content

Update deprecated function and stricter requirements #98

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

chertianser
Copy link

  1. python>=3.9 for jax>=0.4.14
  2. deprecation of jax.linear_util when jax>=0.4.24, change to jax.extend.linear_util
  3. numpy<2.0.0 due to ml-dtypes library's incompatibility with numpy>=2.0.0. Installed version is ml-dtypes=0.2.0, released one year before numpy=2.0.0. I was not able to get later versions of ml-dtypes to install with pip install -e . in the GradDFT repository, but it is possible that the latest version ml-dtypes=0.5.0 can solve this incompatibility.

Full trace for issue 3 is reproduced here upon import of grad_dft:

A module that was compiled using NumPy 1.x cannot be run in
NumPy 2.0.2 as it may crash. To support both 1.x and 2.x
versions of NumPy, modules must be compiled with NumPy 2.0.
Some module may need to rebuild instead e.g. with 'pybind11>=2.12'.

If you are a user of the module, the easiest solution will be to
downgrade to 'numpy<2' or try to upgrade the affected module.
We expect that some modules will need time to support NumPy 2.

Traceback (most recent call last):  File "/h/292/ctser/.conda/envs/graddftnp2/lib/python3.9/runpy.py", line 197, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/h/292/ctser/.conda/envs/graddftnp2/lib/python3.9/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/h/292/ctser/.conda/envs/graddftnp2/lib/python3.9/site-packages/ipykernel_launcher.py", line 18, in <module>
    app.launch_new_instance()
  File "/h/292/ctser/.conda/envs/graddftnp2/lib/python3.9/site-packages/traitlets/config/application.py", line 1075, in launch_instance
    app.start()
  File "/h/292/ctser/.conda/envs/graddftnp2/lib/python3.9/site-packages/ipykernel/kernelapp.py", line 739, in start
    self.io_loop.start()
  File "/h/292/ctser/.conda/envs/graddftnp2/lib/python3.9/site-packages/tornado/platform/asyncio.py", line 205, in start
    self.asyncio_loop.run_forever()
  File "/h/292/ctser/.conda/envs/graddftnp2/lib/python3.9/asyncio/base_events.py", line 601, in run_forever
    self._run_once()
  File "/h/292/ctser/.conda/envs/graddftnp2/lib/python3.9/asyncio/base_events.py", line 1905, in _run_once
    handle._run()
  File "/h/292/ctser/.conda/envs/graddftnp2/lib/python3.9/asyncio/events.py", line 80, in _run
    self._context.run(self._callback, *self._args)
  File "/h/292/ctser/.conda/envs/graddftnp2/lib/python3.9/site-packages/ipykernel/kernelbase.py", line 545, in dispatch_queue
    await self.process_one()
  File "/h/292/ctser/.conda/envs/graddftnp2/lib/python3.9/site-packages/ipykernel/kernelbase.py", line 534, in process_one
    await dispatch(*args)
  File "/h/292/ctser/.conda/envs/graddftnp2/lib/python3.9/site-packages/ipykernel/kernelbase.py", line 437, in dispatch_shell
    await result
  File "/h/292/ctser/.conda/envs/graddftnp2/lib/python3.9/site-packages/ipykernel/ipkernel.py", line 362, in execute_request
    await super().execute_request(stream, ident, parent)
  File "/h/292/ctser/.conda/envs/graddftnp2/lib/python3.9/site-packages/ipykernel/kernelbase.py", line 778, in execute_request
    reply_content = await reply_content
  File "/h/292/ctser/.conda/envs/graddftnp2/lib/python3.9/site-packages/ipykernel/ipkernel.py", line 449, in do_execute
    res = shell.run_cell(
  File "/h/292/ctser/.conda/envs/graddftnp2/lib/python3.9/site-packages/ipykernel/zmqshell.py", line 549, in run_cell
    return super().run_cell(*args, **kwargs)
  File "/h/292/ctser/.conda/envs/graddftnp2/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3048, in run_cell
    result = self._run_cell(
  File "/h/292/ctser/.conda/envs/graddftnp2/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3103, in _run_cell
    result = runner(coro)
  File "/h/292/ctser/.conda/envs/graddftnp2/lib/python3.9/site-packages/IPython/core/async_helpers.py", line 129, in _pseudo_sync_runner
    coro.send(None)
  File "/h/292/ctser/.conda/envs/graddftnp2/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3308, in run_cell_async
    has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
  File "/h/292/ctser/.conda/envs/graddftnp2/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3490, in run_ast_nodes
    if await self.run_code(code, result, async_=asy):
  File "/h/292/ctser/.conda/envs/graddftnp2/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3550, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "/tmp/ipykernel_3117551/3537282310.py", line 1, in <module>
    from grad_dft import (
  File "/h/292/ctser/GradDFT/grad_dft/__init__.py", line 15, in <module>
    from .molecule import (
  File "/h/292/ctser/GradDFT/grad_dft/molecule.py", line 19, in <module>
    from grad_dft.utils import vmap_chunked
  File "/h/292/ctser/GradDFT/grad_dft/utils/__init__.py", line 15, in <module>
    from .types import (
  File "/h/292/ctser/GradDFT/grad_dft/utils/types.py", line 16, in <module>
    import jax
  File "/h/292/ctser/.conda/envs/graddftnp2/lib/python3.9/site-packages/jax/__init__.py", line 25, in <module>
    from jax._src.cloud_tpu_init import cloud_tpu_init as _cloud_tpu_init
  File "/h/292/ctser/.conda/envs/graddftnp2/lib/python3.9/site-packages/jax/_src/cloud_tpu_init.py", line 17, in <module>
    from jax._src import config
  File "/h/292/ctser/.conda/envs/graddftnp2/lib/python3.9/site-packages/jax/_src/config.py", line 27, in <module>
    from jax._src import lib
  File "/h/292/ctser/.conda/envs/graddftnp2/lib/python3.9/site-packages/jax/_src/lib/__init__.py", line 87, in <module>
    import jaxlib.xla_client as xla_client
  File "/h/292/ctser/.conda/envs/graddftnp2/lib/python3.9/site-packages/jaxlib/xla_client.py", line 30, in <module>
    import ml_dtypes
  File "/h/292/ctser/.conda/envs/graddftnp2/lib/python3.9/site-packages/ml_dtypes/__init__.py", line 32, in <module>
    from ml_dtypes._custom_floats import bfloat16
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
AttributeError: _ARRAY_API not found
---------------------------------------------------------------------------
ImportError                               Traceback (most recent call last)
ImportError: numpy.core._multiarray_umath failed to import
---------------------------------------------------------------------------
ImportError                               Traceback (most recent call last)
Cell In[1], line 1
----> 1 from grad_dft import (
      2 	energy_predictor,
      3 	simple_energy_loss,
      4 	NeuralFunctional,
      5 	molecule_from_pyscf
      6 )
      7 from pyscf import gto, dft

File ~/GradDFT/grad_dft/__init__.py:15
      1 # Copyright 2023 Xanadu Quantum Technologies Inc.
      2 
      3 # Licensed under the Apache License, Version 2.0 (the "License");
   (...)
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
---> 15 from .molecule import (
     16     Grid,
     17     Molecule, 
     18     Reaction, 
     19     make_reaction,
     20     abs_clip, 
     21     make_rdm1, 
     22     orbital_grad,
     23     density,
     24     grad_density,
     25     coulomb_energy
     26 )
     27 from .solid import (
     28     Solid
     29 )
     30 from .functional import (
     31     DispersionFunctional,
     32     Functional, 
   (...)
     44     dm21_hfgrads_densities,
     45 )

File ~/GradDFT/grad_dft/molecule.py:19
     16 from dataclasses import fields
     18 from typeguard import typechecked
---> 19 from grad_dft.utils import vmap_chunked
     20 from functools import partial
     22 from jax import numpy as jnp

File ~/GradDFT/grad_dft/utils/__init__.py:15
      1 # Copyright 2023 Xanadu Quantum Technologies Inc.
      2 
      3 # Licensed under the Apache License, Version 2.0 (the "License");
   (...)
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
---> 15 from .types import (
     16     Ansatz,
     17     Key,
     18     PyTree,
     19     Array,
     20     Scalar,
     21     Optimizer,
     22     Device,
     23     DType,
     24     HartreeFock,
     25     DensityFunctional,
     26     default_dtype,
     27 )
     28 from .tree import tree_size, tree_isfinite, tree_randn_like, tree_func, tree_shape
     29 from .utils import to_device_arrays, Utils

File ~/GradDFT/grad_dft/utils/types.py:16
      1 # Copyright 2023 Xanadu Quantum Technologies Inc.
      2 
      3 # Licensed under the Apache License, Version 2.0 (the "License");
   (...)
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     15 from typing import Union
---> 16 import jax
     17 from jax import numpy as jnp
     19 from flax import linen as nn

File ~/.conda/envs/graddftnp2/lib/python3.9/site-packages/jax/__init__.py:25
     22 from jax.version import __version_info__ as __version_info__
     24 # Set Cloud TPU env vars if necessary before transitively loading C++ backend
---> 25 from jax._src.cloud_tpu_init import cloud_tpu_init as _cloud_tpu_init
     26 try:
     27   _cloud_tpu_init()

File ~/.conda/envs/graddftnp2/lib/python3.9/site-packages/jax/_src/cloud_tpu_init.py:17
     15 import os
     16 from jax import version
---> 17 from jax._src import config
     18 from jax._src import hardware_utils
     19 from typing import Optional

File ~/.conda/envs/graddftnp2/lib/python3.9/site-packages/jax/_src/config.py:27
     24 import threading
     25 from typing import Any, Callable, Generic, NamedTuple, NoReturn, TypeVar, cast
---> 27 from jax._src import lib
     28 from jax._src.lib import jax_jit
     29 from jax._src.lib import transfer_guard_lib

File ~/.conda/envs/graddftnp2/lib/python3.9/site-packages/jax/_src/lib/__init__.py:87
     84 cpu_feature_guard.check_cpu_features()
     86 import jaxlib.utils as utils
---> 87 import jaxlib.xla_client as xla_client
     88 import jaxlib.lapack as lapack
     90 xla_extension = xla_client._xla

File ~/.conda/envs/graddftnp2/lib/python3.9/site-packages/jaxlib/xla_client.py:30
     27 import threading
     28 from typing import Any, Protocol, Union
---> 30 import ml_dtypes
     31 import numpy as np
     33 from . import xla_extension as _xla

File ~/.conda/envs/graddftnp2/lib/python3.9/site-packages/ml_dtypes/__init__.py:32
     16 __all__ = [
     17     '__version__',
     18     'bfloat16',
   (...)
     27     'uint4',
     28 ]
     30 from typing import Type
---> 32 from ml_dtypes._custom_floats import bfloat16
     33 from ml_dtypes._custom_floats import float8_e4m3b11fnuz
     34 from ml_dtypes._custom_floats import float8_e4m3fn

ImportError: numpy.core.umath failed to import

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant