File tree 1 file changed +17
-8
lines changed
1 file changed +17
-8
lines changed Original file line number Diff line number Diff line change @@ -28,7 +28,7 @@ def test_scalar_invariance():
28
28
torch .testing .assert_allclose (y , y_rot )
29
29
30
30
31
- @pytest .mark .parametrize ("model_name" , ["equivariant-transformer" , "equivariant- tensornet" ])
31
+ @pytest .mark .parametrize ("model_name" , ["equivariant-transformer" , "tensornet" ])
32
32
def test_vector_equivariance (model_name ):
33
33
torch .manual_seed (1234 )
34
34
rotate = torch .tensor (
@@ -38,14 +38,23 @@ def test_vector_equivariance(model_name):
38
38
[- 0.0626055 , 0.3134752 , 0.9475304 ],
39
39
]
40
40
)
41
-
42
- model = create_model (
43
- load_example_args (
44
- model_name ,
45
- prior_model = None ,
46
- output_model = "VectorOutput" ,
41
+ if model_name == "equivariant_transformer"
42
+ model = create_model (
43
+ load_example_args (
44
+ model_name ,
45
+ prior_model = None ,
46
+ output_model = "VectorOutput" ,
47
+ )
48
+ )
49
+ if model_name == "tensornet"
50
+ model = create_model (
51
+ load_example_args (
52
+ model_name ,
53
+ prior_model = None ,
54
+ vector_output = True ,
55
+ output_model = "VectorOutput" ,
56
+ )
47
57
)
48
- )
49
58
z = torch .ones (100 , dtype = torch .long )
50
59
pos = torch .randn (100 , 3 )
51
60
batch = torch .arange (50 , dtype = torch .long ).repeat_interleave (2 )
You can’t perform that action at this time.
0 commit comments