Skip to content

FSDP checkpoint loading issue #20950

Open
@daniimurphy

Description

@daniimurphy

Bug description

Hello! I am running into a strange issue that I can’t seem to debug, it is similar to the mostly unresolved issue #20373

I have a pytorch lightning module that defines my pipeline. Essentially, training goes fine and as expected. After the line of code that calls trainer.fit(), I call .test() and pass in the testing dataloader.

The error that I’m getting indicates that the parameters of the model are somehow modified at some point between training and testing - for instance, many of the parameters are 3 dimensional objects, but then they are suddenly one dimensional which throws an error when the model is being loaded to run the test step?

Here’s an example of the error

[rank0]: Traceback (most recent call last): (RANK 3)
[rank0]:   File "/.../lib/python3.9/site-packages/torch/distributed/checkpoint/utils.py", line 248, in all_gather
[rank0]:     result = map_fun()
[rank0]:   File "/.../lib/python3.9/site-packages/torch/distributed/checkpoint/logger.py", line 83, in wrapper
[rank0]:     result = func(*args, **kwargs)
[rank0]:   File "/.../lib/python3.9/site-packages/torch/distributed/checkpoint/state_dict_loader.py", line 228, in read_data
[rank0]:     all_reads = storage_reader.read_data(final_local_plan, planner)
[rank0]:   File "/.../lib/python3.9/site-packages/torch/distributed/checkpoint/filesystem.py", line 666, in read_data
[rank0]:     assert (
[rank0]: AssertionError: req MetadataIndex(fqn='model.preprocessor.conv.3.weight', offset=torch.Size([0, 0, 0, 0]), index=0) mismatch sizes torch.Size([32]) vs torch.Size([32, 64, 1, 1])

In my test_step, I printed model.named_parameters() and see

preprocessor.layer1.bias torch.Size([0]) 
preprocessor.layer2.weight torch.Size([0]) # should be [14, 32]
preprocessor.layer2.bias torch.Size([0])
encoder.model.patch_embed.proj.weight torch.Size([0]) # should be [512, 14, 4, 4])
encoder.model.patch_embed.proj.bias torch.Size([0])
encoder.model.patch_embed.norm.weight torch.Size([0])
encoder.model.patch_embed.norm.bias torch.Size([0])
encoder.model.layers.0.blocks.0.norm1.weight torch.Size([0])
encoder.model.layers.0.blocks.0.norm1.bias torch.Size([0])

I’m not sure what’s causing this behavior, but I’ll paste the code for my pipeline and the script that I use to instantiate the Trainer/run training. If anyone has any insights to why this is happening it would be greatly appreciated :slight_smile:

My LightningModule class – (some things omitted if i dont think they’re relevant)

class PhasePred(pl.LightningModule):
    NUM_CLASSES: int = 5
    OUTPUT_SHAPE: Tuple[int, int] = (91, 40)

 def __init__(self,config):
        super().__init__()
        self.save_hyperparameters(ignore=['model'])
        self.config = config
        self.configure_models()
        self.configure_losses()
        self.configure_metrics()
        self.transform = AbiToaTransform(self.config)

  def configure_models(self):
        factory = ModelFactory()

        hidden_dim = 64

        self.input_mlp = InputMLP(in_channels=self.config.MODEL.IN_CHANS, out_channels=14,hidden_dim=hidden_dim)

        self.encoder = factory.get_component(component_type="encoder",
            name=self.config.MODEL.ENCODER,
            config=self.config)


        self.decoder = factory.get_component(
            component_type="decoder",
            name=self.config.MODEL.DECODER,
            num_features=self.encoder.num_features)

        self.segmentation_head = factory.get_component(
            component_type="head",
            name="segmentation_head",
            decoder_channels=14,
            num_classes=self.NUM_CLASSES,
            output_shape=self.OUTPUT_SHAPE)


        self.model = nn.Sequential(self.preprocessor,
                self.segmentation_head)
   def training_step(self, batch, batch_idx):

        inputs, targets = batch

        logits = self.forward(inputs)

        # resize logits output
        logits = F.interpolate(logits, size=targets.shape[-2:], mode="bilinear", align_corners=False)

        loss = self.criterion(logits, targets.long())

        preds = torch.argmax(logits, dim=1)
        iou = self.train_iou(preds, targets)

        iou_per_class = self.train_iou_per_class(preds, targets)
        for i, class_iou in enumerate(iou_per_class):
            self.log(f'train_iou_{i}', class_iou, on_step=False, on_epoch=True)

        self.train_loss_avg.update(loss)
        self.train_iou_avg.update(iou)
        self.log('train_loss', self.train_loss_avg.compute(),on_step=False, on_epoch=True, prog_bar=True)
        self.log('train_iou', self.train_iou_avg.compute(),on_step=False, on_epoch=True, prog_bar=True)
        # I set on step to false just for less noise to make sure things work

        return loss

  def validation_step(self, batch, batch_idx):

        inputs, targets = batch

        print(f"[validation_step] input shape: {inputs.shape}")

        logits = self.forward(inputs)

        logits = F.interpolate(logits, size=targets.shape[-2:], mode="bilinear", align_corners=False)

        val_loss = self.criterion(logits, targets.long())
        preds = torch.argmax(logits, dim=1)

        val_iou = self.val_iou(preds, targets.int())

        iou_per_class = self.val_iou_per_class(preds, targets)
        for i, class_iou in enumerate(iou_per_class):
            self.log(f'val_iou_{i}', class_iou, on_step=False, on_epoch=True)

        self.val_loss_avg.update(val_loss)
        self.val_iou_avg.update(val_iou)

        self.log('val_loss', self.val_loss_avg.compute(),
                    on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)
        self.log('val_iou', self.val_iou_avg.compute(),
                    on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)

        return val_loss

    def test_step(self, batch, batch_idx):

        #if batch_idx == 0:
        #    print("=== test_step model structure ===")
        #    print(self)
        #    print("=== Named parameters ===")
        #    for name, param in self.named_parameters():
        #        print(name, param.shape, param.requires_grad)

        inputs, targets = batch
        #print(f"[test_step] Input batch shape: {inputs.shape}")
        logits = self.forward(inputs)

        logits = F.interpolate(logits, size=targets.shape[-2:], mode="bilinear", align_corners=False)

        test_loss = self.criterion(logits, targets.long())
        preds = torch.argmax(logits,dim=1)

        test_iou = self.test_iou(preds, targets.int())

        iou_per_class = self.test_iou_per_class(preds,targets)
        for i, class_iou in enumerate(iou_per_class):
            self.log(f'test_iou_{i}', class_iou, on_step=False, on_epoch=True)

        return test_loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-3)

        #optimizer = build_optimizer(self.config, self.model, is_pretrain=True)
        print(f'Using optimizer: {optimizer}')
        return optimizer

    def on_train_epoch_start(self):
        self.train_loss_avg.reset()
        self.train_iou_avg.reset()

    def on_validation_epoch_start(self):
        self.val_loss_avg.reset()

