Skip to content

Commit 66521a4

Browse files
Update test_equivariance.py
1 parent 6925fc0 commit 66521a4

File tree

1 file changed

+17
-8
lines changed

1 file changed

+17
-8
lines changed

tests/test_equivariance.py

+17-8
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def test_scalar_invariance():
2828
torch.testing.assert_allclose(y, y_rot)
2929

3030

31-
@pytest.mark.parametrize("model_name", ["equivariant-transformer", "equivariant-tensornet"])
31+
@pytest.mark.parametrize("model_name", ["equivariant-transformer", "tensornet"])
3232
def test_vector_equivariance(model_name):
3333
torch.manual_seed(1234)
3434
rotate = torch.tensor(
@@ -38,14 +38,23 @@ def test_vector_equivariance(model_name):
3838
[-0.0626055, 0.3134752, 0.9475304],
3939
]
4040
)
41-
42-
model = create_model(
43-
load_example_args(
44-
model_name,
45-
prior_model=None,
46-
output_model="VectorOutput",
41+
if model_name == "equivariant_transformer"
42+
model = create_model(
43+
load_example_args(
44+
model_name,
45+
prior_model=None,
46+
output_model="VectorOutput",
47+
)
48+
)
49+
if model_name == "tensornet"
50+
model = create_model(
51+
load_example_args(
52+
model_name,
53+
prior_model=None,
54+
vector_output=True,
55+
output_model="VectorOutput",
56+
)
4757
)
48-
)
4958
z = torch.ones(100, dtype=torch.long)
5059
pos = torch.randn(100, 3)
5160
batch = torch.arange(50, dtype=torch.long).repeat_interleave(2)

0 commit comments

Comments
 (0)