Skip to content

Commit 5d46bab

Browse files
authored
llama : initial Mamba-2 support (ggml-org#9126)
* llama : initial Mamba-2 support * ggml : SIMD ggml_ssm_scan for Mamba-2 * ggml : improve ggml_mul speed when masking recurrent states * llama : support running Mamba-Codestral-7B-v0.1 * llama : fix Mamba-2 conv state saving * ggml : make the ggml_mul fast broadcast path more consistently formatted * llama : remove unused variable * llama : add missing break * convert_hf : prefer SentencePiece tokenizer for Mamba-2 when present The tokenzier.json of Mamba-Codestral-7B-v0.1 otherwise requires workarounds to work correctly. * llama : avoid redundant state copy for Mamba 1 and 2 * metal : attempt to adapt SSM_SCAN for Mamba-2 * metal : fix SSM_SCAN pipeline scope * metal : use log and exp instead of log1pf and expf in SSM_SCAN * metal : remove unused arguments for SSM_SCAN The max index is 31, so trimming the arguments is necessary. * metal : add back n_seqs to SSM_SCAN args Whoops, this is needed for the offset in the concatenated output. * metal : fix SSM_SCAN state head offset * metal : fix wrong number of tokens per sequence in SSM_SCAN * ggml : remove unused fast broadcast path in GGML_MUL This was initially added because states were masked with ggml_mul, but this is no longer done and so this "optimisation" is no longer necessary, or at least not worth the additional code complexity. * ggml : avoid multiply by D in GGML_OP_SSM_SCAN This makes the weight buft detection in src/llama.cpp simpler. * convert : transpose Mamba-2 A, D and reshape SSM_NORM This breaks existing conversions of Mamba-2 models to avoid some reshapes. Not sure if it's a good idea, but it makes the graph slightly cleaner. * llama : more appropriate SSM_SCAN and SSM_CONV buft support checks * convert : fix flake8 lint * metal : fix confusion between ; and , * metal : add missing args for nb references in ssm_scan_f32_group * metal : single-user mamba2 inference works * kv-cache : remove const_cast when setting inputs for s_copy And also fix multi-user inference for recurrent models by using cell_id instead of i as the kv cell index when populating s_copy. * convert : avoid AutoConfig for Mamba and Mamba2 hparams * kv-cache : allow context shift for recurrent models * graph : fix recurrent state copies when avoiding copies Works, but using lambda functions might not be that clean. * ggml : fix mamba2 ssm scan when compiled with SVE * ggml-cpu : reorder SVE FMA for consistency with other SIMD arches * cuda : implement ssm scan for Mamba2 There is still room for improvement, but it works! * cuda : adapt Mamba1 ssm scan to shape changes from Mamba2 * mamba : fix mismatched new and delete size for llm_build_mamba Subclasses of llm_graph_context cannot have extra fields, because the called destructor is not the one from the subclass. This otherwise would cause problems when runnning Mamba-(1|2) inference when compiled -DGGML_SANITIZE_ADDRESS=ON * cuda : graceful fallback for Mamba-1 models with weird embd size
1 parent e17991c commit 5d46bab

24 files changed

+1083
-319
lines changed

convert_hf_to_gguf.py

Lines changed: 111 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4781,6 +4781,14 @@ def set_gguf_parameters(self):
47814781
class MambaModel(TextModel):
47824782
model_arch = gguf.MODEL_ARCH.MAMBA
47834783

4784+
def __init__(self, dir_model: Path, *args, **kwargs):
4785+
# Avoid using AutoConfig for hparams
4786+
hparams = kwargs.pop("hparams", None)
4787+
if hparams is None:
4788+
with open(dir_model / "config.json", "r", encoding="utf-8") as f:
4789+
hparams = json.load(f)
4790+
super().__init__(dir_model, *args, hparams=hparams, **kwargs)
4791+
47844792
def set_vocab(self):
47854793
vocab_size = self.hparams["vocab_size"]
47864794
# Round vocab size to next multiple of 8
@@ -4855,6 +4863,100 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
48554863
return [(new_name, data_torch)]
48564864

48574865

4866+
@ModelBase.register("Mamba2ForCausalLM")
4867+
class Mamba2Model(TextModel):
4868+
model_arch = gguf.MODEL_ARCH.MAMBA2
4869+
4870+
def __init__(self, dir_model: Path, *args, **kwargs):
4871+
# Avoid using AutoConfig for hparams
4872+
# It wrongly assumes all Mamba2 models are Mamba-Codestral-7B-v0.1
4873+
hparams = kwargs.pop("hparams", None)
4874+
if hparams is None:
4875+
with open(dir_model / "config.json", "r", encoding="utf-8") as f:
4876+
hparams = json.load(f)
4877+
super().__init__(dir_model, *args, hparams=hparams, **kwargs)
4878+
4879+
def set_vocab(self):
4880+
vocab_size = self.hparams["vocab_size"]
4881+
# Round vocab size to next multiple of 16
4882+
pad_vocab = self.hparams.get("pad_vocab_size_multiple", 16)
4883+
# pad using ceiling division
4884+
# ref: https://stackoverflow.com/a/17511341/22827863
4885+
vocab_size = -(vocab_size // -pad_vocab) * pad_vocab
4886+
self.hparams["vocab_size"] = vocab_size
4887+
4888+
if (self.dir_model / "tokenizer.model").is_file():
4889+
self._set_vocab_sentencepiece()
4890+
elif (self.dir_model / "tokenizer.model.v3").is_file():
4891+
# mamba-codestral
4892+
raise NotImplementedError(f"Please rename {self.dir_model / 'tokenizer.model.v3'} to {self.dir_model / 'tokenizer.model'}")
4893+
elif (self.dir_model / "tokenizer.json").is_file():
4894+
self._set_vocab_gpt2()
4895+
else:
4896+
# Use the GPT-NeoX tokenizer when no tokenizer files are present
4897+
self._set_vocab_builtin("gpt-neox", vocab_size)
4898+
4899+
def set_gguf_parameters(self):
4900+
d_model = self.find_hparam(["hidden_size", "d_model", "dim"])
4901+
d_conv = self.find_hparam(["conv_kernel", "d_conv"], optional=True) or 4
4902+
d_inner = self.find_hparam(["intermediate_size", "d_inner"], optional=True) or 2 * d_model
4903+
d_state = self.find_hparam(["state_size", "d_state"], optional=True) or 128
4904+
head_dim = self.find_hparam(["head_dim"], optional=True) or 64
4905+
n_group = self.find_hparam(["n_groups"], optional=True) or 1
4906+
4907+
rms_norm_eps = self.find_hparam(["layer_norm_epsilon", "rms_norm_eps"], optional=True) or 1e-5
4908+
4909+
# Fail early for models which don't have a block expansion factor of 2
4910+
# TODO: does this really matter?
4911+
assert d_inner == 2 * d_model
4912+
assert d_inner % head_dim == 0
4913+
4914+
self.gguf_writer.add_context_length(2**20) # arbitrary value; for those who use the default
4915+
self.gguf_writer.add_embedding_length(d_model)
4916+
self.gguf_writer.add_feed_forward_length(0) # unused, but seemingly required when loading
4917+
self.gguf_writer.add_head_count(0) # unused, but seemingly required when loading
4918+
self.gguf_writer.add_block_count(self.block_count)
4919+
self.gguf_writer.add_ssm_conv_kernel(d_conv)
4920+
self.gguf_writer.add_ssm_inner_size(d_inner)
4921+
self.gguf_writer.add_ssm_state_size(d_state)
4922+
self.gguf_writer.add_ssm_time_step_rank(d_inner // head_dim)
4923+
self.gguf_writer.add_ssm_group_count(n_group)
4924+
self.gguf_writer.add_layer_norm_rms_eps(rms_norm_eps)
4925+
self.gguf_writer.add_file_type(self.ftype)
4926+
4927+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
4928+
4929+
if name.startswith("model.backbone") or name.startswith("model.lm_head"):
4930+
# map Mamba-Codestral-7B-v0.1 tensor names to the names used by Mamba-2
4931+
name = name.removeprefix("model.")
4932+
4933+
if name.endswith(".dt_bias"):
4934+
name = name.rpartition(".dt_bias")[0] + ".dt_proj.bias"
4935+
4936+
new_name = self.map_tensor_name(name)
4937+
4938+
if self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.SSM_CONV1D, bid):
4939+
data_torch = data_torch.squeeze()
4940+
elif any(self.match_model_tensor_name(new_name, t, bid, suffix="") for t in [
4941+
gguf.MODEL_TENSOR.SSM_A,
4942+
gguf.MODEL_TENSOR.SSM_D,
4943+
]):
4944+
# unsqueeze A to use similar shape semantics as Mamba-1
4945+
# (D is also unsqueezed, but for more straightforward broadcast internally)
4946+
data_torch = data_torch.reshape((*data_torch.shape, 1))
4947+
elif self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.SSM_NORM, bid):
4948+
d_model = self.find_hparam(["hidden_size", "d_model", "dim"])
4949+
d_inner = self.find_hparam(["intermediate_size", "d_inner"], optional=True) or 2 * d_model
4950+
n_group = self.hparams.get("n_groups", 1)
4951+
data_torch = data_torch.reshape((n_group, d_inner // n_group))
4952+
4953+
if name.endswith(".A_log"):
4954+
logger.debug("A_log --> A ==> " + new_name)
4955+
data_torch = -torch.exp(data_torch)
4956+
4957+
yield (new_name, data_torch)
4958+
4959+
48584960
@ModelBase.register("CohereForCausalLM")
48594961
class CommandR2Model(TextModel):
48604962
model_arch = gguf.MODEL_ARCH.COMMAND_R
@@ -6615,12 +6717,20 @@ def get_model_architecture(hparams: dict[str, Any], model_type: ModelType) -> st
66156717
# maybe we should fallback to text model's arch in that case, since not many models have both
66166718
text_config = hparams.get("text_config", {})
66176719
vision_config = hparams.get("vision_config", {})
6618-
arch = hparams["architectures"][0]
6720+
arch = None
6721+
if (arches := hparams.get("architectures")) is not None and len(arches) > 0:
6722+
arch = arches[0]
6723+
elif "ssm_cfg" in hparams:
6724+
# For non-hf Mamba and Mamba2 models
6725+
arch = hparams["ssm_cfg"].get("layer", "Mamba") + "ForCausalLM"
6726+
66196727
# if "architectures" is found in the sub-config, use that instead
66206728
if model_type == ModelType.TEXT and text_config.get("architectures") is not None:
66216729
arch = text_config["architectures"][0]
66226730
elif model_type == ModelType.MMPROJ and vision_config.get("architectures") is not None:
66236731
arch = vision_config["architectures"][0]
6732+
if arch is None:
6733+
raise ValueError("Failed to detect model architecture")
66246734
return arch
66256735

66266736

ggml/include/ggml.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2031,7 +2031,8 @@ extern "C" {
20312031
struct ggml_tensor * dt,
20322032
struct ggml_tensor * A,
20332033
struct ggml_tensor * B,
2034-
struct ggml_tensor * C);
2034+
struct ggml_tensor * C,
2035+
struct ggml_tensor * ids);
20352036

20362037
// partition into non-overlapping windows with padding if needed
20372038
// example:

0 commit comments

Comments
 (0)