Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

add diff command #5109

Merged
merged 23 commits into from
May 7, 2021
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Ported the following Huggingface `LambdaLR`-based schedulers: `ConstantLearningRateScheduler`, `ConstantWithWarmupLearningRateScheduler`, `CosineWithWarmupLearningRateScheduler`, `CosineHardRestartsWithWarmupLearningRateScheduler`.
- Added new `sub_token_mode` parameter to `pretrained_transformer_mismatched_embedder` class to support first sub-token embedding
- Added `allennlp diff` command to compute a diff on model checkpoints, analogous to what `git diff` does on two files.
- Added `allennlp.nn.util.load_state_dict` helper function.

### Changed

Expand Down
1 change: 1 addition & 0 deletions allennlp/commands/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from allennlp import __version__
from allennlp.commands.build_vocab import BuildVocab
from allennlp.commands.cached_path import CachedPath
from allennlp.commands.diff import Diff
from allennlp.commands.evaluate import Evaluate
from allennlp.commands.find_learning_rate import FindLearningRate
from allennlp.commands.predict import Predict
Expand Down
214 changes: 214 additions & 0 deletions allennlp/commands/diff.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
"""
# Examples

```bash
allennlp diff \
https://huggingface.co/roberta-large/resolve/main/pytorch_model.bin \
https://storage.googleapis.com/allennlp-public-models/transformer-qa-2020-10-03.tar.gz!weights.th \
--strip-prefix-1 'roberta.' \
--strip-prefix-2 '_text_field_embedder.token_embedder_tokens.transformer_model.'
```
"""
import argparse
import logging
from typing import Union, Dict, List, Tuple, NamedTuple, cast

from overrides import overrides
import termcolor
import torch

from allennlp.commands.subcommand import Subcommand
from allennlp.common.file_utils import cached_path
from allennlp.nn.util import load_state_dict


logger = logging.getLogger(__name__)


@Subcommand.register("diff")
class Diff(Subcommand):
requires_plugins: bool = False

@overrides
def add_subparser(self, parser: argparse._SubParsersAction) -> argparse.ArgumentParser:
description = """Display a diff between two model checkpoints."""
subparser = parser.add_parser(
self.name,
description=description,
help=description,
)
subparser.set_defaults(func=_diff)
subparser.add_argument(
"checkpoint1",
type=str,
help="""The URL or path to the first PyTorch checkpoint file (e.g. '.pt' or '.bin').""",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be nice if these could also point to model archives.

)
subparser.add_argument(
"checkpoint2",
type=str,
help="""The URL or path to the second PyTorch checkpoint file.""",
)
subparser.add_argument(
"--strip-prefix-1",
type=str,
help="""A prefix to remove from all of the first checkpoint's keys.""",
)
subparser.add_argument(
"--strip-prefix-2",
type=str,
help="""A prefix to remove from all of the second checkpoint's keys.""",
)
return subparser


class Keep(NamedTuple):
key: str
shape: Tuple[int, ...]

def display(self):
termcolor.cprint(f" {self.key}, shape = {self.shape}")


class Insert(NamedTuple):
key: str
shape: Tuple[int, ...]

def display(self):
termcolor.cprint(f"+{self.key}, shape = {self.shape}", "green")


class Remove(NamedTuple):
key: str
shape: Tuple[int, ...]

def display(self):
termcolor.cprint(f"-{self.key}, shape = {self.shape}", "red")


class Modify(NamedTuple):
key: str
shape: Tuple[int, ...]
distance: float

def display(self):
termcolor.cprint(f"!{self.key}, shape = {self.shape}, △ = {self.distance:.4f}", "yellow")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I personally don't think that the L2 distance is too insightful, especially when all the tensors are so high-dimensional.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That being said, I can't think of a better single metric to succinctly describe how different two tensors are. If we have such a metric, I don't personally think that sorting by that metric is helpful. I would much rather prefer to see the differing parameters in the order in which they are used during forward propagation.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think a heatmap of the element-wise differences of 2D parameters could be very useful. I don't know how to extend this well to more than 2 dimensions though, perhaps aggregate over the channel dimension?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we flatten the tensors, then bin the weights into a manageable number of bins, then do something with those bins, like show a bar chart of the distance between bins of two tensors?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You don't think that L2 distance is going to show up meaningfully when you want to validate, for example, that your gradual unfreezing schedule works? Maybe we can do better than L2, but I think it's a great start.

Copy link
Member Author

@epwalsh epwalsh Apr 9, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking over the example I gave, there is certainly a strong correlation with the number of elements in a modified parameter and the corresponding Euclidean distance, suggesting this is not the best metric to use.

That said, I'm pretty sure any other L*-based metric (even L∞) would suffer from the same correlation unless we add an additional size-based normalization term.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any way, the exact metric to use should be a configurable option IMO

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe just normalizing by $\sqrt{n}$ is good enough. So then we are really doing

$\sqrt{ \frac{1}{n} \sum_{i=0}^n (x_i - y_i)^2 }$

Or the square root of the mean squared "error". Is this meaningful?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are y'all running plugins that render LaTex properly in GitHub?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Human renderer plugin.



class _Frontier(NamedTuple):
x: int
history: List[Union[Keep, Insert, Remove]]


def _finalize(
history: List[Union[Keep, Insert, Remove]],
state_dict_a: Dict[str, torch.Tensor],
state_dict_b: Dict[str, torch.Tensor],
) -> List[Union[Keep, Insert, Remove, Modify]]:
out = cast(List[Union[Keep, Insert, Remove, Modify]], history)
for i, step in enumerate(out):
if isinstance(step, Keep):
a_tensor = state_dict_a[step.key]
b_tensor = state_dict_b[step.key]
with torch.no_grad():
dist = torch.nn.functional.mse_loss(a_tensor, b_tensor).sqrt()
if dist != 0.0:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we worry about loss of precision here? I.e. maybe we want to check that it's within some small threshold of 0, not exactly 0.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The threshold could be a configurable parameter.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a great point that I just thought of as well! I think we could maybe ensure better precision if we actually did (a_tensor != b_tensor).any(). But, I also like the idea of a configurable parameter. This way, if people train two models with a small difference in implementation, they can also use this tool to identify how similar the weights are with different thresholds epsilon.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a great point

out[i] = Modify(step.key, step.shape, dist)
return out


def checkpoint_diff(
state_dict_a: Dict[str, torch.Tensor], state_dict_b: Dict[str, torch.Tensor]
) -> List[Union[Keep, Insert, Remove, Modify]]:
"""
Uses a modified version of the Myers diff algorithm to compute a representation
of the diff between two model state dictionaries.
Comment on lines +159 to +160
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know that much about Myers, but isn't that only necessary if the order matters? Does the order of entries in the state_dict matter? I thought that's just alphabetical?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The order is meaningful. It is the order in which the corresponding modules were registered. Generally this is the order of data flow.


The only difference is that in addition to the `Keep`, `Insert`, and `Remove`
operations, we add `Modify`. This corresponds to keeping a parameter
but changing its weights (not the shape).

Adapted from [this gist]
(https://gist.github.com/adamnew123456/37923cf53f51d6b9af32a539cdfa7cc4).
"""
param_list_a = [(k, tuple(v.shape)) for k, v in state_dict_a.items()]
param_list_b = [(k, tuple(v.shape)) for k, v in state_dict_b.items()]

# This marks the farthest-right point along each diagonal in the edit
# graph, along with the history that got it there
frontier: Dict[int, _Frontier] = {1: _Frontier(0, [])}

def one(idx):
"""
The algorithm Myers presents is 1-indexed; since Python isn't, we
need a conversion.
"""
return idx - 1

a_max = len(param_list_a)
b_max = len(param_list_b)
for d in range(0, a_max + b_max + 1):
for k in range(-d, d + 1, 2):
# This determines whether our next search point will be going down
# in the edit graph, or to the right.
#
# The intuition for this is that we should go down if we're on the
# left edge (k == -d) to make sure that the left edge is fully
# explored.
#
# If we aren't on the top (k != d), then only go down if going down
# would take us to territory that hasn't sufficiently been explored
# yet.
go_down = k == -d or (k != d and frontier[k - 1].x < frontier[k + 1].x)

# Figure out the starting point of this iteration. The diagonal
# offsets come from the geometry of the edit grid - if you're going
# down, your diagonal is lower, and if you're going right, your
# diagonal is higher.
if go_down:
old_x, history = frontier[k + 1]
x = old_x
else:
old_x, history = frontier[k - 1]
x = old_x + 1

# We want to avoid modifying the old history, since some other step
# may decide to use it.
history = history[:]
y = x - k

# We start at the invalid point (0, 0) - we should only start building
# up history when we move off of it.
if 1 <= y <= b_max and go_down:
history.append(Insert(*param_list_b[one(y)]))
elif 1 <= x <= a_max:
history.append(Remove(*param_list_a[one(x)]))

# Chew up as many diagonal moves as we can - these correspond to common lines,
# and they're considered "free" by the algorithm because we want to maximize
# the number of these in the output.
while x < a_max and y < b_max and param_list_a[one(x + 1)] == param_list_b[one(y + 1)]:
x += 1
y += 1
history.append(Keep(*param_list_a[one(x)]))

if x >= a_max and y >= b_max:
# If we're here, then we've traversed through the bottom-left corner,
# and are done.
return _finalize(history, state_dict_a, state_dict_b)
else:
frontier[k] = _Frontier(x, history)

