Skip to content

Commit f28f240

Browse files
authored
fix owlvit tests, update docstring examples (#18586)
1 parent 05d3a43 commit f28f240

File tree

3 files changed

+7
-8
lines changed

3 files changed

+7
-8
lines changed

docs/source/en/model_doc/owlvit.mdx

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,8 @@ OWL-ViT is a zero-shot text-conditioned object detection model. OWL-ViT uses [CL
5757
... box = [round(i, 2) for i in box.tolist()]
5858
... if score >= score_threshold:
5959
... print(f"Detected {text[label]} with confidence {round(score.item(), 3)} at location {box}")
60-
Detected a photo of a cat with confidence 0.243 at location [1.42, 50.69, 308.58, 370.48]
61-
Detected a photo of a cat with confidence 0.298 at location [348.06, 20.56, 642.33, 372.61]
60+
Detected a photo of a cat with confidence 0.707 at location [324.97, 20.44, 640.58, 373.29]
61+
Detected a photo of a cat with confidence 0.717 at location [1.46, 55.26, 315.55, 472.17]
6262
```
6363
6464
This model was contributed by [adirik](https://huggingface.co/adirik). The original code can be found [here](https://github.com/google-research/scenic/tree/main/scenic/projects/owl_vit).

src/transformers/models/owlvit/modeling_owlvit.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1323,8 +1323,8 @@ def forward(
13231323
... box = [round(i, 2) for i in box.tolist()]
13241324
... if score >= score_threshold:
13251325
... print(f"Detected {text[label]} with confidence {round(score.item(), 3)} at location {box}")
1326-
Detected a photo of a cat with confidence 0.243 at location [1.42, 50.69, 308.58, 370.48]
1327-
Detected a photo of a cat with confidence 0.298 at location [348.06, 20.56, 642.33, 372.61]
1326+
Detected a photo of a cat with confidence 0.707 at location [324.97, 20.44, 640.58, 373.29]
1327+
Detected a photo of a cat with confidence 0.717 at location [1.46, 55.26, 315.55, 472.17]
13281328
```"""
13291329
output_hidden_states = (
13301330
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states

tests/models/owlvit/test_modeling_owlvit.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -733,7 +733,6 @@ def prepare_img():
733733

734734
@require_vision
735735
@require_torch
736-
@unittest.skip("These tests are broken, fix me Alara")
737736
class OwlViTModelIntegrationTest(unittest.TestCase):
738737
@slow
739738
def test_inference(self):
@@ -763,8 +762,7 @@ def test_inference(self):
763762
outputs.logits_per_text.shape,
764763
torch.Size((inputs.input_ids.shape[0], inputs.pixel_values.shape[0])),
765764
)
766-
expected_logits = torch.tensor([[4.4420, 0.6181]], device=torch_device)
767-
765+
expected_logits = torch.tensor([[3.4613, 0.9403]], device=torch_device)
768766
self.assertTrue(torch.allclose(outputs.logits_per_image, expected_logits, atol=1e-3))
769767

770768
@slow
@@ -788,7 +786,8 @@ def test_inference_object_detection(self):
788786

789787
num_queries = int((model.config.vision_config.image_size / model.config.vision_config.patch_size) ** 2)
790788
self.assertEqual(outputs.pred_boxes.shape, torch.Size((1, num_queries, 4)))
789+
791790
expected_slice_boxes = torch.tensor(
792-
[[0.0948, 0.0471, 0.1915], [0.3194, 0.0583, 0.6498], [0.1441, 0.0452, 0.2197]]
791+
[[0.0691, 0.0445, 0.1373], [0.1592, 0.0456, 0.3192], [0.1632, 0.0423, 0.2478]]
793792
).to(torch_device)
794793
self.assertTrue(torch.allclose(outputs.pred_boxes[0, :3, :3], expected_slice_boxes, atol=1e-4))

0 commit comments

Comments
 (0)