Skip to content

Commit e451865

Browse files
committed
Don't padd input
1 parent d35273d commit e451865

File tree

1 file changed

+0
-6
lines changed

1 file changed

+0
-6
lines changed

src/accelerate/inference.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -85,12 +85,6 @@ def build_pipeline(model, split_points, args, kwargs, num_chunks):
8585
# We need to annotate the split points in the model for PiPPy
8686
state = PartialState()
8787
split_spec = {split_point: SplitPoint.BEGINNING for split_point in split_points}
88-
found_batch_size = find_pippy_batch_size(args, kwargs)
89-
if found_batch_size != num_chunks:
90-
if args is not None:
91-
args = pad_input_tensors(args, found_batch_size, num_chunks)
92-
if kwargs is not None:
93-
kwargs = pad_input_tensors(kwargs, found_batch_size, num_chunks)
9488
pipe = pipeline(
9589
model,
9690
mb_args=args,

0 commit comments

Comments
 (0)