assert False, "Could not find edit script"


def _diff(args: argparse.Namespace):
checkpoint_1_path = cached_path(args.checkpoint1, extract_archive=True)
checkpoint_2_path = cached_path(args.checkpoint2, extract_archive=True)
checkpoint_1 = load_state_dict(
checkpoint_1_path, strip_prefix=args.strip_prefix_1, strict=False
)
checkpoint_2 = load_state_dict(
checkpoint_2_path, strip_prefix=args.strip_prefix_2, strict=False
)
for step in checkpoint_diff(checkpoint_1, checkpoint_2):
step.display()
2 changes: 1 addition & 1 deletion allennlp/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ def _load(

# Load state dict. We pass `strict=False` so PyTorch doesn't raise a RuntimeError
# if the state dict is missing keys because we handle this case below.
model_state = torch.load(weights_file, map_location=util.device_mapping(cuda_device))
model_state = util.load_state_dict(weights_file, cuda_device=cuda_device)
missing_keys, unexpected_keys = model.load_state_dict(model_state, strict=False)

# Modules might define a class variable called `authorized_missing_keys`,
Expand Down
93 changes: 92 additions & 1 deletion allennlp/nn/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
"""

import copy
from collections import defaultdict, OrderedDict
import json
import logging
from collections import defaultdict
from os import PathLike
import re
from typing import Any, Dict, List, Optional, Sequence, Tuple, TypeVar, Union

import math
Expand Down Expand Up @@ -924,6 +926,95 @@ def inner_device_mapping(storage: torch.Storage, location) -> torch.Storage:
return inner_device_mapping


def load_state_dict(
path: Union[PathLike, str],
strip_prefix: Optional[str] = None,
ignore: Optional[List[str]] = None,
strict: bool = True,
cuda_device: int = -1,
) -> Dict[str, torch.Tensor]:
"""
Load a PyTorch model state dictionary from a checkpoint at the given `path`.

# Parameters

path : `Union[PathLike, str]`, required

strip_prefix : `Optional[str]`, optional (default = `None`)
A prefix to remove from all of the state dict keys.

ignore : `Optional[List[str]]`, optional (default = `None`)
Optional list of regular expressions. Keys that match any of these will be removed
from the state dict.

!!! Note
If `strip_prefix` is given, the regular expressions in `ignore` are matched
before the prefix is stripped.

strict : `bool`, optional (default = `True`)
If `True` (the default) and `strip_prefix` was never used or any of the regular expressions
in `ignore` never matched, a `ValueError` will be raised.

cuda_device : `int`, optional (default = `-1`)
The device to load the parameters onto. Use `-1` (the default) for CPU.

# Returns

`Dict[str, torch.Tensor]`
An ordered dictionary of the state.
"""
state = torch.load(path, map_location=device_mapping(cuda_device))
out: Dict[str, torch.Tensor] = OrderedDict()

if ignore is not None and not isinstance(ignore, list):
# If user accidentally passed in something that is not a list - like a string,
# which is easy to do - the user would be confused why the resulting state dict
# is empty.
raise ValueError("'ignore' parameter should be a list")

# In 'strict' mode, we need to keep track of whether we've used `strip_prefix`
# and which regular expressions in `ignore` we've used.
strip_prefix_used: Optional[bool] = None
ignore_used: Optional[List[bool]] = None
if strict and strip_prefix is not None:
strip_prefix_used = False
if strict and ignore:
ignore_used = [False] * len(ignore)

for key in state.keys():
ignore_key = False
if ignore:
for i, pattern in enumerate(ignore):
if re.match(pattern, key):
if ignore_used:
ignore_used[i] = True
logger.warning("ignoring %s from state dict", key)
ignore_key = True
break

if ignore_key:
continue

new_key = key

if strip_prefix and key.startswith(strip_prefix):
strip_prefix_used = True
new_key = key[len(strip_prefix) :]
if not new_key:
raise ValueError("'strip_prefix' resulted in an empty string for a key")

out[new_key] = state[key]

if strip_prefix_used is False:
raise ValueError(f"'strip_prefix' of '{strip_prefix}' was never used")
if ignore is not None and ignore_used is not None:
for pattern, used in zip(ignore, ignore_used):
if not used:
raise ValueError(f"'ignore' pattern '{pattern}' didn't have any matches")

return out


def combine_tensors(combination: str, tensors: List[torch.Tensor]) -> torch.Tensor:
"""
Combines a list of tensors using element-wise operations and concatenation, specified by a
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
"filelock>=3.0,<3.1",
"lmdb",
"more-itertools",
"termcolor==1.1.0",
"wandb>=0.10.0,<0.11.0",
],
entry_points={"console_scripts": ["allennlp=allennlp.__main__:run"]},
Expand Down