my code that runs training:

def main(config, output_dir):

    ptlPipeline = PhasePred(config)

    # === Build LightningModule ===
    #if config.MODEL.RESUME:
    #    ptlPipeline  = PhasePred.load_from_checkpoint(config.MODEL.RESUME)

    logger = CSVLogger(save_dir="PhasePredLogs",name=f"phasepred_{config.TAG}")

    checkpoint_callback = ModelCheckpoint(
        monitor="val_iou",
        mode="max",
        save_top_k=1,
        filename="best-miou-checkpoint",
        verbose=True,
        dirpath="/.../PhasePredLogs/checkpoints"
    )

    strategy = FSDPStrategy(state_dict_type="sharded")

    trainer = Trainer(
        accelerator=config.TRAIN.ACCELERATOR,
        strategy=strategy,
        precision=config.PRECISION,
        logger=logger,
        max_epochs=config.TRAIN.EPOCHS,
        devices=2,
        num_nodes=2,
        log_every_n_steps=config.PRINT_FREQ,
        fast_dev_run=False,
        default_root_dir=output_dir,
        enable_checkpointing=True,
        callbacks=[checkpoint_callback]
        )

    if config.TRAIN.LIMIT_TRAIN_BATCHES:
        trainer.limit_train_batches = get_distributed_train_batches(
            config, trainer)

    npz_paths = sorted(glob.glob("/path/to/directory/*.npz"))

full_ds = CloudPhaseDataset(npz_paths)
    train_ds, val_ds, test_ds  = torch.utils.data.random_split(full_ds, [0.70,0.20,0.10])

    # Create data loaders
    train_loader = DataLoader(train_ds, batch_size=config.DATA.BATCH_SIZE, shuffle=True)
    val_loader = DataLoader(val_ds, batch_size=config.DATA.BATCH_SIZE)
    test_loader = DataLoader(test_ds, batch_size=config.DATA.BATCH_SIZE)

    # Train
    trainer.fit(model=ptlPipeline,train_dataloaders=train_loader, val_dataloaders=val_loader)

    # === Reload the best model from checkpoint ===
    best_ckpt_path = checkpoint_callback.best_model_path
    print(f"Best checkpoint path: {best_ckpt_path}")


    # === Evaluate on test set ===
    test_results = trainer.test(dataloaders=test_loader)
    print("TEST RESULTS:", test_results)

I’ve tried calling .test() with and without passing in a specific checkpoint path, the same issue happens either way.

The issue is also unrelated to the encoder/decoders used as if I define the model just with the simple multi layer perceptron and a segmentation head, i still get the error.

Thank you in advance for any help or debugging ideas, and if there is any other code or methods that might be helpful I am happy to share them!

What version are you seeing the problem on?

v2.4

Reproduced in studio

No response

How to reproduce the bug

Error messages and logs

# Error messages and logs here please

Environment

Current environment
* CUDA:
        - GPU:
                - NVIDIA L40S
        - available:         True
        - version:           11.8
* Lightning:
        - efficientnet-pytorch: 0.7.1
        - lightning:         2.4.0
        - lightning-utilities: 0.12.0
        - pytorch-lightning: 2.5.0.post0
        - segmentation-models-pytorch: 0.4.0
        - torch:             2.5.1+cu118
        - torchaudio:        2.5.1+cu118
        - torchgeo:          0.5.2
        - torchmetrics:      1.6.1
        - torchvision:       0.20.1+cu118
* System:
        - OS:                Linux
        - architecture:
                - 64bit
                - ELF
        - processor:         x86_64
        - python:            3.9.18

More info

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingneeds triageWaiting to be triaged by maintainersver: 2.4.x

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions