Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

Tango #5162

Merged
merged 166 commits into from
Aug 5, 2021
Merged

Tango #5162

Show file tree
Hide file tree
Changes from 165 commits
Commits
Show all changes
166 commits
Select commit Hold shift + click to select a range
6f42c6f
Basic step infrastructure
dirkgr Apr 28, 2021
0e95903
Formatting
dirkgr Apr 28, 2021
d1a8320
Merge branch 'main' into Tango
dirkgr Apr 28, 2021
08c3f55
Merge remote-tracking branch 'origin/main' into Tango
dirkgr May 4, 2021
8c72499
Adds a proper StepCache, plus a bunch of other stuff
dirkgr May 6, 2021
0f0d999
Formatting
dirkgr May 6, 2021
9d57f7f
Actually run a workflow
dirkgr May 6, 2021
c02772f
Typo
dirkgr May 6, 2021
10ad6a5
More Step infra
dirkgr May 10, 2021
b1abd2e
Fix the symlink
dirkgr May 11, 2021
35e3e04
Print some output
dirkgr May 11, 2021
672f255
ensure_results()
dirkgr May 11, 2021
18fb9ee
Write to a temporary location before writing to the final one
dirkgr May 11, 2021
5a9a7f4
Steps for reading and tokenizing a dataset
dirkgr May 11, 2021
c7062d0
Formatting
dirkgr May 11, 2021
2954779
Remove old TODO
dirkgr May 14, 2021
c54fa66
Show more errors when something can't be instantiated with from_params
dirkgr May 14, 2021
ae8ae8a
Batches have a length
dirkgr May 14, 2021
0d6cb33
Make mypy happy
dirkgr May 14, 2021
1c00563
Updated dataset definition
dirkgr May 14, 2021
d056bca
Use cached_transformers wherever possible
dirkgr May 14, 2021
fc207ae
Step for PIQA instances
dirkgr May 14, 2021
cef91b4
Some unimportant mypy stuff
dirkgr May 14, 2021
b400201
Compatibility with huggingface
dirkgr May 19, 2021
ab589dd
Create a mask if we have to
dirkgr May 19, 2021
d516ba5
Make LongTensors
dirkgr May 19, 2021
0924537
Easy access to the output dimension of an activation layer
dirkgr May 19, 2021
7a0952d
Take an ignore an attention mask in TransformerEmbeddings
dirkgr May 19, 2021
6c3a49b
Make it so a pooler can be derived from a huggingface module
dirkgr May 19, 2021
c4fb7f2
Improved type annotations
dirkgr May 19, 2021
f61a641
Formatting
dirkgr May 19, 2021
21c8a43
Big refactoring that makes training work
dirkgr May 19, 2021
d9fc729
Making mypy a little less grumpy
dirkgr May 19, 2021
efd4659
Formatting and unused variables
dirkgr May 19, 2021
f323544
Remove symlinks to old results
dirkgr May 19, 2021
2cdae81
Leave TODOs in the code
dirkgr May 19, 2021
02d58f4
StepCache doesn't need to be a MutableMapping
dirkgr May 19, 2021
d47ec53
Merge branch 'main' into Tango
dirkgr Jun 8, 2021
c5c1740
Fix duplicate line
dirkgr Jun 9, 2021
2f9b7c7
Pooler that can load from a transformer module
dirkgr Jun 9, 2021
ee8324f
GPU training
dirkgr Jun 9, 2021
634b0ee
Merge remote-tracking branch 'origin/main' into Tango
dirkgr Jun 10, 2021
750f243
Adds an evaluation step and TorchFormat
dirkgr Jun 11, 2021
e93ef1d
Merge remote-tracking branch 'origin/main' into Tango
dirkgr Jun 17, 2021
4626957
Fix duplicate line
dirkgr Jun 9, 2021
668022a
Easy access to the output dimension of an activation layer
dirkgr May 19, 2021
3194b2d
Take an ignore an attention mask in TransformerEmbeddings
dirkgr May 19, 2021
ad41685
Make it so a pooler can be derived from a huggingface module
dirkgr May 19, 2021
7c45761
Pooler that can load from a transformer module
dirkgr Jun 9, 2021
67041cb
Changelog
dirkgr Jun 17, 2021
b684bf0
Update transformer_embeddings.py
AkshitaB Jun 17, 2021
0df423e
Productivity through formatting
dirkgr Jun 17, 2021
77d189b
Don't break positional arguments
dirkgr Jun 18, 2021
2aa1b72
Merge branch 'AkshitaB-patch-1' into TransformerToolkitUpdates
dirkgr Jun 18, 2021
453d645
Some mode module names
dirkgr Jun 18, 2021
32fda86
Remove _get_input_arguments()
dirkgr Jun 18, 2021
43a200b
Merge branch 'TransformerToolkitUpdates' into Tango
dirkgr Jun 18, 2021
76bc409
Fix previously broken merge
dirkgr Jun 18, 2021
8c5a8da
Formatting
dirkgr Jun 18, 2021
6d606ec
Merge remote-tracking branch 'origin/main' into Tango
dirkgr Jun 23, 2021
aaf5eca
Fix the treatment of padding
dirkgr Jun 28, 2021
2b222da
Merge remote-tracking branch 'origin/main' into TransformerTextFieldFix
dirkgr Jul 2, 2021
0cc2025
Adds a test that exposes the problem
dirkgr Jul 6, 2021
f669579
Makes the epsilon for layer norm configurable and handles positions f…
dirkgr Jul 6, 2021
5032678
Fixes the end-to-end toolkit test
dirkgr Jul 6, 2021
27d5591
Merge remote-tracking branch 'origin/main' into TransformerTextFieldFix
Jul 7, 2021
e50838c
Changelog
dirkgr Jul 7, 2021
4924bf8
Merge remote-tracking branch 'origin/main' into Tango
dirkgr Jul 8, 2021
02aab16
Merge branch 'TransformerTextFieldFix' into Tango
dirkgr Jul 8, 2021
3658aae
Fix test
dirkgr Jul 8, 2021
298efe3
Merge branch 'main' into TransformerTextFieldFix
dirkgr Jul 8, 2021
655d72f
Merge remote-tracking branch 'origin/main' into tango
dirkgr Jul 8, 2021
486fa70
Merge remote-tracking branch 'origin/TransformerTextFieldFix' into tango
dirkgr Jul 8, 2021
1c50327
Slightly better error message
dirkgr Jul 8, 2021
dc9ff6a
Removing broken config
dirkgr Jul 8, 2021
a68b0ef
Moving the DSPT config where it belongs
dirkgr Jul 8, 2021
5af1aaf
Moving tango things into a tango directory
dirkgr Jul 8, 2021
8b6dc20
Moving the PIQA config to allennlp_models
dirkgr Jul 8, 2021
e8d482e
cleanup
dirkgr Jul 8, 2021
10e25a0
Lots of mypy changes
dirkgr Jul 10, 2021
b4eddb5
Fix type of covariance metric
dirkgr Jul 10, 2021
b24edd3
Silence two warnings
dirkgr Jul 10, 2021
5af127c
Merge remote-tracking branch 'origin/main' into Tango
dirkgr Jul 10, 2021
d78f38c
Moving example steps into their own files
dirkgr Jul 10, 2021
5303257
Removes a useless step
dirkgr Jul 14, 2021
6d177ab
Rename the huggingface tokenizer step
dirkgr Jul 14, 2021
bd12bd3
Removed obsolete TODO
dirkgr Jul 14, 2021
2069da1
Brings back the text_only step
dirkgr Jul 14, 2021
a508701
Fixes the text_only step
dirkgr Jul 15, 2021
5c78475
Move and fix the test for steps
dirkgr Jul 15, 2021
87417e0
Compatibility with Python 3.7
dirkgr Jul 15, 2021
767881e
Documentation
dirkgr Jul 15, 2021
93170eb
More docs
dirkgr Jul 15, 2021
e938841
More docs
dirkgr Jul 15, 2021
3cdae0a
Makes a step's format part of its unique name
dirkgr Jul 15, 2021
620eaa8
Changelog
dirkgr Jul 15, 2021
db8ea8d
Even more docs
dirkgr Jul 15, 2021
6f986c5
Turn off zipfile serialization
dirkgr Jul 16, 2021
e508644
More documentation
dirkgr Jul 16, 2021
7c19aeb
Merge branch 'main' into Tango
dirkgr Jul 19, 2021
660b9ff
Merge remote-tracking branch 'origin/main' into Tango
dirkgr Jul 20, 2021
089ae90
Make it clearer which steps are running and which ones are not
dirkgr Jul 20, 2021
4b1e931
Merge remote-tracking branch 'origin/main' into Tango
dirkgr Jul 20, 2021
55f171a
Fix the torch format
dirkgr Jul 20, 2021
7353005
Better error message
dirkgr Jul 20, 2021
14def67
Adds a dataset reader adapter for Tango
dirkgr Jul 20, 2021
7999c20
Formatting
dirkgr Jul 20, 2021
7266daa
Fix warning
dirkgr Jul 20, 2021
bf8bad2
Fix output of the dataset reader adapter
dirkgr Jul 20, 2021
af6ee27
Return type annotation for the dataset reader adapter
dirkgr Jul 20, 2021
0f647e1
Puts the documentation where it belongs
dirkgr Jul 20, 2021
f5a758f
Fix old name
dirkgr Jul 20, 2021
318d820
Adds an end-to-end tango test
dirkgr Jul 20, 2021
6a1e618
Formatting 🙄
dirkgr Jul 20, 2021
46a1907
More tests
dirkgr Jul 20, 2021
d896960
Test for dry run
dirkgr Jul 21, 2021
4551f8f
f
dirkgr Jul 21, 2021
c24b386
Fix iterable dill results
dirkgr Jul 21, 2021
ebceaf0
Test iterable dill format
dirkgr Jul 21, 2021
786c974
Refactored how dry_run() works
dirkgr Jul 21, 2021
e58002e
Test for running Tango programmatically
dirkgr Jul 21, 2021
aba9890
Removing one TODO
dirkgr Jul 21, 2021
78c9981
Merge remote-tracking branch 'origin/main' into Tango
dirkgr Jul 27, 2021
0f9bc19
Merge branch 'main' into Tango
dirkgr Jul 27, 2021
f14e6f0
Re-initializes random seeds for every step's run
dirkgr Jul 27, 2021
10ec287
Rename temp_dir to work_dir.
dirkgr Jul 27, 2021
444e3fb
Experimental warnings
dirkgr Jul 28, 2021
c2e74fa
Formatting
dirkgr Jul 28, 2021
d5c691a
Even more formatting
dirkgr Jul 28, 2021
abba3e4
Remove file that wasn't supposed to be checked in
dirkgr Jul 28, 2021
2b30dd4
Moved changelog entry
dirkgr Jul 28, 2021
da1213f
Pin datasets tighter
dirkgr Jul 28, 2021
867b643
JSON format
dirkgr Jul 28, 2021
68f0b3e
Test for JSON format
dirkgr Jul 28, 2021
8787a47
Factor out some stuff
dirkgr Jul 28, 2021
a2eabaf
Refactor some more
dirkgr Jul 28, 2021
3902a32
JSON format for the Evaluation step
dirkgr Jul 29, 2021
f0aeba7
Renamed to DatasetDict
dirkgr Jul 29, 2021
153e649
Shorter name
dirkgr Jul 29, 2021
de6ad81
Use torch inference mode
dirkgr Jul 29, 2021
7b47bed
mypy
dirkgr Jul 29, 2021
1faed56
Set bounds on datasets
dirkgr Jul 29, 2021
71411e1
ShuffledSequence
dirkgr Jul 29, 2021
badd836
Convenience methods for DatasetDict
dirkgr Jul 29, 2021
729f6ab
Use det_hash, giving a way to override how unique hashes are computed
dirkgr Jul 29, 2021
1816783
to_params()
dirkgr Jul 29, 2021
6372aa9
Some more documentation for to_params()
dirkgr Jul 29, 2021
58351e5
Merge branch 'main' into Tango
dirkgr Jul 30, 2021
423d964
Better dethash for types
dirkgr Jul 30, 2021
bd9b788
Better hash for formats in steps
dirkgr Jul 30, 2021
cd63e02
to_params and named parameters for data loaders
dirkgr Jul 30, 2021
6d25012
Merge branch 'Tango' of https://github.com/allenai/allennlp into Tango
dirkgr Jul 30, 2021
4c085c2
😳
dirkgr Jul 31, 2021
687f3b8
Support tuples in input
dirkgr Aug 3, 2021
2cb729e
Merge branch 'main' into Tango
AkshitaB Aug 3, 2021
8334b31
Stolen check for how to call __new__
dirkgr Aug 4, 2021
c0c42e8
Fix some type checks
dirkgr Aug 4, 2021
cb87db6
Makes RefStep as the default step work
dirkgr Aug 4, 2021
bb8bb0f
Changes how from_params works with steps
dirkgr Aug 4, 2021
1bcdf05
Merge branch 'Tango' of https://github.com/allenai/allennlp into Tango
dirkgr Aug 4, 2021
a1595d4
Fix tests
dirkgr Aug 4, 2021
b9624cf
Chasing that locktable error
dirkgr Aug 4, 2021
4e15aab
Hopefully fixes the MDB_BAD_RSLOT error
dirkgr Aug 4, 2021
b7fcb24
Typo
dirkgr Aug 5, 2021
b7d4a92
That wasn't how __new__ works. This is.
dirkgr Aug 5, 2021
153bade
Unique file ids
dirkgr Aug 5, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
`self.ddp_accelerator` during distributed training. This is useful when, for example, instantiating submodules in your
model's `__init__()` method by wrapping them with `self.ddp_accelerator.wrap_module()`. See the `allennlp.modules.transformer.t5`
for an example.
- Added Tango components, to be explored in detail in a later post.

### Fixed

Expand Down
1 change: 1 addition & 0 deletions allennlp/commands/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from allennlp.commands.test_install import TestInstall
from allennlp.commands.train import Train
from allennlp.commands.count_instances import CountInstances
from allennlp.commands.tango import Tango
from allennlp.common.plugins import import_plugins
from allennlp.common.util import import_module_and_submodules
from allennlp.commands.checklist import CheckList
Expand Down
155 changes: 155 additions & 0 deletions allennlp/commands/tango.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
"""
Subcommand for running Tango experiments
"""

import argparse
import logging
import os
from os import PathLike
from pathlib import Path
from typing import Union, Dict, Any, List, Optional

from overrides import overrides

from allennlp.commands.subcommand import Subcommand
from allennlp.common.params import Params
from allennlp.common import logging as common_logging
from allennlp.common import util as common_util
from allennlp.tango.step import step_graph_from_params, tango_dry_run
from allennlp.tango.step import DirectoryStepCache

logger = logging.getLogger(__name__)


@Subcommand.register("tango")
class Tango(Subcommand):
@overrides
def add_subparser(self, parser: argparse._SubParsersAction) -> argparse.ArgumentParser:
description = """Run a tango experiment file."""
subparser = parser.add_parser(self.name, description=description, help=description)

