Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

Add "sub_module" argument in PretrainedTransformerMismatchedEmbedder #5580

Merged
merged 2 commits into from
Feb 28, 2022
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ class PretrainedTransformerMismatchedEmbedder(TokenEmbedder):
through the transformer model independently, and concatenate the final representations.
Should be set to the same value as the `max_length` option on the
`PretrainedTransformerMismatchedIndexer`.
sub_module: `str`, optional (default = `None`)
The name of a submodule of the transformer to be used as the embedder. Some transformers naturally act
as embedders such as BERT. However, other models consist of encoder and decoder, in which case we just
want to use the encoder.
train_parameters: `bool`, optional (default = `True`)
If this is `True`, the transformer weights get updated during training.
last_layer_only: `bool`, optional (default = `True`)
Expand Down Expand Up @@ -65,6 +69,7 @@ def __init__(
self,
model_name: str,
max_length: int = None,
sub_module: str = None,
train_parameters: bool = True,
last_layer_only: bool = True,
override_weights_file: Optional[str] = None,
Expand All @@ -80,6 +85,7 @@ def __init__(
self._matched_embedder = PretrainedTransformerEmbedder(
model_name,
max_length=max_length,
sub_module=sub_module,
train_parameters=train_parameters,
last_layer_only=last_layer_only,
override_weights_file=override_weights_file,
Expand Down