-
Notifications
You must be signed in to change notification settings - Fork 252
/
Copy pathutils.py
executable file
·412 lines (333 loc) · 14.3 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
# coding=utf-8
# Copyright 2022 the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
import subprocess
import time
from typing import Any, Dict, List
import numpy as np
import torch
from packaging import version
from transformers.utils import is_torch_available
from optimum.utils import logging
from .version import __version__
logger = logging.get_logger(__name__)
CURRENTLY_VALIDATED_SYNAPSE_VERSION = version.parse("1.18.0")
def to_device_dtype(my_input: Any, target_device: torch.device = None, target_dtype: torch.dtype = None):
"""
Move a state_dict to the target device and convert it into target_dtype.
Args:
my_input : input to transform
target_device (torch.device, optional): target_device to move the input on. Defaults to None.
target_dtype (torch.dtype, optional): target dtype to convert the input into. Defaults to None.
Returns:
: transformed input
"""
if isinstance(my_input, torch.Tensor):
if target_device is None:
target_device = my_input.device
if target_dtype is None:
target_dtype = my_input.dtype
return my_input.to(device=target_device, dtype=target_dtype)
elif isinstance(my_input, list):
return [to_device_dtype(i, target_device, target_dtype) for i in my_input]
elif isinstance(my_input, tuple):
return tuple(to_device_dtype(i, target_device, target_dtype) for i in my_input)
elif isinstance(my_input, dict):
return {k: to_device_dtype(v, target_device, target_dtype) for k, v in my_input.items()}
else:
return my_input
def speed_metrics(
split: str,
start_time: float,
num_samples: int = None,
num_steps: int = None,
num_tokens: int = None,
start_time_after_warmup: float = None,
log_evaluate_save_time: float = None,
) -> Dict[str, float]:
"""
Measure and return speed performance metrics.
This function requires a time snapshot `start_time` before the operation to be measured starts and this function
should be run immediately after the operation to be measured has completed.
Args:
split (str): name to prefix metric (like train, eval, test...)
start_time (float): operation start time
num_samples (int, optional): number of samples processed. Defaults to None.
num_steps (int, optional): number of steps performed. Defaults to None.
num_tokens (int, optional): number of tokens processed. Defaults to None.
start_time_after_warmup (float, optional): time after warmup steps have been performed. Defaults to None.
log_evaluate_save_time (float, optional): time spent to log, evaluate and save. Defaults to None.
Returns:
Dict[str, float]: dictionary with performance metrics.
"""
runtime = time.time() - start_time
result = {f"{split}_runtime": round(runtime, 4)}
if runtime == 0:
return result
# Adjust runtime if log_evaluate_save_time should not be included
if log_evaluate_save_time is not None:
runtime = runtime - log_evaluate_save_time
# Adjust runtime if there were warmup steps
if start_time_after_warmup is not None:
runtime = runtime + start_time - start_time_after_warmup
# Compute throughputs
if num_samples is not None:
samples_per_second = num_samples / runtime
result[f"{split}_samples_per_second"] = round(samples_per_second, 3)
if num_steps is not None:
steps_per_second = num_steps / runtime
result[f"{split}_steps_per_second"] = round(steps_per_second, 3)
if num_tokens is not None:
tokens_per_second = num_tokens / runtime
result[f"{split}_tokens_per_second"] = round(tokens_per_second, 3)
return result
def warmup_inference_steps_time_adjustment(
start_time_after_warmup, start_time_after_inference_steps_warmup, num_inference_steps, warmup_steps
):
"""
Adjust start time after warmup to account for warmup inference steps.
When warmup is applied to multiple inference steps within a single sample generation we need to account for
skipped inference steps time to estimate "per sample generation time". This function computes the average
inference time per step and adjusts the start time after warmup accordingly.
Args:
start_time_after_warmup: time after warmup steps have been performed
start_time_after_inference_steps_warmup: time after warmup inference steps have been performed
num_inference_steps: total number of inference steps per sample generation
warmup_steps: number of warmup steps
Returns:
[float]: adjusted start time after warmup which accounts for warmup inference steps based on average non-warmup steps time
"""
if num_inference_steps > warmup_steps:
avg_time_per_inference_step = (time.time() - start_time_after_inference_steps_warmup) / (
num_inference_steps - warmup_steps
)
start_time_after_warmup -= avg_time_per_inference_step * warmup_steps
return start_time_after_warmup
def to_gb_rounded(mem: float) -> float:
"""
Rounds and converts to GB.
Args:
mem (float): memory in bytes
Returns:
float: memory in GB rounded to the second decimal
"""
return np.round(mem / 1024**3, 2)
def get_hpu_memory_stats(device=None) -> Dict[str, float]:
"""
Returns memory stats of HPU as a dictionary:
- current memory allocated (GB)
- maximum memory allocated (GB)
- total memory available (GB)
Returns:
Dict[str, float]: memory stats.
"""
from habana_frameworks.torch.hpu import memory_stats
mem_stats = memory_stats(device)
mem_dict = {
"memory_allocated (GB)": to_gb_rounded(mem_stats["InUse"]),
"max_memory_allocated (GB)": to_gb_rounded(mem_stats["MaxInUse"]),
"total_memory_available (GB)": to_gb_rounded(mem_stats["Limit"]),
}
return mem_dict
def set_seed(seed: int):
"""
Helper function for reproducible behavior to set the seed in `random`, `numpy` and `torch`.
Args:
seed (`int`): The seed to set.
"""
random.seed(seed)
np.random.seed(seed)
if is_torch_available():
from habana_frameworks.torch.hpu import random as hpu_random
torch.manual_seed(seed)
hpu_random.manual_seed_all(seed)
def check_synapse_version():
"""
Checks whether the versions of SynapseAI and drivers have been validated for the current version of Optimum Habana.
"""
# Change the logging format
logging.enable_default_handler()
logging.enable_explicit_format()
# Check the version of habana_frameworks
habana_frameworks_version_number = get_habana_frameworks_version()
if (
habana_frameworks_version_number.major != CURRENTLY_VALIDATED_SYNAPSE_VERSION.major
or habana_frameworks_version_number.minor != CURRENTLY_VALIDATED_SYNAPSE_VERSION.minor
):
logger.warning(
f"optimum-habana v{__version__} has been validated for SynapseAI v{CURRENTLY_VALIDATED_SYNAPSE_VERSION} but habana-frameworks v{habana_frameworks_version_number} was found, this could lead to undefined behavior!"
)
# Check driver version
driver_version = get_driver_version()
# This check is needed to make sure an error is not raised while building the documentation
# Because the doc is built on an instance that does not have `hl-smi`
if driver_version is not None:
if (
driver_version.major != CURRENTLY_VALIDATED_SYNAPSE_VERSION.major
or driver_version.minor != CURRENTLY_VALIDATED_SYNAPSE_VERSION.minor
):
logger.warning(
f"optimum-habana v{__version__} has been validated for SynapseAI v{CURRENTLY_VALIDATED_SYNAPSE_VERSION} but the driver version is v{driver_version}, this could lead to undefined behavior!"
)
else:
logger.warning(
"Could not run `hl-smi`, please follow the installation guide: https://docs.habana.ai/en/latest/Installation_Guide/index.html."
)
def get_habana_frameworks_version():
"""
Returns the installed version of SynapseAI.
"""
output = subprocess.run(
"pip list | grep habana-torch-plugin",
shell=True,
text=True,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
return version.parse(output.stdout.split("\n")[0].split()[-1])
def get_driver_version():
"""
Returns the driver version.
"""
# Enable console printing for `hl-smi` check
output = subprocess.run(
"hl-smi", shell=True, text=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env={"ENABLE_CONSOLE": "true"}
)
if output.returncode == 0 and output.stdout:
return version.parse(output.stdout.split("\n")[2].replace(" ", "").split(":")[1][:-1].split("-")[0])
return None
class HabanaGenerationtime(object):
def __init__(self, iteration_times: List[float] = None):
self.iteration_times = iteration_times
self.start_time = 0
self.end_time = 0
def start(self):
self.start_time = time.perf_counter()
def step(self):
self.end_time = time.perf_counter()
self.iteration_times.append(self.end_time - self.start_time)
self.start_time = self.end_time
class HabanaProfile(object):
"""
HPU profiler only could be run once, so HABANA_PROFILE_ENABLED, a class static variable shared by all the instances of HabanaProfile, is used to control which part will be captured.
"""
HABANA_PROFILE_ENABLED = True
def __init__(
self,
warmup: int = 0,
active: int = 0,
record_shapes: bool = True,
with_stack: bool = False,
output_dir: str = "./hpu_profile",
wait: int = 0,
):
if active <= 0 or warmup < 0 or not HabanaProfile.HABANA_PROFILE_ENABLED:
def noop():
pass
self.start = noop
self.stop = noop
self.step = noop
else:
HabanaProfile.HABANA_PROFILE_ENABLED = False
schedule = torch.profiler.schedule(wait=wait, warmup=warmup, active=active, repeat=1)
activities = [torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.HPU]
profiler = torch.profiler.profile(
schedule=schedule,
activities=activities,
on_trace_ready=torch.profiler.tensorboard_trace_handler(output_dir),
record_shapes=record_shapes,
with_stack=with_stack,
)
self.start = profiler.start
self.stop = profiler.stop
self.step = profiler.step
HabanaProfile.enable.invalid = True
HabanaProfile.disable.invalid = True
def stop(self):
self.stop()
def start(self):
self.start()
def step(self):
self.step()
@staticmethod
def disable():
"""
Runs only once and must happen before doing profiling.
"""
if hasattr(HabanaProfile.disable, "invalid"):
if not HabanaProfile.disable.invalid:
HabanaProfile.HABANA_PROFILE_ENABLED = False
else:
HabanaProfile.HABANA_PROFILE_ENABLED = False
@staticmethod
def enable():
"""
Runs only once and must happen before doing profiling.
"""
if hasattr(HabanaProfile.enable, "invalid"):
if not HabanaProfile.enable.invalid:
HabanaProfile.HABANA_PROFILE_ENABLED = True
else:
HabanaProfile.HABANA_PROFILE_ENABLED = True
def check_optimum_habana_min_version(min_version):
"""
Checks if the installed version of `optimum-habana` is larger than or equal to `min_version`.
Copied from: https://github.com/huggingface/transformers/blob/c41291965f078070c5c832412f5d4a5f633fcdc4/src/transformers/utils/__init__.py#L212
"""
if version.parse(__version__) < version.parse(min_version):
error_message = (
f"This example requires `optimum-habana` to have a minimum version of {min_version},"
f" but the version found is {__version__}.\n"
)
if "dev" in min_version:
error_message += (
"You can install it from source with: "
"`pip install git+https://github.com/huggingface/optimum-habana.git`."
)
raise ImportError(error_message)
def check_habana_frameworks_min_version(min_version):
"""
Checks if the installed version of `habana_frameworks` is larger than or equal to `min_version`.
"""
if get_habana_frameworks_version() < version.parse(min_version):
return False
else:
return True
def check_habana_frameworks_version(req_version):
"""
Checks if the installed version of `habana_frameworks` is equal to `req_version`.
"""
return (get_habana_frameworks_version().major == version.parse(req_version).major) and (
get_habana_frameworks_version().minor == version.parse(req_version).minor
)
def check_neural_compressor_min_version(req_version):
"""
Checks if the installed version of `neural_compressor` is larger than or equal to `req_version`.
"""
import neural_compressor
return version.Version(neural_compressor.__version__) >= version.Version(req_version)
def get_device_name():
"""
Returns the name of the current device: Gaudi or Gaudi2.
Inspired from: https://github.com/HabanaAI/Model-References/blob/a87c21f14f13b70ffc77617b9e80d1ec989a3442/PyTorch/computer_vision/classification/torchvision/utils.py#L274
"""
import habana_frameworks.torch.utils.experimental as htexp
device_type = htexp._get_device_type()
if device_type == htexp.synDeviceType.synDeviceGaudi:
return "gaudi"
elif device_type == htexp.synDeviceType.synDeviceGaudi2:
return "gaudi2"
else:
raise ValueError(f"Unsupported device: the device type is {device_type}.")