Skip to content

Commit 7264532

Browse files
authored
[REQ] Support PyTorch 2.x (#307)
* [REQ] Remove upper version restrictions for `torch` and `torchvision` * [REQ] Bump python to 3.8+ * [REF] Replace `Tensor.symeig` with `torch.linalg.eigh` * [CI] Replace `python3.7` with `python3.8` * [REF] Try fixing syntax for `flake8` in `setup.cfg` * [TEST] Skip double-backward of LSTM for PyTorch2.0.1 See pytorch/pytorch#99413 * [FIX] flake8 * [TEST] Skip `jac_mat_prod` for LSTM in PyTorch2.0.1 double-backward not supported pytorch/pytorch#99413 * [CI] Use python3.8 in RTD build * [CI] Skip LSTM for PyTorch2.0.1 in DiagGGN tests * [FIX] Imports * [FIX] Turn off MKLDNN in RNN example --------- Co-authored-by: Felix Dangel <[email protected]>
1 parent 7b0b712 commit 7264532

File tree

15 files changed

+99
-47
lines changed

15 files changed

+99
-47
lines changed

.conda_env.yml

+3-6
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,7 @@ channels:
33
- pytorch
44
- defaults
55
dependencies:
6-
- pip=19.3.1
7-
- python=3.7.6
6+
- pip=21.2.4
7+
- python=3.8.5
88
- pip:
9-
- -e .
10-
- -e .[lint]
11-
- -e .[test]
12-
- -e .[docs]
9+
- -e .[lint,test,doc]

.github/workflows/lint.yaml

+14-14
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,10 @@ jobs:
1515
runs-on: ubuntu-latest
1616
steps:
1717
- uses: actions/checkout@v1
18-
- name: Set up Python 3.7
18+
- name: Set up Python 3.8
1919
uses: actions/setup-python@v1
2020
with:
21-
python-version: 3.7
21+
python-version: 3.8
2222
- name: Install dependencies
2323
run: |
2424
python -m pip install --upgrade pip
@@ -30,10 +30,10 @@ jobs:
3030
runs-on: ubuntu-latest
3131
steps:
3232
- uses: actions/checkout@v1
33-
- name: Set up Python 3.7
33+
- name: Set up Python 3.8
3434
uses: actions/setup-python@v1
3535
with:
36-
python-version: 3.7
36+
python-version: 3.8
3737
- name: Install dependencies
3838
run: |
3939
python -m pip install --upgrade pip
@@ -45,10 +45,10 @@ jobs:
4545
runs-on: ubuntu-latest
4646
steps:
4747
- uses: actions/checkout@v1
48-
- name: Set up Python 3.7
48+
- name: Set up Python 3.8
4949
uses: actions/setup-python@v1
5050
with:
51-
python-version: 3.7
51+
python-version: 3.8
5252
- name: Install dependencies
5353
run: |
5454
python -m pip install --upgrade pip
@@ -61,10 +61,10 @@ jobs:
6161
runs-on: ubuntu-latest
6262
steps:
6363
- uses: actions/checkout@v1
64-
- name: Set up Python 3.7
64+
- name: Set up Python 3.8
6565
uses: actions/setup-python@v1
6666
with:
67-
python-version: 3.7
67+
python-version: 3.8
6868
- name: Install dependencies
6969
run: |
7070
python -m pip install --upgrade pip
@@ -77,10 +77,10 @@ jobs:
7777
runs-on: ubuntu-latest
7878
steps:
7979
- uses: actions/checkout@v1
80-
- name: Set up Python 3.7
80+
- name: Set up Python 3.8
8181
uses: actions/setup-python@v1
8282
with:
83-
python-version: 3.7
83+
python-version: 3.8
8484
- name: Install dependencies
8585
run: |
8686
python -m pip install --upgrade pip
@@ -92,10 +92,10 @@ jobs:
9292
runs-on: ubuntu-latest
9393
steps:
9494
- uses: actions/checkout@v1
95-
- name: Set up Python 3.7
95+
- name: Set up Python 3.8
9696
uses: actions/setup-python@v1
9797
with:
98-
python-version: 3.7
98+
python-version: 3.8
9999
- name: Install dependencies
100100
run: |
101101
python -m pip install --upgrade pip
@@ -107,10 +107,10 @@ jobs:
107107
runs-on: ubuntu-latest
108108
steps:
109109
- uses: actions/checkout@v1
110-
- name: Set up Python 3.7
110+
- name: Set up Python 3.8
111111
uses: actions/setup-python@v1
112112
with:
113-
python-version: 3.7
113+
python-version: 3.8
114114
- name: Install dependencies
115115
run: |
116116
python -m pip install --upgrade pip

.github/workflows/test.yaml

+3-2
Original file line numberDiff line numberDiff line change
@@ -15,17 +15,18 @@ jobs:
1515
name: "py${{ matrix.python-version }} torch${{ matrix.pytorch-version}}"
1616
runs-on: ubuntu-latest
1717
env:
18-
USING_COVERAGE: '3.7,3.9'
18+
USING_COVERAGE: '3.9'
1919

2020
strategy:
2121
matrix:
22-
python-version: [3.7, 3.8, 3.9]
22+
python-version: [3.8, 3.9]
2323
pytorch-version:
2424
- "==1.9.1"
2525
- "==1.10.1"
2626
- "==1.11.0"
2727
- "==1.12.1"
2828
- "==1.13.1"
29+
- "==2.0.1"
2930
- "" # latest
3031
steps:
3132
- uses: actions/checkout@v1

.readthedocs.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ sphinx:
77
configuration: docs_src/rtd/conf.py
88

99
python:
10-
version: 3.7
10+
version: 3.8
1111
install:
1212
- method: pip
1313
path: .

README-dev.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# <img alt="BackPACK" src="./logo/backpack_logo_torch.svg" height="90"> BackPACK developer manual
22

33
## General standards
4-
- Python version: support 3.7+, use 3.7 for development
4+
- Python version: support 3.8+, use 3.8 for development
55
- `git` [branching model](https://nvie.com/posts/a-successful-git-branching-model/)
66
- Docstring style: [Google](https://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html)
77
- Test runner: [`pytest`](https://docs.pytest.org/en/latest/)

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
[![Travis](https://travis-ci.org/f-dangel/backpack.svg?branch=master)](https://travis-ci.org/f-dangel/backpack)
44
[![Coveralls](https://coveralls.io/repos/github/f-dangel/backpack/badge.svg?branch=master)](https://coveralls.io/github/f-dangel/backpack)
5-
[![Python 3.7+](https://img.shields.io/badge/python-3.7+-blue.svg)](https://www.python.org/downloads/release/python-370/)
5+
[![Python 3.8+](https://img.shields.io/badge/python-3.8+-blue.svg)](https://www.python.org/downloads/release/python-370/)
66

77
BackPACK is built on top of [PyTorch](https://github.com/pytorch/pytorch). It efficiently computes quantities other than the gradient.
88

backpack/utils/kroneckers.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from torch import einsum
2+
from torch.linalg import eigh
23

34
from backpack.utils.unsqueeze import kfacmp_unsqueeze_if_missing_dim
45

@@ -101,7 +102,7 @@ def sym_mat_inv(mat, shift, truncate=1e-8):
101102
Computed by eigenvalue decomposition. Eigenvalues with small
102103
absolute values are truncated.
103104
"""
104-
eigvals, eigvecs = mat.symeig(eigenvectors=True)
105+
eigvals, eigvecs = eigh(mat)
105106
eigvals.add_(shift)
106107
inv_eigvals = 1.0 / eigvals
107108
inv_truncate = 1.0 / truncate

docs_src/examples/use_cases/example_rnn.py

+12
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,12 @@
2222
# Not all extensions support RNNs (yet). Please create a feature request in the
2323
# repository if the extension you need is not supported.
2424

25+
from pkg_resources import packaging
26+
2527
# %%
2628
# Let's get the imports out of the way.
2729
from torch import (
30+
_C,
2831
allclose,
2932
cat,
3033
device,
@@ -41,11 +44,20 @@
4144
from backpack.custom_module.permute import Permute
4245
from backpack.custom_module.reduce_tuple import ReduceTuple
4346
from backpack.extensions import BatchGrad, DiagGGNExact
47+
from backpack.utils import TORCH_VERSION
4448
from backpack.utils.examples import autograd_diag_ggn_exact
4549

4650
manual_seed(0)
4751
DEVICE = device("cpu") # Verification via autograd only works on CPU
4852

53+
# %%
54+
#
55+
# .. note::
56+
# Due to `#99413 <https://github.com/pytorch/pytorch/issues/99413>`_, we have to disable
57+
# MKLDNN for PyTorch 2.0.1 to get the double-backward through LSTMs working.
58+
if TORCH_VERSION == packaging.version.parse("2.0.1"):
59+
_C._set_mkldnn_enabled(False)
60+
4961

5062
# %%
5163
# For this demo, we will use the Tolstoi Char RNN from

setup.cfg

+25-16
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,9 @@ classifiers =
2222
Development Status :: 4 - Beta
2323
License :: OSI Approved :: MIT License
2424
Operating System :: OS Independent
25-
Programming Language :: Python :: 3.7
2625
Programming Language :: Python :: 3.8
2726
Programming Language :: Python :: 3.9
27+
Programming Language :: Python :: 3.10
2828

2929
[options]
3030
zip_safe = False
@@ -34,12 +34,12 @@ setup_requires =
3434
setuptools_scm
3535
# Dependencies of the project (semicolon/line-separated):
3636
install_requires =
37-
torch >= 1.9.0, < 1.13.0
38-
torchvision >= 0.7.0, < 1.0.0
37+
torch >= 1.9.0
38+
torchvision >= 0.7.0
3939
einops >= 0.3.0, < 1.0.0
4040
unfoldNd >= 0.2.0, < 1.0.0
4141
# Require a specific Python version, e.g. Python 2.7 or >= 3.4
42-
python_requires = >=3.7
42+
python_requires = >=3.8
4343

4444
[options.packages.find]
4545
exclude = test*
@@ -96,19 +96,28 @@ use_parentheses=True
9696
select = B,C,E,F,P,W,B9
9797
max-line-length = 88
9898
max-complexity = 10
99+
100+
# E501, # max-line-length
101+
# # ignored because pytorch uses dict
102+
# C408, # use {} instead of dict()
103+
# # Not Black-compatible
104+
# E203, # whitespace before :
105+
# E231, # missing whitespace after ','
106+
# W291, # trailing whitespace
107+
# W503, # line break before binary operator
108+
# W504, # line break after binary operator
109+
# B905, # 'zip()' without an explicit 'strict=' parameter
110+
# B028, # No explicit stacklevel keyword argument found (warn)
99111
ignore =
100-
# replaced by B950 (max-line-length + 10%)
101-
E501, # max-line-length
102-
# ignored because pytorch uses dict
103-
C408, # use {} instead of dict()
104-
# Not Black-compatible
105-
E203, # whitespace before :
106-
E231, # missing whitespace after ','
107-
W291, # trailing whitespace
108-
W503, # line break before binary operator
109-
W504, # line break after binary operator
110-
B905, # 'zip()' without an explicit 'strict=' parameter
111-
B028, # No explicit stacklevel keyword argument found (warn)
112+
E501,
113+
C408,
114+
E203,
115+
E231,
116+
W291,
117+
W503,
118+
W504,
119+
B905,
120+
B028,
112121
exclude = docs, build, .git, docs_src/rtd, docs_src/rtd_output, .eggs
113122

114123
# Differences with pytorch

test/converter/test_converter.py

+2
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
"""
66
from test.converter.converter_cases import CONVERTER_MODULES, ConverterModule
77
from test.core.derivatives.utils import classification_targets, regression_targets
8+
from test.utils.skip_test import skip_torch_2_0_1_lstm
89
from typing import Tuple
910

1011
from pytest import fixture
@@ -31,6 +32,7 @@ def model_and_input(request) -> Tuple[Module, Tensor, Module]:
3132
"""
3233
manual_seed(0)
3334
model: ConverterModule = request.param()
35+
skip_torch_2_0_1_lstm(model)
3436
inputs: Tensor = model.input_fn()
3537
loss_fn: Module = model.loss_fn()
3638
yield model, inputs, loss_fn

test/core/derivatives/derivatives_test.py

+3
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
- Jacobian-matrix products with respect to layer parameters
77
- Transposed Jacobian-matrix products with respect to layer parameters
88
"""
9+
910
from contextlib import nullcontext
1011
from test.automated_test import check_sizes, check_sizes_and_values
1112
from test.core.derivatives.batch_norm_settings import BATCH_NORM_SETTINGS
@@ -27,6 +28,7 @@
2728
skip_BCEWithLogitsLoss,
2829
skip_BCEWithLogitsLoss_non_binary_labels,
2930
skip_subsampling_conflict,
31+
skip_torch_2_0_1_lstm,
3032
)
3133
from typing import List, Union
3234
from warnings import warn
@@ -136,6 +138,7 @@ def test_jac_mat_prod(problem: DerivativesTestProblem, V: int = 3) -> None:
136138
V: Number of vectorized Jacobian-vector products. Default: ``3``.
137139
"""
138140
problem.set_up()
141+
skip_torch_2_0_1_lstm(problem.module)
139142
mat = rand(V, *problem.input_shape).to(problem.device)
140143

141144
backpack_res = BackpackDerivatives(problem).jac_mat_prod(mat)

test/extensions/problem.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ def collect_data(self, savefield: str) -> List[Any]:
227227
else:
228228
if hasattr(p, savefield):
229229
raise RuntimeError(
230-
f"Found non-differentiable parameter with attribute '{savefield}'."
230+
f"Found non-differentiable parameter with attribute {savefield}."
231231
)
232232

233233
return data

test/extensions/secondorder/diag_ggn/test_batch_diag_ggn.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from test.extensions.problem import make_test_problems
66
from test.extensions.secondorder.diag_ggn.diag_ggn_settings import DiagGGN_SETTINGS
77
from test.utils.skip_extension_test import skip_BCEWithLogitsLoss_non_binary_labels
8-
from test.utils.skip_test import skip_adaptive_avg_pool3d_cuda
8+
from test.utils.skip_test import skip_adaptive_avg_pool3d_cuda, skip_torch_2_0_1_lstm
99

1010
import pytest
1111

@@ -23,6 +23,7 @@ def test_diag_ggn_exact_batch(problem, request):
2323
"""
2424
skip_adaptive_avg_pool3d_cuda(request)
2525
problem.set_up()
26+
skip_torch_2_0_1_lstm(problem.model)
2627

2728
backpack_res = BackpackExtensions(problem).diag_ggn_exact_batch()
2829
autograd_res = AutogradExtensions(problem).diag_ggn_exact_batch()
@@ -47,6 +48,7 @@ def test_diag_ggn_mc_batch_light(problem):
4748
"""
4849
problem.set_up()
4950
skip_BCEWithLogitsLoss_non_binary_labels(problem)
51+
skip_torch_2_0_1_lstm(problem.model)
5052

5153
backpack_res = BackpackExtensions(problem).diag_ggn_exact_batch()
5254
mc_samples = 6000
@@ -70,6 +72,7 @@ def test_diag_ggn_mc_batch(problem):
7072
"""
7173
problem.set_up()
7274
skip_BCEWithLogitsLoss_non_binary_labels(problem)
75+
skip_torch_2_0_1_lstm(problem.model)
7376

7477
backpack_res = BackpackExtensions(problem).diag_ggn_exact_batch()
7578
mc_samples = 300000

test/extensions/secondorder/diag_ggn/test_diag_ggn.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from test.extensions.problem import make_test_problems
66
from test.extensions.secondorder.diag_ggn.diag_ggn_settings import DiagGGN_SETTINGS
77
from test.utils.skip_extension_test import skip_BCEWithLogitsLoss_non_binary_labels
8-
from test.utils.skip_test import skip_adaptive_avg_pool3d_cuda
8+
from test.utils.skip_test import skip_adaptive_avg_pool3d_cuda, skip_torch_2_0_1_lstm
99

1010
import pytest
1111

@@ -23,6 +23,7 @@ def test_diag_ggn(problem, request):
2323
"""
2424
skip_adaptive_avg_pool3d_cuda(request)
2525
problem.set_up()
26+
skip_torch_2_0_1_lstm(problem.model)
2627

2728
backpack_res = BackpackExtensions(problem).diag_ggn()
2829
autograd_res = AutogradExtensions(problem).diag_ggn()
@@ -47,6 +48,7 @@ def test_diag_ggn_mc_light(problem):
4748
"""
4849
problem.set_up()
4950
skip_BCEWithLogitsLoss_non_binary_labels(problem)
51+
skip_torch_2_0_1_lstm(problem.model)
5052

5153
backpack_res = BackpackExtensions(problem).diag_ggn()
5254
mc_samples = 3000
@@ -70,6 +72,7 @@ def test_diag_ggn_mc(problem):
7072
"""
7173
problem.set_up()
7274
skip_BCEWithLogitsLoss_non_binary_labels(problem)
75+
skip_torch_2_0_1_lstm(problem.model)
7376

7477
backpack_res = BackpackExtensions(problem).diag_ggn()
7578
mc_samples = 300000

0 commit comments

Comments
 (0)