Skip to content

Commit cd9ae39

Browse files
guyueh1NeMo Bot
authored andcommitted
Align adapter shape with base linear output shape (#14009)
* Align adapter shape with base linear output shape Signed-off-by: Guyue Huang <[email protected]> * Apply isort and black reformatting Signed-off-by: guyueh1 <[email protected]> * Align adapter output shape to linear output shape in LoRALinearSplitFC1UpGate Signed-off-by: guyueh1 <[email protected]> --------- Signed-off-by: Guyue Huang <[email protected]> Signed-off-by: guyueh1 <[email protected]> Signed-off-by: guyueh1 <[email protected]> Co-authored-by: guyueh1 <[email protected]>
1 parent 4c6fb0c commit cd9ae39

File tree

2 files changed

+6
-3
lines changed

2 files changed

+6
-3
lines changed

nemo/collections/llm/peft/canonical_lora.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def forward(self, x):
8181
qkv_4d = torch.cat([query_4d, key_4d, value_4d], dim=2)
8282
adapter_output = qkv_4d.reshape(qkv_4d.shape[0], qkv_4d.shape[1], -1)
8383

84-
return linear_output + adapter_output, bias
84+
return linear_output + adapter_output.reshape(linear_output.shape), bias
8585

8686

8787
class LoRALinearSplitFC1UpGate(AdapterWrapper):
@@ -99,7 +99,7 @@ def forward(self, x):
9999
adapter_output_gate = self.adapter.adapter_gate(layernorm_output)
100100
adapter_output_up = self.adapter.adapter_up(layernorm_output)
101101
adapter_output = torch.cat([adapter_output_gate, adapter_output_up], dim=2)
102-
return linear_output + adapter_output, bias
102+
return linear_output + adapter_output.reshape(linear_output.shape), bias
103103

104104

105105
@dataclass

nemo/collections/llm/peft/dora.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,10 @@ def forward(self, x):
120120
self.adapter.dropout(layernorm_output) - layernorm_output
121121
)[0]
122122

123-
return mag_norm_scale * (linear_output + adapter_output) + dropout_correction, bias
123+
return (
124+
mag_norm_scale * (linear_output + adapter_output.reshape(linear_output.shape)) + dropout_correction,
125+
bias,
126+
)
124127

125128

126129
@dataclass

0 commit comments

Comments
 (0)