Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

T5 #4969

Merged
merged 56 commits into from
Apr 22, 2021
Merged

T5 #4969

merged 56 commits into from
Apr 22, 2021

Conversation

dirkgr
Copy link
Member

@dirkgr dirkgr commented Feb 10, 2021

This adds a transformer toolkit version of T5. The code is mostly stolen from huggingface and then given type annotations and some minor improvements.

@dirkgr dirkgr self-assigned this Feb 10, 2021
@epwalsh
Copy link
Member

epwalsh commented Mar 15, 2021

Currently I'm debugging. The loss doesn't match what you get from HF.

@dirkgr
Copy link
Member Author

dirkgr commented Mar 16, 2021

Holy fixes Batman. Did I screw up that much in my port?

@epwalsh epwalsh marked this pull request as ready for review March 17, 2021 22:16
@epwalsh
Copy link
Member

epwalsh commented Mar 17, 2021

Ready for review now. Not necessary ready to merge, because I definitely want feedback on the API.

Comment on lines 13 to 27
@DatasetReader.register("t5")
class T5DatasetReader(DatasetReader):
def __init__(self, model_name: str, **kwargs) -> None:
super().__init__(
manual_distributed_sharding=True, manual_multiprocess_sharding=True, **kwargs
)
self.tokenizer = PretrainedTransformerTokenizer(model_name)
self.token_indexers = {"tokens": PretrainedTransformerIndexer(model_name)}

@overrides
def _read(self, file_path):
with open(cached_path(file_path)) as data_file:
for line in self.shard_iterable(data_file):
source, target = line.strip().split("\t")
yield self.text_to_instance(source, target)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As of right now, this reader is really designed for fine-tuning tasks, not LM-ing. Should we also have a reader for LM-ing? Or just have a flag that makes it work for LM-ing?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The main difference would be that we'd just have source text/tokens, no target.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since T5 always does seq2seq, and there are different ways language modeling is done in this setting (predict next, fix masks, re-order words, etc.), I think this reader should always produce a source sequence and a target sequence. If they are the same in some settings, that's fine.

@epwalsh epwalsh self-assigned this Mar 17, 2021
Copy link
Member Author

@dirkgr dirkgr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can only comment since it's my own PR :-)

Comment on lines 13 to 27
@DatasetReader.register("t5")
class T5DatasetReader(DatasetReader):
def __init__(self, model_name: str, **kwargs) -> None:
super().__init__(
manual_distributed_sharding=True, manual_multiprocess_sharding=True, **kwargs
)
self.tokenizer = PretrainedTransformerTokenizer(model_name)
self.token_indexers = {"tokens": PretrainedTransformerIndexer(model_name)}

@overrides
def _read(self, file_path):
with open(cached_path(file_path)) as data_file:
for line in self.shard_iterable(data_file):
source, target = line.strip().split("\t")
yield self.text_to_instance(source, target)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since T5 always does seq2seq, and there are different ways language modeling is done in this setting (predict next, fix masks, re-order words, etc.), I think this reader should always produce a source sequence and a target sequence. If they are the same in some settings, that's fine.

self.wi_1 = nn.Linear(hidden_size, ff_size, bias=False)
self.wi_1.weight.data.normal_(mean=0.0, std=hidden_size ** -0.5)
self.wo = nn.Linear(ff_size, hidden_size, bias=False)
self.hidden_size = hidden_size
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have to store the hidden size? Can't we just read it out of self.wi_0 whenever we need it?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same question for T5DenseReluDense

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, didn't need this after all.

key_value_proj_dim: int = 64,
num_heads: int = 16,
hidden_size: int = 512,
key_value_proj_dim: int = 64, # d_kv
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

d_kv?

position_bias, mask
) # (batch_size, n_heads, seq_length, key_length)
# Shape: (batch_size, num_heads, seq_length, key_length)
position_bias = position_bias + mask
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed this to use our apply_mask(), and you changed it back. I assume you looked into this and you know what's going on?

I seem to remember that with masking, HF likes to make a float mask early on and then use that everywhere, whereas I was trying to pass around boolean masks and make them into floats at the last possible moment (if ever). I think my approach is better, because you can use the boolean mask in different ways in different places, for example, apply in an fp16 way in one place, but apply in a fp32 way in another place, or use sum() to find the number of non-masked items, etc.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Comment on lines 724 to 733
extended_attention_mask = get_extended_attention_mask(
attention_mask, input_shape, inputs_embeds.dtype, is_decoder=self.is_decoder
)

if self.is_decoder and encoder_attention_mask is not None:
encoder_extended_attention_mask = invert_attention_mask(
encoder_attention_mask, inputs_embeds.dtype
)
else:
encoder_extended_attention_mask = None
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

