-
Notifications
You must be signed in to change notification settings - Fork 314
/
Copy pathauto_model.py
38 lines (26 loc) · 1.51 KB
/
auto_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import inspect
from transformers import AutoConfig
from transformers import AutoModelForCausalLM
from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN
from liger_kernel.transformers.monkey_patch import _apply_liger_kernel
def _get_model_config(model_dir, **model_init_kwargs):
config = AutoConfig.from_pretrained(model_dir, **model_init_kwargs)
return config
class AutoLigerKernelForCausalLM(AutoModelForCausalLM):
"""
This class is a drop-in replacement for AutoModelForCausalLM that applies the Liger Kernel to the model
if applicable.
"""
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
model_config = _get_model_config(pretrained_model_name_or_path, **kwargs)
# Determine the model type and apply the Liger Kernel if applicable
# Note: _apply_liger_kernel will only pass relevant kwargs to the apply_liger_kernel_to_* function
model_type = model_config.model_type
_apply_liger_kernel(model_type, **kwargs)
# Filter out kwargs that were passed to the apply_liger_* function, which will cause
# model initialization errors otherwise
apply_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[model_type]
apply_fn_signature = inspect.signature(apply_fn)
applicable_kwargs = {key: value for key, value in kwargs.items() if key not in apply_fn_signature.parameters}
return super().from_pretrained(pretrained_model_name_or_path, *model_args, **applicable_kwargs)