diff --git a/src/transformers/pipelines/feature_extraction.py b/src/transformers/pipelines/feature_extraction.py index c7d50c971955..48f7735b6ce0 100644 --- a/src/transformers/pipelines/feature_extraction.py +++ b/src/transformers/pipelines/feature_extraction.py @@ -31,6 +31,8 @@ class FeatureExtractionPipeline(Pipeline): If no framework is specified, will default to the one currently installed. If no framework is specified and both frameworks are installed, will default to the framework of the `model`, or to PyTorch if no model is provided. + return_tensor (`bool`, *optional*): + If `True`, returns a tensor according to the specified framework, otherwise returns a list. task (`str`, defaults to `""`): A task-identifier for the pipeline. args_parser ([`~pipelines.ArgumentHandler`], *optional*): @@ -40,7 +42,7 @@ class FeatureExtractionPipeline(Pipeline): the associated CUDA device id. """ - def _sanitize_parameters(self, truncation=None, tokenize_kwargs=None, **kwargs): + def _sanitize_parameters(self, truncation=None, tokenize_kwargs=None, return_tensors=None, **kwargs): if tokenize_kwargs is None: tokenize_kwargs = {} @@ -53,7 +55,11 @@ def _sanitize_parameters(self, truncation=None, tokenize_kwargs=None, **kwargs): preprocess_params = tokenize_kwargs - return preprocess_params, {}, {} + postprocess_params = {} + if return_tensors is not None: + postprocess_params["return_tensors"] = return_tensors + + return preprocess_params, {}, postprocess_params def preprocess(self, inputs, **tokenize_kwargs) -> Dict[str, GenericTensor]: return_tensors = self.framework @@ -64,8 +70,10 @@ def _forward(self, model_inputs): model_outputs = self.model(**model_inputs) return model_outputs - def postprocess(self, model_outputs): + def postprocess(self, model_outputs, return_tensors=False): # [0] is the first available tensor, logits or last_hidden_state. + if return_tensors: + return model_outputs[0] if self.framework == "pt": return model_outputs[0].tolist() elif self.framework == "tf": diff --git a/tests/pipelines/test_pipelines_feature_extraction.py b/tests/pipelines/test_pipelines_feature_extraction.py index f75af6808bcc..2e431fa1d486 100644 --- a/tests/pipelines/test_pipelines_feature_extraction.py +++ b/tests/pipelines/test_pipelines_feature_extraction.py @@ -15,6 +15,8 @@ import unittest import numpy as np +import tensorflow as tf +import torch from transformers import ( FEATURE_EXTRACTOR_MAPPING, @@ -133,6 +135,22 @@ def test_tokenization_small_model_tf(self): tokenize_kwargs=tokenize_kwargs, ) + @require_torch + def test_return_tensors_pt(self): + feature_extractor = pipeline( + task="feature-extraction", model="hf-internal-testing/tiny-random-distilbert", framework="pt" + ) + outputs = feature_extractor("This is a test" * 100, return_tensors=True) + self.assertTrue(torch.is_tensor(outputs)) + + @require_tf + def test_return_tensors_tf(self): + feature_extractor = pipeline( + task="feature-extraction", model="hf-internal-testing/tiny-random-distilbert", framework="tf" + ) + outputs = feature_extractor("This is a test" * 100, return_tensors=True) + self.assertTrue(tf.is_tensor(outputs)) + def get_shape(self, input_, shape=None): if shape is None: shape = []