More masking code that I removed and you brought back ...

loss: Optional[FloatT] = None
past_key_values: Optional[List[KeyValueStates]] = None
decoder_hidden_states: Optional[List[FloatT]] = None
class T5Output:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even if we do nothing else, this version of T5, with the types and the docs and everything, is already a superior starting point if you want to hack on top of T5. I wonder if we should see if HF would want some if this ported back? Let's talk about that once this PR is done.

Comment on lines +1093 to +1094
# Currently tied embeddings is the only option we have, but if make
# that configurable then we should put this in an 'if' block.
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The T5 paper does a bunch of experiments tying things here and there, but ultimately concludes that almost none of them work. The HF code supports all the weight tying options, but I removed that stuff because I just wanted to run T5-Large at the time. I suppose it would be good to have for someone to do future experiments, but not very critical.

decoder_input_ids = self._shift_right(labels, self.decoder_start_token_id)

# Replace possible -100 values in labels by `pad_token_id`
decoder_input_ids.masked_fill_(decoder_input_ids == -100, self.pad_token_id)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A backwards compatibility hack we carried over I assume?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Carried over, but I don't know why 😏

mapping: Optional[Dict[str, str]] = None,
):
"""
Returns the mapping to be used, based on the optional `pretrained_module`.
If `pretrained_module` is not given, the default module-level mapping is returned.
"""
combined_mapping = {}
if "huggingface" in source:
if "huggingface" == source:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@AkshitaB, do you sign off on this change? Is it safe? Was there a good reason you were doing a substring lookup before?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@AkshitaB, ping?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is fine. Initially, the thought was that the source could be different huggingface configs, with some minor differences that could be addressed. But we haven't used it like that anywhere, so this should work fine.

Copy link
Member Author

@dirkgr dirkgr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changes look good, just one question.

I can't hit "Approve" because I started this PR, but I'm fine with merging this.

sanitize(self._finalize_output(o)) for o in self._model.forward_on_instances(instances)
]

def _finalize_output(self, output: JsonDict) -> JsonDict:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we not doing this anymore?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added this initially, then after doing some other refactoring I realized I really didn't need it, and instead the logic I put there should live in Model.make_output_human_readable().

dirkgr and others added 24 commits April 22, 2021 10:04
* Fixes distributed training with gradient accumulation

* Fix in case we don't do anything in a batch group

* Test for the problematic condition

* Formatting

* More formatting

* Changelog

* Fix another test

* Fix even more tests

* Fixes one more test

* I can fix these tests all day.
* Add link to gallery in README

* Update README.md

* try emojis

Is this overkill?
* Adding metadata parameter to BasicClassifier

* Fix

* Updating the changelog

* reformatting

* updating parameter type

* fixing import

Co-authored-by: Dirk Groeneveld <[email protected]>
* additional W&B params

* add wandb_kwargs

* fix

* fix docs
* Add eval_mode argument to pretrained transformer embedder

* Edit changelog entry

* Lint

* Update allennlp/modules/token_embedders/pretrained_transformer_embedder.py

* Apply suggestions from code review

Co-authored-by: Evan Pete Walsh <[email protected]>
Co-authored-by: Evan Pete Walsh <[email protected]>
* specify 'truncation' to avoid transformers warning

* Update docs

* Remove `stride` param

* Update CHANGELOG.md

Co-authored-by: Dirk Groeneveld <[email protected]>
* Create a way to use allennlp predict with a dataset and a multitask model

* Fix type ignoration

* Changelog

* Fix to the predictor
* fix bug with interleaving dataset reader

* more tests

* Update allennlp/data/dataset_readers/interleaving_dataset_reader.py

* Update allennlp/data/dataset_readers/interleaving_dataset_reader.py
* Take the number of runs in the test for distributed metrics

* Changelog
* creating a new functionality  to fields and instances to support outputing instnaces to json files

* creating tests for the new functionality

* fixing docs

* Delete __init__.py

* Delete influence_interpreter.py

* Delete use_if.py

* Delete simple_influence_test.py

* fixing docs

* finishing up SimpleInfluence

* passing lint

* passing format

* making small progress in coding

* Delete fast_influence.py

Submit to the wrong branch

* Delete faiss_utils.py

wrong branch

* Delete gpt2_bug.py

not sure why it's included

* Delete text_class.py

not sure why it's included

* adding test file

* adding testing files

* deleted unwanted files

* deleted unwanted files and rearrange test files

* small bug

* adjust function call to save instance in json

* Update allennlp/interpret/influence_interpreters/influence_interpreter.py

Co-authored-by: Evan Pete Walsh <[email protected]>

* Update allennlp/interpret/influence_interpreters/influence_interpreter.py

