Skip to content

Commit ea8150a

Browse files
author
sanchit-gandhi
committed
fix class naming
1 parent 80188b6 commit ea8150a

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

src/transformers/models/bert/modeling_flax_bert.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -552,12 +552,12 @@ class FlaxBertLayerCollection(nn.Module):
552552

553553
def setup(self):
554554
if self.gradient_checkpointing:
555-
FlaxBertBlockLayer = remat(FlaxBertLayer, static_argnums=(5, 6, 7), policy=self.remat_policy)
555+
FlaxBertCheckpointLayer = remat(FlaxBertLayer, static_argnums=(5, 6, 7), policy=self.remat_policy)
556556
else:
557-
FlaxBertBlockLayer = FlaxBertLayer
557+
FlaxBertCheckpointLayer = FlaxBertLayer
558558

559559
self.layers = [
560-
FlaxBertBlockLayer(self.config, name=str(i), dtype=self.dtype)
560+
FlaxBertCheckpointLayer(self.config, name=str(i), dtype=self.dtype)
561561
for i in range(self.config.num_hidden_layers)
562562
]
563563

0 commit comments

Comments
 (0)