-
Notifications
You must be signed in to change notification settings - Fork 2.2k
add diff command #5109
add diff command #5109
Changes from 6 commits
4026e57
4323d41
21142e2
4791638
b3dd8c6
bacdab2
8671b90
ba06ab3
5421929
509358a
8d7f98b
8ee42bb
d7a8e29
3c57931
0af5ada
7645e08
79ad8eb
1a79953
e6141c6
1cf507e
6cb351d
a870615
7050d75
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,214 @@ | ||
""" | ||
# Examples | ||
|
||
```bash | ||
allennlp diff \ | ||
https://huggingface.co/roberta-large/resolve/main/pytorch_model.bin \ | ||
https://storage.googleapis.com/allennlp-public-models/transformer-qa-2020-10-03.tar.gz!weights.th \ | ||
--strip-prefix-1 'roberta.' \ | ||
--strip-prefix-2 '_text_field_embedder.token_embedder_tokens.transformer_model.' | ||
``` | ||
""" | ||
import argparse | ||
import logging | ||
from typing import Union, Dict, List, Tuple, NamedTuple, cast | ||
|
||
from overrides import overrides | ||
import termcolor | ||
import torch | ||
|
||
from allennlp.commands.subcommand import Subcommand | ||
from allennlp.common.file_utils import cached_path | ||
from allennlp.nn.util import load_state_dict | ||
|
||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
@Subcommand.register("diff") | ||
class Diff(Subcommand): | ||
requires_plugins: bool = False | ||
|
||
@overrides | ||
def add_subparser(self, parser: argparse._SubParsersAction) -> argparse.ArgumentParser: | ||
description = """Display a diff between two model checkpoints.""" | ||
subparser = parser.add_parser( | ||
self.name, | ||
description=description, | ||
help=description, | ||
) | ||
subparser.set_defaults(func=_diff) | ||
subparser.add_argument( | ||
"checkpoint1", | ||
type=str, | ||
help="""The URL or path to the first PyTorch checkpoint file (e.g. '.pt' or '.bin').""", | ||
) | ||
subparser.add_argument( | ||
"checkpoint2", | ||
type=str, | ||
help="""The URL or path to the second PyTorch checkpoint file.""", | ||
) | ||
subparser.add_argument( | ||
"--strip-prefix-1", | ||
type=str, | ||
help="""A prefix to remove from all of the first checkpoint's keys.""", | ||
) | ||
subparser.add_argument( | ||
"--strip-prefix-2", | ||
type=str, | ||
help="""A prefix to remove from all of the second checkpoint's keys.""", | ||
) | ||
return subparser | ||
|
||
|
||
class Keep(NamedTuple): | ||
key: str | ||
shape: Tuple[int, ...] | ||
|
||
def display(self): | ||
termcolor.cprint(f" {self.key}, shape = {self.shape}") | ||
|
||
|
||
class Insert(NamedTuple): | ||
key: str | ||
shape: Tuple[int, ...] | ||
|
||
def display(self): | ||
termcolor.cprint(f"+{self.key}, shape = {self.shape}", "green") | ||
|
||
|
||
class Remove(NamedTuple): | ||
key: str | ||
shape: Tuple[int, ...] | ||
|
||
def display(self): | ||
termcolor.cprint(f"-{self.key}, shape = {self.shape}", "red") | ||
|
||
|
||
class Modify(NamedTuple): | ||
key: str | ||
shape: Tuple[int, ...] | ||
distance: float | ||
|
||
def display(self): | ||
termcolor.cprint(f"!{self.key}, shape = {self.shape}, △ = {self.distance:.4f}", "yellow") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I personally don't think that the L2 distance is too insightful, especially when all the tensors are so high-dimensional. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That being said, I can't think of a better single metric to succinctly describe how different two tensors are. If we have such a metric, I don't personally think that sorting by that metric is helpful. I would much rather prefer to see the differing parameters in the order in which they are used during forward propagation. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think a heatmap of the element-wise differences of 2D parameters could be very useful. I don't know how to extend this well to more than 2 dimensions though, perhaps aggregate over the channel dimension? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe we flatten the tensors, then bin the weights into a manageable number of bins, then do something with those bins, like show a bar chart of the distance between bins of two tensors? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You don't think that L2 distance is going to show up meaningfully when you want to validate, for example, that your gradual unfreezing schedule works? Maybe we can do better than L2, but I think it's a great start. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Looking over the example I gave, there is certainly a strong correlation with the number of elements in a modified parameter and the corresponding Euclidean distance, suggesting this is not the best metric to use. That said, I'm pretty sure any other L*-based metric (even L∞) would suffer from the same correlation unless we add an additional size-based normalization term. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Any way, the exact metric to use should be a configurable option IMO There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe just normalizing by Or the square root of the mean squared "error". Is this meaningful? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are y'all running plugins that render LaTex properly in GitHub? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Human renderer plugin. |
||
|
||
|
||
class _Frontier(NamedTuple): | ||
x: int | ||
history: List[Union[Keep, Insert, Remove]] | ||
|
||
|
||
def _finalize( | ||
history: List[Union[Keep, Insert, Remove]], | ||
state_dict_a: Dict[str, torch.Tensor], | ||
state_dict_b: Dict[str, torch.Tensor], | ||
) -> List[Union[Keep, Insert, Remove, Modify]]: | ||
out = cast(List[Union[Keep, Insert, Remove, Modify]], history) | ||
for i, step in enumerate(out): | ||
if isinstance(step, Keep): | ||
a_tensor = state_dict_a[step.key] | ||
b_tensor = state_dict_b[step.key] | ||
with torch.no_grad(): | ||
dist = torch.nn.functional.mse_loss(a_tensor, b_tensor).sqrt() | ||
if dist != 0.0: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we worry about loss of precision here? I.e. maybe we want to check that it's within some small threshold of 0, not exactly 0. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The threshold could be a configurable parameter. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a great point that I just thought of as well! I think we could maybe ensure better precision if we actually did (a_tensor != b_tensor).any(). But, I also like the idea of a configurable parameter. This way, if people train two models with a small difference in implementation, they can also use this tool to identify how similar the weights are with different thresholds epsilon. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's a great point |
||
out[i] = Modify(step.key, step.shape, dist) | ||
return out | ||
|
||
|
||
def checkpoint_diff( | ||
state_dict_a: Dict[str, torch.Tensor], state_dict_b: Dict[str, torch.Tensor] | ||
) -> List[Union[Keep, Insert, Remove, Modify]]: | ||
""" | ||
Uses a modified version of the Myers diff algorithm to compute a representation | ||
of the diff between two model state dictionaries. | ||
Comment on lines
+159
to
+160
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't know that much about Myers, but isn't that only necessary if the order matters? Does the order of entries in the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The order is meaningful. It is the order in which the corresponding modules were registered. Generally this is the order of data flow. |
||
|
||
The only difference is that in addition to the `Keep`, `Insert`, and `Remove` | ||
operations, we add `Modify`. This corresponds to keeping a parameter | ||
but changing its weights (not the shape). | ||
|
||
Adapted from [this gist] | ||
(https://gist.github.com/adamnew123456/37923cf53f51d6b9af32a539cdfa7cc4). | ||
""" | ||
param_list_a = [(k, tuple(v.shape)) for k, v in state_dict_a.items()] | ||
param_list_b = [(k, tuple(v.shape)) for k, v in state_dict_b.items()] | ||
|
||
# This marks the farthest-right point along each diagonal in the edit | ||
# graph, along with the history that got it there | ||
frontier: Dict[int, _Frontier] = {1: _Frontier(0, [])} | ||
|
||
def one(idx): | ||
""" | ||
The algorithm Myers presents is 1-indexed; since Python isn't, we | ||
need a conversion. | ||
""" | ||
return idx - 1 | ||
|
||
a_max = len(param_list_a) | ||
b_max = len(param_list_b) | ||
for d in range(0, a_max + b_max + 1): | ||
for k in range(-d, d + 1, 2): | ||
# This determines whether our next search point will be going down | ||
# in the edit graph, or to the right. | ||
# | ||
# The intuition for this is that we should go down if we're on the | ||
# left edge (k == -d) to make sure that the left edge is fully | ||
# explored. | ||
# | ||
# If we aren't on the top (k != d), then only go down if going down | ||
# would take us to territory that hasn't sufficiently been explored | ||
# yet. | ||
go_down = k == -d or (k != d and frontier[k - 1].x < frontier[k + 1].x) | ||
|
||
# Figure out the starting point of this iteration. The diagonal | ||
# offsets come from the geometry of the edit grid - if you're going | ||
# down, your diagonal is lower, and if you're going right, your | ||
# diagonal is higher. | ||
if go_down: | ||
old_x, history = frontier[k + 1] | ||
x = old_x | ||
else: | ||
old_x, history = frontier[k - 1] | ||
x = old_x + 1 | ||
|
||
# We want to avoid modifying the old history, since some other step | ||
# may decide to use it. | ||
history = history[:] | ||
y = x - k | ||
|
||
# We start at the invalid point (0, 0) - we should only start building | ||
# up history when we move off of it. | ||
if 1 <= y <= b_max and go_down: | ||
history.append(Insert(*param_list_b[one(y)])) | ||
elif 1 <= x <= a_max: | ||
history.append(Remove(*param_list_a[one(x)])) | ||
|
||
# Chew up as many diagonal moves as we can - these correspond to common lines, | ||
# and they're considered "free" by the algorithm because we want to maximize | ||
# the number of these in the output. | ||
while x < a_max and y < b_max and param_list_a[one(x + 1)] == param_list_b[one(y + 1)]: | ||
x += 1 | ||
y += 1 | ||
history.append(Keep(*param_list_a[one(x)])) | ||
|
||
if x >= a_max and y >= b_max: | ||
# If we're here, then we've traversed through the bottom-left corner, | ||
# and are done. | ||
return _finalize(history, state_dict_a, state_dict_b) | ||
else: | ||
frontier[k] = _Frontier(x, history) | ||
|
||
assert False, "Could not find edit script" | ||
|
||
|
||
def _diff(args: argparse.Namespace): | ||
checkpoint_1_path = cached_path(args.checkpoint1, extract_archive=True) | ||
checkpoint_2_path = cached_path(args.checkpoint2, extract_archive=True) | ||
checkpoint_1 = load_state_dict( | ||
checkpoint_1_path, strip_prefix=args.strip_prefix_1, strict=False | ||
) | ||
checkpoint_2 = load_state_dict( | ||
checkpoint_2_path, strip_prefix=args.strip_prefix_2, strict=False | ||
) | ||
for step in checkpoint_diff(checkpoint_1, checkpoint_2): | ||
step.display() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be nice if these could also point to model archives.