Co-authored-by: Evan Pete Walsh <[email protected]>

* Update allennlp/interpret/influence_interpreters/influence_interpreter.py

Co-authored-by: Evan Pete Walsh <[email protected]>

* move some documentation of parameters to base class

* delete one comment

* delete one deprecated abstract method

* changing interface

* formatting

* formatting err

* passing mypy

* passing mypy

* passing mypy

* passing mypy

* passing integration test

* passing integration test

* adding a new option to the do-all function

* modifying the callable function to the interface

* update API, fixes

* doc fixes

* add `from_path` and `from_archive` methods

* fix docs, improve logging

* add test

* address @matt-gardner's comments

* fixes to documentation

* update docs

Co-authored-by: Evan Pete Walsh <[email protected]>
Co-authored-by: Evan Pete Walsh <[email protected]>
* Update CONTRIBUTING.md

* updated changelog

Co-authored-by: Akshita Bhagia <[email protected]>
Co-authored-by: Arjun Subramonian <[email protected]>
* Added three definitions of fairness

* Updated CHANGELOG

* Added DemographicParityWithoutGroundTruth and finished tests

* finished refactoring Independence, Separation, and Sufficiency to accumulate

* added distributed functionality to Independence, Sufficiency, and Separation

* Finished aggregate and distributed functionality for DemographicParityWithoutGroundTruth

* fixed GPU and doc issues

* fixed GPU and doc issues

* fixed GPU and doc issues

* fixed GPU issues

* fixed GPU issues

* added init file

* fixed typo

* minor docstring changes

* minor changes to docstring

* Added simple explanations of fairness metrics to docstrings

* Further vectorized all metric implementations

* Fixed device issue

Co-authored-by: Arjun Subramonian <[email protected]>
Co-authored-by: Akshita Bhagia <[email protected]>
Co-authored-by: Dirk Groeneveld <[email protected]>
* fix cached_path for hub downloads

* fix test name

* fix type hint

* Update allennlp/common/file_utils.py

Co-authored-by: Lysandre Debut <[email protected]>

Co-authored-by: Lysandre Debut <[email protected]>
Copy link
Member Author

@dirkgr dirkgr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No concerns, though I didn't read the big file a second time. In fact, I'm not sure why GitHub thinks it changed so much since the last time I looked at it.

@epwalsh epwalsh merged commit 4e862a5 into main Apr 22, 2021
@epwalsh epwalsh deleted the T5 branch April 22, 2021 19:11
dirkgr added a commit that referenced this pull request May 10, 2021
* Formatting

* New activation functions

* Makes position embeddings optional in the transformer embeddings

* Adds T5

* Various fixes to make this start up

* Share weights

* Adds one test that passes, and one test that fails

* use min_value_of_dtype in apply_mask

* fixes, add beam search

* encoder fixes

* fix

* fix beam search

* fix tests

* rename to just 'T5'

* fix initialization from pretrained

* add Model, DatasetReader, and Predictor

* remove useless dataset reader

* move high-level peices to allennlp-models

* revert predictor changes

* remove unneeded hidden_size

* remove stray comment

* bool masks

* CHANGELOG

* fix test file name

* revert other change

* revert other change

