Skip to content

BartLearnedPositionalEmbedding's forward method signature obstructs private (Opacus) training of BART #18425

Closed
@donebydan

Description

@donebydan

System Info

-transformers version: 4.20.1
-Platform: Linux-5.4.0-1086-azure-x86_64-with-glibc2.17
-Python version: 3.8.13
-Huggingface_hub version: 0.8.1
-PyTorch version (GPU?): 1.9.1+cu102 (False)
-Tensorflow version (GPU?): not installed (NA)
-Flax version (CPU?/GPU?/TPU?): not installed (NA)
-Jax version: not installed
-JaxLib version: not installed
-Using GPU in script?: yes (NA)
-Using distributed or parallel set-up in script?: no (NA)

Who can help?

Tagging @patil-suraj as BART model owner.

Details:
The signature of BartLearnedPositionalEmbedding's forward method takes an input of type torch.Size, which breaks in Opacus. The reason is that Opacus makes a (reasonable) assumption that all layers take input of type torch.Tensor.

In particular, opacus/grad_sample/grad_sample_module.py line 190 (the capture_activations_hook method) tries to detach the input from device via:

module.activations.append(forward_input[0].detach())

If we pass the tensor instead, this will allow fine-tuning BART-type summarization models with differential privacy.

Only a few lines of code need to be changed in modeling_bart.py. In particular, the forward signature of BartLearnedPositionalEmbedding.forward() and references to this method.

I already have a change implemented with BART-related tests passing. More than happy to create a PR which I can tag you in @patil-suraj.

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

import torch

from transformers.models.bart.modeling_bart import  BartLearnedPositionalEmbedding


from opacus.tests.grad_samples.common import GradSampleHooks_test
class TestPositionalEmbedding(GradSampleHooks_test):
    def test_grad_sample(self):
        """
        Verify that our custom implementation of the grad sample for huggingface's
        BartLearnedPositionalEmbedding layer works. Built on the test routines in opacus's library.
        """
        register_grad_sampler()
        batch_size = 1
        max_pos_embs = 10
        embed_dim = 3
    
        x = torch.randint(0, max_pos_embs - 1, (batch_size, embed_dim))
        layer = BartLearnedPositionalEmbedding(max_pos_embs, embed_dim)
        self.run_test(x, layer, batch_first=True)

where a custom register_grad_sampler() method is called for BartLearnedPositionalEmbedding layer.

Expected behavior

Test above should pass.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions