Skip to content

Commit 05c4335

Browse files
committed
sevennet/filter_bad_preds.py fix bad_mask df_wbm column access
1 parent df69736 commit 05c4335

File tree

5 files changed

+39
-38
lines changed

5 files changed

+39
-38
lines changed

.pre-commit-config.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ default_install_hook_types: [pre-commit, commit-msg]
88

99
repos:
1010
- repo: https://github.com/astral-sh/ruff-pre-commit
11-
rev: v0.5.3
11+
rev: v0.5.5
1212
hooks:
1313
- id: ruff
1414
args: [--fix]
@@ -79,7 +79,7 @@ repos:
7979
- id: check-github-actions
8080

8181
- repo: https://github.com/RobertCraigie/pyright-python
82-
rev: v1.1.372
82+
rev: v1.1.373
8383
hooks:
8484
- id: pyright
8585
args: [--level, error]

models/sevennet/filter_bad_preds.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,18 @@
11
import pandas as pd
22

3-
from matbench_discovery.data import Key, df_wbm
3+
from matbench_discovery.data import df_wbm
4+
from matbench_discovery.enums import MbdKey
45

5-
E_FORM_COL = "e_form_per_atom_sevennet"
6+
e_form_7net_col = "e_form_per_atom_sevennet"
67

78
csv_path = "./2024-07-11-sevennet-preds.csv.gz"
89
df_preds = pd.read_csv(csv_path).set_index("material_id")
910

1011
# NOTE this filtering is necessary for both MACE and SevenNet because some outliers
1112
# have extremely low e_form (like -1e40)
12-
bad_mask = df_preds[E_FORM_COL] - df_wbm[Key.e_form] < -5
13+
bad_mask = abs(df_preds[e_form_7net_col] - df_wbm[MbdKey.e_form_wbm]) > 5
14+
n_preds = len(df_preds[e_form_7net_col].dropna())
15+
print(f"{sum(bad_mask)=} is {sum(bad_mask) / len(df_wbm):.2%} of {n_preds:,}")
1316
df_preds[~bad_mask].select_dtypes("number").to_csv(
1417
"./2024-07-11-sevennet-preds-no-bad.csv.gz"
1518
)

models/sevennet/test_sevennet.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,11 @@
2424

2525

2626
# %% this config is editable
27-
SMOKE_TEST = True
27+
smoke_test = True
2828
sevennet_root = os.path.dirname(sevenn.__path__[0])
2929
module_dir = os.path.dirname(__file__)
3030
sevennet_chkpt = f"{module_dir}/sevennet_checkpoint.pth.tar"
31-
pot_name = "sevennet"
31+
model_name = "sevennet"
3232
task_type = Task.IS2RE
3333
ase_optimizer = "FIRE"
3434
device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -55,22 +55,20 @@
5555
slurm_array_task_id = int(os.getenv("SLURM_ARRAY_TASK_ID", "0"))
5656

5757
os.makedirs(out_dir := "./results", exist_ok=True)
58-
out_path = f"{out_dir}/{pot_name}-{slurm_array_task_id:>03}.json.gz"
58+
out_path = f"{out_dir}/{model_name}-{slurm_array_task_id:>03}.json.gz"
5959

6060
data_path = {Task.IS2RE: DataFiles.wbm_initial_structures.path}[task_type]
61-
print(f"\nJob started running {timestamp}, eval {pot_name}", flush=True)
61+
print(f"\nJob started running {timestamp}, eval {model_name}", flush=True)
6262
print(f"{data_path=}", flush=True)
6363

64-
e_pred_col = "sevennet_energy"
65-
66-
# Init ASE SevenNet Calculator from checkpoint
64+
# Initialize ASE SevenNet Calculator from checkpoint
6765
sevennet_calc = SevenNetCalculator(sevennet_chkpt)
6866

6967

7068
# %%
7169
print(f"Read data from {data_path}")
7270
df_in = pd.read_json(data_path).set_index(Key.mat_id)
73-
if SMOKE_TEST:
71+
if smoke_test:
7472
df_in = df_in.head(10)
7573
else:
7674
df_in = df_in.sample(frac=1, random_state=7) # shuffle data for equal runtime
@@ -111,5 +109,5 @@
111109

112110

113111
# %%
114-
if not SMOKE_TEST:
112+
if not smoke_test:
115113
df_out.reset_index().to_json(out_path, default_handler=as_dict_handler)
Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,47 +1,47 @@
11
model:
2-
chemical_species: "auto"
2+
chemical_species: auto
33
cutoff: 5.0
44
channel: 128
55
is_parity: False
66
lmax: 2
77
num_convolution_layer: 5
88
irreps_manual:
9-
- "128x0e"
10-
- "128x0e+64x1e+32x2e"
11-
- "128x0e+64x1e+32x2e"
12-
- "128x0e+64x1e+32x2e"
13-
- "128x0e+64x1e+32x2e"
14-
- "128x0e"
9+
- 128x0e
10+
- 128x0e+64x1e+32x2e
11+
- 128x0e+64x1e+32x2e
12+
- 128x0e+64x1e+32x2e
13+
- 128x0e+64x1e+32x2e
14+
- 128x0e
1515

1616
weight_nn_hidden_neurons: [64, 64]
1717
radial_basis:
18-
radial_basis_name: "bessel"
18+
radial_basis_name: bessel
1919
bessel_basis_num: 8
2020
cutoff_function:
21-
cutoff_function_name: "XPLOR"
21+
cutoff_function_name: XPLOR
2222
cutoff_on: 4.5
2323

24-
act_gate: { "e": "silu", "o": "tanh" }
25-
act_scalar: { "e": "silu", "o": "tanh" }
24+
act_gate: { "e": silu, "o": tanh }
25+
act_scalar: { "e": silu, "o": tanh }
2626

27-
conv_denominator: "avg_num_neigh"
27+
conv_denominator: avg_num_neigh
2828
train_shift_scale: False
2929
train_denominator: False
30-
self_connection_type: "linear"
30+
self_connection_type: linear
3131
train:
3232
train_shuffle: False
3333
random_seed: 1
3434
is_train_stress: True
3535
epoch: 600
3636

37-
loss: "Huber"
37+
loss: Huber
3838
loss_param:
3939
delta: 0.01
4040

41-
optimizer: "adam"
41+
optimizer: adam
4242
optim_param:
4343
lr: 0.01
44-
scheduler: "linearlr"
44+
scheduler: linearlr
4545
scheduler_param:
4646
start_factor: 1.0
4747
total_iters: 600
@@ -51,10 +51,10 @@ train:
5151
stress_loss_weight: 0.01
5252

5353
error_record:
54-
- ["Energy", "MAE"]
55-
- ["Force", "MAE"]
56-
- ["Stress", "MAE"]
57-
- ["TotalLoss", "None"]
54+
- [Energy, MAE]
55+
- [Force, MAE]
56+
- [Stress, MAE]
57+
- [TotalLoss, None]
5858

5959
per_epoch: 10
6060
# continue:
@@ -64,9 +64,9 @@ train:
6464
data:
6565
data_shuffle: False
6666
batch_size: 4 # batch size per gpu
67-
scale: "per_atom_energy_std"
68-
shift: "elemwise_reference_energies"
67+
scale: per_atom_energy_std
68+
shift: elemwise_reference_energies
6969

7070
save_by_train_valid: False
71-
load_dataset_path: ["train.sevenn_data"]
72-
load_validset_path: ["valid.sevenn_data"]
71+
load_dataset_path: [train.sevenn_data]
72+
load_validset_path: [valid.sevenn_data]

0 commit comments

Comments
 (0)