* Distributed training with gradient accumulation (#5100)

* Fixes distributed training with gradient accumulation

* Fix in case we don't do anything in a batch group

* Test for the problematic condition

* Formatting

* More formatting

* Changelog

* Fix another test

* Fix even more tests

* Fixes one more test

* I can fix these tests all day.

* Add link to gallery and demo in README (#5103)

* Add link to gallery in README

* Update README.md

* try emojis

Is this overkill?

* Adding a metadata field to the basic classifier (#5104)

* Adding metadata parameter to BasicClassifier

* Fix

* Updating the changelog

* reformatting

* updating parameter type

* fixing import

Co-authored-by: Dirk Groeneveld <[email protected]>

* additional W&B params (#5114)

* additional W&B params

* add wandb_kwargs

* fix

* fix docs

* Add eval_mode argument to pretrained transformer embedder (#5111)

* Add eval_mode argument to pretrained transformer embedder

* Edit changelog entry

* Lint

* Update allennlp/modules/token_embedders/pretrained_transformer_embedder.py

* Apply suggestions from code review

Co-authored-by: Evan Pete Walsh <[email protected]>
Co-authored-by: Evan Pete Walsh <[email protected]>

* specify 'truncation' to avoid transformers warning (#5120)

* specify 'truncation' to avoid transformers warning

* Update docs

* Remove `stride` param

* Update CHANGELOG.md

Co-authored-by: Dirk Groeneveld <[email protected]>

* Predicting with a dataset reader on a multitask model (#5115)

* Create a way to use allennlp predict with a dataset and a multitask model

* Fix type ignoration

* Changelog

* Fix to the predictor

* fix bug with interleaving dataset reader (#5122)

* fix bug with interleaving dataset reader

* more tests

* Update allennlp/data/dataset_readers/interleaving_dataset_reader.py

* Update allennlp/data/dataset_readers/interleaving_dataset_reader.py

* remove jsonpickle from dependencies (#5121)

Co-authored-by: Dirk Groeneveld <[email protected]>

* Update docstring for basic_classifier (#5124)

* improve error message from Registrable class (#5125)

Co-authored-by: Akshita Bhagia <[email protected]>

* Prepare for release v2.3.0

* fix docs CI

* Take the number of runs in the test for distributed metrics (#5127)

* Take the number of runs in the test for distributed metrics

* Changelog

* Add influence functions to interpret module (#4988)

* creating a new functionality  to fields and instances to support outputing instnaces to json files

* creating tests for the new functionality

* fixing docs

* Delete __init__.py

* Delete influence_interpreter.py

* Delete use_if.py

* Delete simple_influence_test.py

* fixing docs

* finishing up SimpleInfluence

* passing lint

* passing format

* making small progress in coding

* Delete fast_influence.py

Submit to the wrong branch

* Delete faiss_utils.py

wrong branch

* Delete gpt2_bug.py

not sure why it's included

* Delete text_class.py

not sure why it's included

* adding test file

* adding testing files

* deleted unwanted files

* deleted unwanted files and rearrange test files

* small bug

* adjust function call to save instance in json

* Update allennlp/interpret/influence_interpreters/influence_interpreter.py

Co-authored-by: Evan Pete Walsh <[email protected]>

* Update allennlp/interpret/influence_interpreters/influence_interpreter.py

Co-authored-by: Evan Pete Walsh <[email protected]>

* Update allennlp/interpret/influence_interpreters/influence_interpreter.py

Co-authored-by: Evan Pete Walsh <[email protected]>

* move some documentation of parameters to base class

* delete one comment

* delete one deprecated abstract method

* changing interface

* formatting

* formatting err

* passing mypy

* passing mypy

* passing mypy

* passing mypy

* passing integration test

* passing integration test

* adding a new option to the do-all function

* modifying the callable function to the interface

* update API, fixes

* doc fixes

* add `from_path` and `from_archive` methods

* fix docs, improve logging

* add test

* address @matt-gardner's comments

* fixes to documentation

* update docs

Co-authored-by: Evan Pete Walsh <[email protected]>
Co-authored-by: Evan Pete Walsh <[email protected]>

* Update CONTRIBUTING.md (#5133)

* Update CONTRIBUTING.md

* updated changelog

Co-authored-by: Akshita Bhagia <[email protected]>
Co-authored-by: Arjun Subramonian <[email protected]>

* fix #5132 (#5134)

* fix

* Prepare for release v2.3.1

* Fairness Metrics (#5093)

* Added three definitions of fairness

* Updated CHANGELOG

* Added DemographicParityWithoutGroundTruth and finished tests

* finished refactoring Independence, Separation, and Sufficiency to accumulate

* added distributed functionality to Independence, Sufficiency, and Separation

* Finished aggregate and distributed functionality for DemographicParityWithoutGroundTruth

* fixed GPU and doc issues

* fixed GPU and doc issues

* fixed GPU and doc issues

* fixed GPU issues

* fixed GPU issues

* added init file

* fixed typo

* minor docstring changes

* minor changes to docstring

* Added simple explanations of fairness metrics to docstrings

* Further vectorized all metric implementations

* Fixed device issue

Co-authored-by: Arjun Subramonian <[email protected]>
Co-authored-by: Akshita Bhagia <[email protected]>
Co-authored-by: Dirk Groeneveld <[email protected]>

* fix cached_path for hub downloads (#5141)

* fix cached_path for hub downloads

* fix test name

* fix type hint

* Update allennlp/common/file_utils.py

Co-authored-by: Lysandre Debut <[email protected]>

Co-authored-by: Lysandre Debut <[email protected]>

* fix

* fix

Co-authored-by: epwalsh <[email protected]>
Co-authored-by: Evan Pete Walsh <[email protected]>
Co-authored-by: Jacob Morrison <[email protected]>
Co-authored-by: Nelson Liu <[email protected]>
Co-authored-by: Akshita Bhagia <[email protected]>
Co-authored-by: Leo Liu <[email protected]>
Co-authored-by: ArjunSubramonian <[email protected]>
Co-authored-by: Arjun Subramonian <[email protected]>
Co-authored-by: Arjun Subramonian <[email protected]>
Co-authored-by: Lysandre Debut <[email protected]>
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants