diff --git a/nemo/collections/llm/peft/canonical_lora.py b/nemo/collections/llm/peft/canonical_lora.py index 954f174fd3e1..40c86253017e 100644 --- a/nemo/collections/llm/peft/canonical_lora.py +++ b/nemo/collections/llm/peft/canonical_lora.py @@ -81,7 +81,7 @@ def forward(self, x): qkv_4d = torch.cat([query_4d, key_4d, value_4d], dim=2) adapter_output = qkv_4d.reshape(qkv_4d.shape[0], qkv_4d.shape[1], -1) - return linear_output + adapter_output, bias + return linear_output + adapter_output.reshape(linear_output.shape), bias class LoRALinearSplitFC1UpGate(AdapterWrapper): @@ -99,7 +99,7 @@ def forward(self, x): adapter_output_gate = self.adapter.adapter_gate(layernorm_output) adapter_output_up = self.adapter.adapter_up(layernorm_output) adapter_output = torch.cat([adapter_output_gate, adapter_output_up], dim=2) - return linear_output + adapter_output, bias + return linear_output + adapter_output.reshape(linear_output.shape), bias @dataclass diff --git a/nemo/collections/llm/peft/dora.py b/nemo/collections/llm/peft/dora.py index e96c20fadf0c..1a4c15473230 100644 --- a/nemo/collections/llm/peft/dora.py +++ b/nemo/collections/llm/peft/dora.py @@ -120,7 +120,10 @@ def forward(self, x): self.adapter.dropout(layernorm_output) - layernorm_output )[0] - return mag_norm_scale * (linear_output + adapter_output) + dropout_correction, bias + return ( + mag_norm_scale * (linear_output + adapter_output.reshape(linear_output.shape)) + dropout_correction, + bias, + ) @dataclass