subparser.add_argument("config_path", type=str, help="path to a Tango experiment file")

subparser.add_argument(
"-s",
"--serialization-dir",
required=True,
type=str,
help="directory in which to save the results of the steps",
)

subparser.add_argument(
"-o",
"--overrides",
type=str,
default="",
help=(
"a json(net) structure used to override the experiment configuration, e.g., "
"'{\"vocabulary.min_count.labels\": 10}'. Nested parameters can be specified either"
" with nested dictionaries or with dot syntax."
),
)

subparser.add_argument(
"--dry-run",
action="store_true",
help="Only show what would run. Don't run anything.",
)

subparser.add_argument(
"--file-friendly-logging",
action="store_true",
default=False,
help="outputs tqdm status on separate lines and slows tqdm refresh rate",
)

subparser.set_defaults(func=run_tango_from_args)

return subparser


def run_tango_from_args(args: argparse.Namespace):
run_tango_from_file(
tango_filename=args.config_path,
serialization_dir=args.serialization_dir,
overrides=args.overrides,
include_package=args.include_package,
dry_run=args.dry_run,
file_friendly_logging=args.file_friendly_logging,
)


def run_tango_from_file(
tango_filename: Union[str, PathLike],
serialization_dir: Union[str, PathLike],
overrides: Union[str, Dict[str, Any]] = "",
include_package: Optional[List[str]] = None,
dry_run: bool = False,
file_friendly_logging: bool = False,
):
params = Params.from_file(tango_filename, overrides)
return run_tango(
params=params,
serialization_dir=serialization_dir,
include_package=include_package,
dry_run=dry_run,
file_friendly_logging=file_friendly_logging,
)


def run_tango(
params: Params,
serialization_dir: Union[str, PathLike],
include_package: Optional[List[str]] = None,
dry_run: bool = False,
file_friendly_logging: bool = False,
):
common_logging.FILE_FRIENDLY_LOGGING = file_friendly_logging

if include_package is not None:
for package_name in include_package:
common_util.import_module_and_submodules(package_name)

common_util.prepare_environment(params)

step_graph = step_graph_from_params(params.pop("steps"))

serialization_dir = Path(serialization_dir)
serialization_dir.mkdir(parents=True, exist_ok=True)
step_cache = DirectoryStepCache(serialization_dir / "step_cache")

if dry_run:
for step, cached in tango_dry_run(
(s for s in step_graph.values() if not s.only_if_needed), step_cache
):
if cached:
print(f"Getting {step.name} from cache")
else:
print(f"Computing {step.name}")
else:
# remove symlinks to old results
for filename in serialization_dir.glob("*"):
if filename.is_symlink():
relative_target = os.readlink(filename)
if not relative_target.startswith("step_cache/"):
continue
logger.info(
f"Removing symlink '{filename.name}' to previous result {relative_target}"
)
filename.unlink()

# produce results
for name, step in step_graph.items():
if not step.only_if_needed:
step.ensure_result(step_cache)

