Skip to content

Commit 7c63106

Browse files
authored
allow string argument for Pipeline device and remove multi call warning (#435)
* allow string argument for pipeline device * remove warning when called multiple times on gpu * fix * make pre-commit happy
1 parent 3145b17 commit 7c63106

File tree

1 file changed

+5
-10
lines changed

1 file changed

+5
-10
lines changed

src/pytorch_ie/pipeline.py

+5-10
Original file line numberDiff line numberDiff line change
@@ -49,13 +49,15 @@ def __init__(
4949
self,
5050
model: PyTorchIEModel,
5151
taskmodule: TaskModule,
52-
# args_parser: ArgumentHandler = None,
53-
device: int = -1,
52+
device: Union[int, str] = "cpu",
5453
binary_output: bool = False,
5554
**kwargs,
5655
):
5756
self.taskmodule = taskmodule
58-
self.device = torch.device("cpu" if device < 0 else f"cuda:{device}")
57+
device_str = (
58+
("cpu" if device < 0 else f"cuda:{device}") if isinstance(device, int) else device
59+
)
60+
self.device = torch.device(device_str)
5961
self.binary_output = binary_output
6062

6163
# Module.to() returns just self, but moved to the device. This is not correctly
@@ -324,13 +326,6 @@ def __call__(
324326
forward_params = {**self._forward_params, **forward_params}
325327
postprocess_params = {**self._postprocess_params, **postprocess_params}
326328

327-
self.call_count += 1
328-
if self.call_count > 10 and self.device.type == "cuda":
329-
warnings.warn(
330-
"You seem to be using the pipelines sequentially on GPU. In order to maximize efficiency please use a dataset",
331-
UserWarning,
332-
)
333-
334329
single_document = False
335330
if isinstance(documents, Document):
336331
single_document = True

0 commit comments

Comments
 (0)