File tree 1 file changed +5
-10
lines changed
1 file changed +5
-10
lines changed Original file line number Diff line number Diff line change @@ -49,13 +49,15 @@ def __init__(
49
49
self ,
50
50
model : PyTorchIEModel ,
51
51
taskmodule : TaskModule ,
52
- # args_parser: ArgumentHandler = None,
53
- device : int = - 1 ,
52
+ device : Union [int , str ] = "cpu" ,
54
53
binary_output : bool = False ,
55
54
** kwargs ,
56
55
):
57
56
self .taskmodule = taskmodule
58
- self .device = torch .device ("cpu" if device < 0 else f"cuda:{ device } " )
57
+ device_str = (
58
+ ("cpu" if device < 0 else f"cuda:{ device } " ) if isinstance (device , int ) else device
59
+ )
60
+ self .device = torch .device (device_str )
59
61
self .binary_output = binary_output
60
62
61
63
# Module.to() returns just self, but moved to the device. This is not correctly
@@ -324,13 +326,6 @@ def __call__(
324
326
forward_params = {** self ._forward_params , ** forward_params }
325
327
postprocess_params = {** self ._postprocess_params , ** postprocess_params }
326
328
327
- self .call_count += 1
328
- if self .call_count > 10 and self .device .type == "cuda" :
329
- warnings .warn (
330
- "You seem to be using the pipelines sequentially on GPU. In order to maximize efficiency please use a dataset" ,
331
- UserWarning ,
332
- )
333
-
334
329
single_document = False
335
330
if isinstance (documents , Document ):
336
331
single_document = True
You can’t perform that action at this time.
0 commit comments