|
| 1 | +""" |
| 2 | +# Examples |
| 3 | +
|
| 4 | +```bash |
| 5 | +allennlp diff \ |
| 6 | + hf://roberta-large/pytorch_model.bin \ |
| 7 | + https://storage.googleapis.com/allennlp-public-models/transformer-qa-2020-10-03.tar.gz \ |
| 8 | + --strip-prefix-1 'roberta.' \ |
| 9 | + --strip-prefix-2 '_text_field_embedder.token_embedder_tokens.transformer_model.' |
| 10 | +``` |
| 11 | +""" |
| 12 | +import argparse |
| 13 | +import logging |
| 14 | +from typing import Union, Dict, List, Tuple, NamedTuple, cast |
| 15 | + |
| 16 | +from overrides import overrides |
| 17 | +import termcolor |
| 18 | +import torch |
| 19 | + |
| 20 | +from allennlp.commands.subcommand import Subcommand |
| 21 | +from allennlp.common.file_utils import cached_path |
| 22 | +from allennlp.nn.util import load_state_dict |
| 23 | + |
| 24 | + |
| 25 | +logger = logging.getLogger(__name__) |
| 26 | + |
| 27 | + |
| 28 | +@Subcommand.register("diff") |
| 29 | +class Diff(Subcommand): |
| 30 | + requires_plugins: bool = False |
| 31 | + |
| 32 | + @overrides |
| 33 | + def add_subparser(self, parser: argparse._SubParsersAction) -> argparse.ArgumentParser: |
| 34 | + description = """Display a diff between two model checkpoints.""" |
| 35 | + long_description = ( |
| 36 | + description |
| 37 | + + """ |
| 38 | + In the output, lines start with either a "+", "-", "!", or empty space " ". |
| 39 | + "+" means the corresponding parameter is present in the 2nd checkpoint but not the 1st. |
| 40 | + "-" means the corresponding parameter is present in the 1st checkpoint but not the 2nd. |
| 41 | + "!" means the corresponding parameter is present in both, but has different weights (same shape) |
| 42 | + according to the distance calculation and the '--threshold' value. |
| 43 | + And " " means the corresponding parameter is considered identical in both, i.e. |
| 44 | + the distance falls below the threshold. |
| 45 | + The distance between two tensors is calculated as the root of the |
| 46 | + mean squared difference, multiplied by the '--scale' parameter. |
| 47 | + """ |
| 48 | + ) |
| 49 | + subparser = parser.add_parser( |
| 50 | + self.name, |
| 51 | + description=long_description, |
| 52 | + help=description, |
| 53 | + ) |
| 54 | + subparser.set_defaults(func=_diff) |
| 55 | + subparser.add_argument( |
| 56 | + "checkpoint1", |
| 57 | + type=str, |
| 58 | + help="""the URL, path, or other identifier (see '--checkpoint-type-1') |
| 59 | + to the 1st PyTorch checkpoint.""", |
| 60 | + ) |
| 61 | + subparser.add_argument( |
| 62 | + "checkpoint2", |
| 63 | + type=str, |
| 64 | + help="""the URL, path, or other identifier (see '--checkpoint-type-2') |
| 65 | + to the 2nd PyTorch checkpoint.""", |
| 66 | + ) |
| 67 | + subparser.add_argument( |
| 68 | + "--strip-prefix-1", |
| 69 | + type=str, |
| 70 | + help="""a prefix to remove from all of the 1st checkpoint's keys.""", |
| 71 | + ) |
| 72 | + subparser.add_argument( |
| 73 | + "--strip-prefix-2", |
| 74 | + type=str, |
| 75 | + help="""a prefix to remove from all of the 2nd checkpoint's keys.""", |
| 76 | + ) |
| 77 | + subparser.add_argument( |
| 78 | + "--scale", |
| 79 | + type=float, |
| 80 | + default=1.0, |
| 81 | + help="""controls the scale of the distance calculation.""", |
| 82 | + ) |
| 83 | + subparser.add_argument( |
| 84 | + "--threshold", |
| 85 | + type=float, |
| 86 | + default=1e-5, |
| 87 | + help="""the threshold for the distance between two tensors, |
| 88 | + under which the two tensors are considered identical.""", |
| 89 | + ) |
| 90 | + return subparser |
| 91 | + |
| 92 | + |
| 93 | +class Keep(NamedTuple): |
| 94 | + key: str |
| 95 | + shape: Tuple[int, ...] |
| 96 | + |
| 97 | + def display(self): |
| 98 | + termcolor.cprint(f" {self.key}, shape = {self.shape}") |
| 99 | + |
| 100 | + |
| 101 | +class Insert(NamedTuple): |
| 102 | + key: str |
| 103 | + shape: Tuple[int, ...] |
| 104 | + |
| 105 | + def display(self): |
| 106 | + termcolor.cprint(f"+{self.key}, shape = {self.shape}", "green") |
| 107 | + |
| 108 | + |
| 109 | +class Remove(NamedTuple): |
| 110 | + key: str |
| 111 | + shape: Tuple[int, ...] |
| 112 | + |
| 113 | + def display(self): |
| 114 | + termcolor.cprint(f"-{self.key}, shape = {self.shape}", "red") |
| 115 | + |
| 116 | + |
| 117 | +class Modify(NamedTuple): |
| 118 | + key: str |
| 119 | + shape: Tuple[int, ...] |
| 120 | + distance: float |
| 121 | + |
| 122 | + def display(self): |
| 123 | + termcolor.cprint( |
| 124 | + f"!{self.key}, shape = {self.shape}, distance = {self.distance:.4f}", "yellow" |
| 125 | + ) |
| 126 | + |
| 127 | + |
| 128 | +class _Frontier(NamedTuple): |
| 129 | + x: int |
| 130 | + history: List[Union[Keep, Insert, Remove]] |
| 131 | + |
| 132 | + |
| 133 | +def _finalize( |
| 134 | + history: List[Union[Keep, Insert, Remove]], |
| 135 | + state_dict_a: Dict[str, torch.Tensor], |
| 136 | + state_dict_b: Dict[str, torch.Tensor], |
| 137 | + scale: float, |
| 138 | + threshold: float, |
| 139 | +) -> List[Union[Keep, Insert, Remove, Modify]]: |
| 140 | + out = cast(List[Union[Keep, Insert, Remove, Modify]], history) |
| 141 | + for i, step in enumerate(out): |
| 142 | + if isinstance(step, Keep): |
| 143 | + a_tensor = state_dict_a[step.key] |
| 144 | + b_tensor = state_dict_b[step.key] |
| 145 | + with torch.no_grad(): |
| 146 | + dist = (scale * torch.nn.functional.mse_loss(a_tensor, b_tensor).sqrt()).item() |
| 147 | + if dist > threshold: |
| 148 | + out[i] = Modify(step.key, step.shape, dist) |
| 149 | + return out |
| 150 | + |
| 151 | + |
| 152 | +def checkpoint_diff( |
| 153 | + state_dict_a: Dict[str, torch.Tensor], |
| 154 | + state_dict_b: Dict[str, torch.Tensor], |
| 155 | + scale: float, |
| 156 | + threshold: float, |
| 157 | +) -> List[Union[Keep, Insert, Remove, Modify]]: |
| 158 | + """ |
| 159 | + Uses a modified version of the Myers diff algorithm to compute a representation |
| 160 | + of the diff between two model state dictionaries. |
| 161 | +
|
| 162 | + The only difference is that in addition to the `Keep`, `Insert`, and `Remove` |
| 163 | + operations, we add `Modify`. This corresponds to keeping a parameter |
| 164 | + but changing its weights (not the shape). |
| 165 | +
|
| 166 | + Adapted from [this gist] |
| 167 | + (https://gist.github.com/adamnew123456/37923cf53f51d6b9af32a539cdfa7cc4). |
| 168 | + """ |
| 169 | + param_list_a = [(k, tuple(v.shape)) for k, v in state_dict_a.items()] |
| 170 | + param_list_b = [(k, tuple(v.shape)) for k, v in state_dict_b.items()] |
| 171 | + |
| 172 | + # This marks the farthest-right point along each diagonal in the edit |
| 173 | + # graph, along with the history that got it there |
| 174 | + frontier: Dict[int, _Frontier] = {1: _Frontier(0, [])} |
| 175 | + |
| 176 | + def one(idx): |
| 177 | + """ |
| 178 | + The algorithm Myers presents is 1-indexed; since Python isn't, we |
| 179 | + need a conversion. |
| 180 | + """ |
| 181 | + return idx - 1 |
| 182 | + |
| 183 | + a_max = len(param_list_a) |
| 184 | + b_max = len(param_list_b) |
| 185 | + for d in range(0, a_max + b_max + 1): |
| 186 | + for k in range(-d, d + 1, 2): |
| 187 | + # This determines whether our next search point will be going down |
| 188 | + # in the edit graph, or to the right. |
| 189 | + # |
| 190 | + # The intuition for this is that we should go down if we're on the |
| 191 | + # left edge (k == -d) to make sure that the left edge is fully |
| 192 | + # explored. |
| 193 | + # |
| 194 | + # If we aren't on the top (k != d), then only go down if going down |
| 195 | + # would take us to territory that hasn't sufficiently been explored |
| 196 | + # yet. |
| 197 | + go_down = k == -d or (k != d and frontier[k - 1].x < frontier[k + 1].x) |
| 198 | + |
| 199 | + # Figure out the starting point of this iteration. The diagonal |
| 200 | + # offsets come from the geometry of the edit grid - if you're going |
| 201 | + # down, your diagonal is lower, and if you're going right, your |
| 202 | + # diagonal is higher. |
| 203 | + if go_down: |
| 204 | + old_x, history = frontier[k + 1] |
| 205 | + x = old_x |
| 206 | + else: |
| 207 | + old_x, history = frontier[k - 1] |
| 208 | + x = old_x + 1 |
| 209 | + |
| 210 | + # We want to avoid modifying the old history, since some other step |
| 211 | + # may decide to use it. |
| 212 | + history = history[:] |
| 213 | + y = x - k |
| 214 | + |
| 215 | + # We start at the invalid point (0, 0) - we should only start building |
| 216 | + # up history when we move off of it. |
| 217 | + if 1 <= y <= b_max and go_down: |
| 218 | + history.append(Insert(*param_list_b[one(y)])) |
| 219 | + elif 1 <= x <= a_max: |
| 220 | + history.append(Remove(*param_list_a[one(x)])) |
| 221 | + |
| 222 | + # Chew up as many diagonal moves as we can - these correspond to common lines, |
| 223 | + # and they're considered "free" by the algorithm because we want to maximize |
| 224 | + # the number of these in the output. |
| 225 | + while x < a_max and y < b_max and param_list_a[one(x + 1)] == param_list_b[one(y + 1)]: |
| 226 | + x += 1 |
| 227 | + y += 1 |
| 228 | + history.append(Keep(*param_list_a[one(x)])) |
| 229 | + |
| 230 | + if x >= a_max and y >= b_max: |
| 231 | + # If we're here, then we've traversed through the bottom-left corner, |
| 232 | + # and are done. |
| 233 | + return _finalize(history, state_dict_a, state_dict_b, scale, threshold) |
| 234 | + else: |
| 235 | + frontier[k] = _Frontier(x, history) |
| 236 | + |
| 237 | + assert False, "Could not find edit script" |
| 238 | + |
| 239 | + |
| 240 | +def _get_checkpoint_path(checkpoint: str) -> str: |
| 241 | + if checkpoint.endswith(".tar.gz"): |
| 242 | + return cached_path(checkpoint + "!weights.th", extract_archive=True) |
| 243 | + elif ".tar.gz!" in checkpoint: |
| 244 | + return cached_path(checkpoint, extract_archive=True) |
| 245 | + else: |
| 246 | + return cached_path(checkpoint) |
| 247 | + |
| 248 | + |
| 249 | +def _diff(args: argparse.Namespace): |
| 250 | + checkpoint_1_path = _get_checkpoint_path(args.checkpoint1) |
| 251 | + checkpoint_2_path = _get_checkpoint_path(args.checkpoint2) |
| 252 | + checkpoint_1 = load_state_dict( |
| 253 | + checkpoint_1_path, strip_prefix=args.strip_prefix_1, strict=False |
| 254 | + ) |
| 255 | + checkpoint_2 = load_state_dict( |
| 256 | + checkpoint_2_path, strip_prefix=args.strip_prefix_2, strict=False |
| 257 | + ) |
| 258 | + for step in checkpoint_diff(checkpoint_1, checkpoint_2, args.scale, args.threshold): |
| 259 | + step.display() |
0 commit comments