Skip to content

Commit 3ad1f4a

Browse files
freeze decoder
1 parent 915edcf commit 3ad1f4a

File tree

3 files changed

+6
-5
lines changed

3 files changed

+6
-5
lines changed

neurometry/curvature/grid-cells-curvature/models/xu_rnn/default_config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
###-----TRAINING PARAMETERS-----###
2323
load_pretrain=True
2424
pretrain_path=os.path.join(os.getcwd(),"logs/rnn_isometry/20240418-180712/ckpt/model/checkpoint-step25000.pth")
25-
num_steps_train=100#7500 # 10000
25+
num_steps_train=10000#7500 # 10000
2626
lr_decay_from=10000
2727
steps_per_logging=20
2828
steps_per_large_logging=500 # 500
@@ -81,4 +81,4 @@
8181

8282
###-----RAY TUNE PARAMETERS-----###
8383
sweep_metric= "error_reencode"
84-
num_samples = 1#1000
84+
num_samples = 1000#1000

neurometry/curvature/grid-cells-curvature/models/xu_rnn/experiment.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def __init__(self, rng, config: ml_collections.ConfigDict, device):
3434
self.model_config = model.GridCellConfig(**config.model)
3535
self.model = model.GridCell(self.model_config).to(device)
3636

37-
if config.train.freeze_decoder:
37+
if config.model.freeze_decoder:
3838
logging.info("==== freeze decoder ====")
3939
for param in self.model.decoder.parameters():
4040
param.requires_grad = False
@@ -69,8 +69,8 @@ def __init__(self, rng, config: ml_collections.ConfigDict, device):
6969
logging.info(f"Loading pretrain model from {ckpt_model_path}")
7070
ckpt = torch.load(ckpt_model_path, map_location=device)
7171
self.model.load_state_dict(ckpt["state_dict"])
72-
logging.info("==== load pretrained optimizer ====")
73-
self.optimizer.load_state_dict(ckpt["optimizer"])
72+
# logging.info("==== load pretrained optimizer ====")
73+
# self.optimizer.load_state_dict(ckpt["optimizer"])
7474
self.starting_step = ckpt["step"]
7575
else:
7676
self.starting_step = 1

neurometry/curvature/grid-cells-curvature/models/xu_rnn/model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
@dataclass
1111
class GridCellConfig:
12+
freeze_decoder: bool
1213
trans_type: str
1314
num_grid: int
1415
num_neurons: int

0 commit comments

Comments
 (0)