Skip to content

Commit 5657b17

Browse files
committed
fix torch.onnx.symbolic_opset12 import
1 parent 01e83e1 commit 5657b17

File tree

3 files changed

+9
-3
lines changed

3 files changed

+9
-3
lines changed

src/transformers/models/deberta/modeling_deberta.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,8 @@ def backward(ctx, grad_output):
187187

188188
@staticmethod
189189
def symbolic(g: torch._C.Graph, input: torch._C.Value, local_ctx: Union[float, DropoutContext]) -> torch._C.Value:
190+
from torch.onnx import symbolic_opset12
191+
190192
dropout_p = local_ctx
191193
if isinstance(local_ctx, DropoutContext):
192194
dropout_p = local_ctx.dropout
@@ -198,7 +200,7 @@ def symbolic(g: torch._C.Graph, input: torch._C.Value, local_ctx: Union[float, D
198200
# Once https://github.com/pytorch/pytorch/issues/78391 is fixed, do something like:
199201
# if opset_version < 12:
200202
# return torch.onnx.symbolic_opset9.dropout(g, input, dropout_p, train)
201-
return torch.onnx.symbolic_opset12.dropout(g, input, dropout_p, train)
203+
return symbolic_opset12.dropout(g, input, dropout_p, train)
202204

203205

204206
class StableDropout(nn.Module):

src/transformers/models/deberta_v2/modeling_deberta_v2.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,8 @@ def backward(ctx, grad_output):
193193

194194
@staticmethod
195195
def symbolic(g: torch._C.Graph, input: torch._C.Value, local_ctx: Union[float, DropoutContext]) -> torch._C.Value:
196+
from torch.onnx import symbolic_opset12
197+
196198
dropout_p = local_ctx
197199
if isinstance(local_ctx, DropoutContext):
198200
dropout_p = local_ctx.dropout
@@ -204,7 +206,7 @@ def symbolic(g: torch._C.Graph, input: torch._C.Value, local_ctx: Union[float, D
204206
# Once https://github.com/pytorch/pytorch/issues/78391 is fixed, do something like:
205207
# if opset_version < 12:
206208
# return torch.onnx.symbolic_opset9.dropout(g, input, dropout_p, train)
207-
return torch.onnx.symbolic_opset12.dropout(g, input, dropout_p, train)
209+
return symbolic_opset12.dropout(g, input, dropout_p, train)
208210

209211

210212
# Copied from transformers.models.deberta.modeling_deberta.StableDropout

src/transformers/models/sew_d/modeling_sew_d.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -597,6 +597,8 @@ def backward(ctx, grad_output):
597597

598598
@staticmethod
599599
def symbolic(g: torch._C.Graph, input: torch._C.Value, local_ctx: Union[float, DropoutContext]) -> torch._C.Value:
600+
from torch.onnx import symbolic_opset12
601+
600602
dropout_p = local_ctx
601603
if isinstance(local_ctx, DropoutContext):
602604
dropout_p = local_ctx.dropout
@@ -608,7 +610,7 @@ def symbolic(g: torch._C.Graph, input: torch._C.Value, local_ctx: Union[float, D
608610
# Once https://github.com/pytorch/pytorch/issues/78391 is fixed, do something like:
609611
# if opset_version < 12:
610612
# return torch.onnx.symbolic_opset9.dropout(g, input, dropout_p, train)
611-
return torch.onnx.symbolic_opset12.dropout(g, input, dropout_p, train)
613+
return symbolic_opset12.dropout(g, input, dropout_p, train)
612614

613615

614616
# Copied from transformers.models.deberta.modeling_deberta.StableDropout

0 commit comments

Comments
 (0)