Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

Commit 7473737

Browse files
epwalshdirkgr
andauthored
add diff command (#5109)
* add diff command * fix docs * no silly geese * update CHANGELOG * move 'load_state_dict' to nn.util * normalize by size * handle different checkpoint types * add integration tests * add 'scale' and 'threshold' params * HuggingFace Hub support * support '_/' as well, add test * revert some changes * fix * Update CHANGELOG.md * Update codecov.yml Co-authored-by: Dirk Groeneveld <[email protected]>
1 parent d85c5c3 commit 7473737

File tree

8 files changed

+532
-2
lines changed

8 files changed

+532
-2
lines changed

CHANGELOG.md

+2
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1616
### Added
1717

1818
- Added `TaskSuite` base class and command line functionality for running [`checklist`](https://github.com/marcotcr/checklist) test suites, along with implementations for `SentimentAnalysisSuite`, `QuestionAnsweringSuite`, and `TextualEntailmentSuite`. These can be found in the `allennlp.sanity_checks.task_checklists` module.
19+
- Added `allennlp diff` command to compute a diff on model checkpoints, analogous to what `git diff` does on two files.
20+
- Added `allennlp.nn.util.load_state_dict` helper function.
1921
- Added a way to avoid downloading and loading pretrained weights in modules that wrap transformers
2022
such as the `PretrainedTransformerEmbedder` and `PretrainedTransformerMismatchedEmbedder`.
2123
You can do this by setting the parameter `load_weights` to `False`.

allennlp/commands/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from allennlp import __version__
99
from allennlp.commands.build_vocab import BuildVocab
1010
from allennlp.commands.cached_path import CachedPath
11+
from allennlp.commands.diff import Diff
1112
from allennlp.commands.evaluate import Evaluate
1213
from allennlp.commands.find_learning_rate import FindLearningRate
1314
from allennlp.commands.predict import Predict

allennlp/commands/diff.py

+259
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,259 @@
1+
"""
2+
# Examples
3+
4+
```bash
5+
allennlp diff \
6+
hf://roberta-large/pytorch_model.bin \
7+
https://storage.googleapis.com/allennlp-public-models/transformer-qa-2020-10-03.tar.gz \
8+
--strip-prefix-1 'roberta.' \
9+
--strip-prefix-2 '_text_field_embedder.token_embedder_tokens.transformer_model.'
10+
```
11+
"""
12+
import argparse
13+
import logging
14+
from typing import Union, Dict, List, Tuple, NamedTuple, cast
15+
16+
from overrides import overrides
17+
import termcolor
18+
import torch
19+
20+
from allennlp.commands.subcommand import Subcommand
21+
from allennlp.common.file_utils import cached_path
22+
from allennlp.nn.util import load_state_dict
23+
24+
25+
logger = logging.getLogger(__name__)
26+
27+
28+
@Subcommand.register("diff")
29+
class Diff(Subcommand):
30+
requires_plugins: bool = False
31+
32+
@overrides
33+
def add_subparser(self, parser: argparse._SubParsersAction) -> argparse.ArgumentParser:
34+
description = """Display a diff between two model checkpoints."""
35+
long_description = (
36+
description
37+
+ """
38+
In the output, lines start with either a "+", "-", "!", or empty space " ".
39+
"+" means the corresponding parameter is present in the 2nd checkpoint but not the 1st.
40+
"-" means the corresponding parameter is present in the 1st checkpoint but not the 2nd.
41+
"!" means the corresponding parameter is present in both, but has different weights (same shape)
42+
according to the distance calculation and the '--threshold' value.
43+
And " " means the corresponding parameter is considered identical in both, i.e.
44+
the distance falls below the threshold.
45+
The distance between two tensors is calculated as the root of the
46+
mean squared difference, multiplied by the '--scale' parameter.
47+
"""
48+
)
49+
subparser = parser.add_parser(
50+
self.name,
51+
description=long_description,
52+
help=description,
53+
)
54+
subparser.set_defaults(func=_diff)
55+
subparser.add_argument(
56+
"checkpoint1",
57+
type=str,
58+
help="""the URL, path, or other identifier (see '--checkpoint-type-1')
59+
to the 1st PyTorch checkpoint.""",
60+
)
61+
subparser.add_argument(
62+
"checkpoint2",
63+
type=str,
64+
help="""the URL, path, or other identifier (see '--checkpoint-type-2')
65+
to the 2nd PyTorch checkpoint.""",
66+
)
67+
subparser.add_argument(
68+
"--strip-prefix-1",
69+
type=str,
70+
help="""a prefix to remove from all of the 1st checkpoint's keys.""",
71+
)
72+
subparser.add_argument(
73+
"--strip-prefix-2",
74+
type=str,
75+
help="""a prefix to remove from all of the 2nd checkpoint's keys.""",
76+
)
77+
subparser.add_argument(
78+
"--scale",
79+
type=float,
80+
default=1.0,
81+
help="""controls the scale of the distance calculation.""",
82+
)
83+
subparser.add_argument(
84+
"--threshold",
85+
type=float,
86+
default=1e-5,
87+
help="""the threshold for the distance between two tensors,
88+
under which the two tensors are considered identical.""",
89+
)
90+
return subparser
91+
92+
93+
class Keep(NamedTuple):
94+
key: str
95+
shape: Tuple[int, ...]
96+
97+
def display(self):
98+
termcolor.cprint(f" {self.key}, shape = {self.shape}")
99+
100+
101+
class Insert(NamedTuple):
102+
key: str
103+
shape: Tuple[int, ...]
104+
105+
def display(self):
106+
termcolor.cprint(f"+{self.key}, shape = {self.shape}", "green")
107+
108+
109+
class Remove(NamedTuple):
110+
key: str
111+
shape: Tuple[int, ...]
112+
113+
def display(self):
114+
termcolor.cprint(f"-{self.key}, shape = {self.shape}", "red")
115+
116+
117+
class Modify(NamedTuple):
118+
key: str
119+
shape: Tuple[int, ...]
120+
distance: float
121+
122+
def display(self):
123+
termcolor.cprint(
124+
f"!{self.key}, shape = {self.shape}, distance = {self.distance:.4f}", "yellow"
125+
)
126+
127+
128+
class _Frontier(NamedTuple):
129+
x: int
130+
history: List[Union[Keep, Insert, Remove]]
131+
132+
133+
def _finalize(
134+
history: List[Union[Keep, Insert, Remove]],
135+
state_dict_a: Dict[str, torch.Tensor],
136+
state_dict_b: Dict[str, torch.Tensor],
137+
scale: float,
138+
threshold: float,
139+
) -> List[Union[Keep, Insert, Remove, Modify]]:
140+
out = cast(List[Union[Keep, Insert, Remove, Modify]], history)
141+
for i, step in enumerate(out):
142+
if isinstance(step, Keep):
143+
a_tensor = state_dict_a[step.key]
144+
b_tensor = state_dict_b[step.key]
145+
with torch.no_grad():
146+
dist = (scale * torch.nn.functional.mse_loss(a_tensor, b_tensor).sqrt()).item()
147+
if dist > threshold:
148+
out[i] = Modify(step.key, step.shape, dist)
149+
return out
150+
151+
152+
def checkpoint_diff(
153+
state_dict_a: Dict[str, torch.Tensor],
154+
state_dict_b: Dict[str, torch.Tensor],
155+
scale: float,
156+
threshold: float,
157+
) -> List[Union[Keep, Insert, Remove, Modify]]:
158+
"""
159+
Uses a modified version of the Myers diff algorithm to compute a representation
160+
of the diff between two model state dictionaries.
161+
162+
The only difference is that in addition to the `Keep`, `Insert`, and `Remove`
163+
operations, we add `Modify`. This corresponds to keeping a parameter
164+
but changing its weights (not the shape).
165+
166+
Adapted from [this gist]
167+
(https://gist.github.com/adamnew123456/37923cf53f51d6b9af32a539cdfa7cc4).
168+
"""
169+
param_list_a = [(k, tuple(v.shape)) for k, v in state_dict_a.items()]
170+
param_list_b = [(k, tuple(v.shape)) for k, v in state_dict_b.items()]
171+
172+
# This marks the farthest-right point along each diagonal in the edit
173+
# graph, along with the history that got it there
174+
frontier: Dict[int, _Frontier] = {1: _Frontier(0, [])}
175+
176+
def one(idx):
177+
"""
178+
The algorithm Myers presents is 1-indexed; since Python isn't, we
179+
need a conversion.
180+
"""
181+
return idx - 1
182+
183+
a_max = len(param_list_a)
184+
b_max = len(param_list_b)
185+
for d in range(0, a_max + b_max + 1):
186+
for k in range(-d, d + 1, 2):
187+
# This determines whether our next search point will be going down
188+
# in the edit graph, or to the right.
189+
#
190+
# The intuition for this is that we should go down if we're on the
191+
# left edge (k == -d) to make sure that the left edge is fully
192+
# explored.
193+
#
194+
# If we aren't on the top (k != d), then only go down if going down
195+
# would take us to territory that hasn't sufficiently been explored
196+
# yet.
197+
go_down = k == -d or (k != d and frontier[k - 1].x < frontier[k + 1].x)
198+
199+
# Figure out the starting point of this iteration. The diagonal
200+
# offsets come from the geometry of the edit grid - if you're going
201+
# down, your diagonal is lower, and if you're going right, your
202+
# diagonal is higher.
203+
if go_down:
204+
old_x, history = frontier[k + 1]
205+
x = old_x
206+
else:
207+
old_x, history = frontier[k - 1]
208+
x = old_x + 1
209+
210+
# We want to avoid modifying the old history, since some other step
211+
# may decide to use it.
212+
history = history[:]
213+
y = x - k
214+
215+
# We start at the invalid point (0, 0) - we should only start building
216+
# up history when we move off of it.
217+
if 1 <= y <= b_max and go_down:
218+
history.append(Insert(*param_list_b[one(y)]))
219+
elif 1 <= x <= a_max:
220+
history.append(Remove(*param_list_a[one(x)]))
221+
222+
# Chew up as many diagonal moves as we can - these correspond to common lines,
223+
# and they're considered "free" by the algorithm because we want to maximize
224+
# the number of these in the output.
225+
while x < a_max and y < b_max and param_list_a[one(x + 1)] == param_list_b[one(y + 1)]:
226+
x += 1
227+
y += 1
228+
history.append(Keep(*param_list_a[one(x)]))
229+
230+
if x >= a_max and y >= b_max:
231+
# If we're here, then we've traversed through the bottom-left corner,
232+
# and are done.
233+
return _finalize(history, state_dict_a, state_dict_b, scale, threshold)
234+
else:
235+
frontier[k] = _Frontier(x, history)
236+
237+
assert False, "Could not find edit script"
238+
239+
240+
def _get_checkpoint_path(checkpoint: str) -> str:
241+
if checkpoint.endswith(".tar.gz"):
242+
return cached_path(checkpoint + "!weights.th", extract_archive=True)
243+
elif ".tar.gz!" in checkpoint:
244+
return cached_path(checkpoint, extract_archive=True)
245+
else:
246+
return cached_path(checkpoint)
247+
248+
249+
def _diff(args: argparse.Namespace):
250+
checkpoint_1_path = _get_checkpoint_path(args.checkpoint1)
251+
checkpoint_2_path = _get_checkpoint_path(args.checkpoint2)
252+
checkpoint_1 = load_state_dict(
253+
checkpoint_1_path, strip_prefix=args.strip_prefix_1, strict=False
254+
)
255+
checkpoint_2 = load_state_dict(
256+
checkpoint_2_path, strip_prefix=args.strip_prefix_2, strict=False
257+
)
258+
for step in checkpoint_diff(checkpoint_1, checkpoint_2, args.scale, args.threshold):
259+
step.display()

allennlp/models/model.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -335,7 +335,7 @@ def _load(
335335

336336
# Load state dict. We pass `strict=False` so PyTorch doesn't raise a RuntimeError
337337
# if the state dict is missing keys because we handle this case below.
338-
model_state = torch.load(weights_file, map_location=util.device_mapping(cuda_device))
338+
model_state = util.load_state_dict(weights_file, cuda_device=cuda_device)
339339
missing_keys, unexpected_keys = model.load_state_dict(model_state, strict=False)
340340

341341
# Modules might define a class variable called `authorized_missing_keys`,

allennlp/nn/util.py

+92-1
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,11 @@
33
"""
44

55
import copy
6+
from collections import defaultdict, OrderedDict
67
import json
78
import logging
8-
from collections import defaultdict
9+
from os import PathLike
10+
import re
911
from typing import Any, Dict, List, Optional, Sequence, Tuple, TypeVar, Union
1012

1113
import math
@@ -924,6 +926,95 @@ def inner_device_mapping(storage: torch.Storage, location) -> torch.Storage:
924926
return inner_device_mapping
925927

926928

929+
def load_state_dict(
930+
path: Union[PathLike, str],
931+
strip_prefix: Optional[str] = None,
932+
ignore: Optional[List[str]] = None,
933+
strict: bool = True,
934+
cuda_device: int = -1,
935+
) -> Dict[str, torch.Tensor]:
936+
"""
937+
Load a PyTorch model state dictionary from a checkpoint at the given `path`.
938+
939+
# Parameters
940+
941+
path : `Union[PathLike, str]`, required
942+
943+
strip_prefix : `Optional[str]`, optional (default = `None`)
944+
A prefix to remove from all of the state dict keys.
945+
946+
ignore : `Optional[List[str]]`, optional (default = `None`)
947+
Optional list of regular expressions. Keys that match any of these will be removed
948+
from the state dict.
949+
950+
!!! Note
951+
If `strip_prefix` is given, the regular expressions in `ignore` are matched
952+
before the prefix is stripped.
953+
954+
strict : `bool`, optional (default = `True`)
955+
If `True` (the default) and `strip_prefix` was never used or any of the regular expressions
956+
in `ignore` never matched, a `ValueError` will be raised.
957+
958+
cuda_device : `int`, optional (default = `-1`)
959+
The device to load the parameters onto. Use `-1` (the default) for CPU.
960+
961+
# Returns
962+
963+
`Dict[str, torch.Tensor]`
964+
An ordered dictionary of the state.
965+
"""
966+
state = torch.load(path, map_location=device_mapping(cuda_device))
967+
out: Dict[str, torch.Tensor] = OrderedDict()
968+
969+
if ignore is not None and not isinstance(ignore, list):
970+
# If user accidentally passed in something that is not a list - like a string,
971+
# which is easy to do - the user would be confused why the resulting state dict
972+
# is empty.
973+
raise ValueError("'ignore' parameter should be a list")
974+
975+
# In 'strict' mode, we need to keep track of whether we've used `strip_prefix`
976+
# and which regular expressions in `ignore` we've used.
977+
strip_prefix_used: Optional[bool] = None
978+
ignore_used: Optional[List[bool]] = None
979+
if strict and strip_prefix is not None:
980+
strip_prefix_used = False
981+
if strict and ignore:
982+
ignore_used = [False] * len(ignore)
983+
984+
for key in state.keys():
985+
ignore_key = False
986+
if ignore:
987+
for i, pattern in enumerate(ignore):
988+
if re.match(pattern, key):
989+
if ignore_used:
990+
ignore_used[i] = True
991+
logger.warning("ignoring %s from state dict", key)
992+
ignore_key = True
993+
break
994+
995+
if ignore_key:
996+
continue
997+
998+
new_key = key
999+
1000+
if strip_prefix and key.startswith(strip_prefix):
1001+
strip_prefix_used = True
1002+
new_key = key[len(strip_prefix) :]
1003+
if not new_key:
1004+
raise ValueError("'strip_prefix' resulted in an empty string for a key")
1005+
1006+
out[new_key] = state[key]
1007+
1008+
if strip_prefix_used is False:
1009+
raise ValueError(f"'strip_prefix' of '{strip_prefix}' was never used")
1010+
if ignore is not None and ignore_used is not None:
1011+
for pattern, used in zip(ignore, ignore_used):
1012+
if not used:
1013+
raise ValueError(f"'ignore' pattern '{pattern}' didn't have any matches")
1014+
1015+
return out
1016+
1017+
9271018
def combine_tensors(combination: str, tensors: List[torch.Tensor]) -> torch.Tensor:
9281019
"""
9291020
Combines a list of tensors using element-wise operations and concatenation, specified by a

0 commit comments

Comments
 (0)