# symlink everything that has been computed
for name, step in step_graph.items():
if step in step_cache:
step_link = serialization_dir / name
step_link.unlink(missing_ok=True)
step_link.symlink_to(
step_cache.path_for_step(step).relative_to(serialization_dir),
target_is_directory=True,
)
print(f'The output for "{name}" is in {step_link}.')
2 changes: 0 additions & 2 deletions allennlp/common/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,6 @@ def __init__(self, message: str):
self.message = message

def __str__(self):
# TODO(brendanr): Is there some reason why we need repr here? It
# produces horrible output for simple multi-line error messages.
return self.message


Expand Down
63 changes: 63 additions & 0 deletions allennlp/common/det_hash.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import hashlib
import io
from typing import Any

import base58
import dill


class CustomDetHash:
def det_hash_object(self) -> Any:
"""
By default, `det_hash()` pickles an object, and returns the hash of the pickled
representation. Sometimes you want to take control over what goes into
that hash. In that case, implement this method. `det_hash()` will pickle the
result of this method instead of the object itself.
"""
raise NotImplementedError()


class DetHashFromInitParams(CustomDetHash):
"""
Add this class as a mixin base class to make sure your class's det_hash is derived
exclusively from the parameters passed to __init__().
"""

_det_hash_object: Any

def __new__(cls, *args, **kwargs):
super_new = super(DetHashFromInitParams, cls).__new__
if super().__new__ is object.__new__ and cls.__init__ is not object.__init__:
instance = super_new(cls)
else:
instance = super_new(cls, *args, **kwargs)
instance._det_hash_object = (args, kwargs)
return instance

def det_hash_object(self) -> Any:
return self._det_hash_object


class _DetHashPickler(dill.Pickler):
def persistent_id(self, obj: Any) -> Any:
if isinstance(obj, CustomDetHash):
return obj.__class__.__qualname__, obj.det_hash_object()
elif isinstance(obj, type):
return obj.__module__, obj.__qualname__
else:
return None


def det_hash(o: Any) -> str:
"""
Returns a deterministic hash code of arbitrary Python objects.

If you want to override how we calculate the deterministic hash, derive from the
`CustomDetHash` class and implement `det_hash_object()`.
"""
m = hashlib.blake2b()
with io.BytesIO() as buffer:
pickler = _DetHashPickler(buffer)
pickler.dump(o)
m.update(buffer.getbuffer())
return base58.b58encode(m.digest()).decode()
Loading