forked from deepmodeling/deepmd-kit
-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathbatch_size.py
40 lines (34 loc) · 1.06 KB
/
batch_size.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
# SPDX-License-Identifier: LGPL-3.0-or-later
from packaging.version import (
Version,
)
from deepmd.tf.env import (
TF_VERSION,
tf,
)
from deepmd.tf.utils.errors import (
OutOfMemoryError,
)
from deepmd.utils.batch_size import AutoBatchSize as AutoBatchSizeBase
class AutoBatchSize(AutoBatchSizeBase):
def is_gpu_available(self) -> bool:
"""Check if GPU is available.
Returns
-------
bool
True if GPU is available
"""
return (
Version(TF_VERSION) >= Version("1.14")
and tf.config.experimental.get_visible_devices("GPU")
) or tf.test.is_gpu_available()
def is_oom_error(self, e: Exception) -> bool:
"""Check if the exception is an OOM error.
Parameters
----------
e : Exception
Exception
"""
# TODO: it's very slow to catch OOM error; I don't know what TF is doing here
# but luckily we only need to catch once
return isinstance(e, (tf.errors.ResourceExhaustedError, OutOfMemoryError))