Skip to content

Commit 66c6e7e

Browse files
authored
Fix point interpolation (#1344)
* Extend formats tests with different track types * Add unordered list comparison * Skip empty list comparison * fix * fix * Reproduce problem * Fix point interpolation for single point * undo rest api refactor
1 parent 1feeef6 commit 66c6e7e

File tree

2 files changed

+43
-1
lines changed

2 files changed

+43
-1
lines changed

cvat/apps/engine/data_manager.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,10 @@ def _modify_unmached_object(obj, end_frame):
293293

294294
@staticmethod
295295
def normalize_shape(shape):
296-
points = np.asarray(shape["points"]).reshape(-1, 2)
296+
points = list(shape["points"])
297+
if len(points) == 2:
298+
points.extend(points) # duplicate points for single point case
299+
points = np.asarray(points).reshape(-1, 2)
297300
broken_line = geometry.LineString(points)
298301
points = []
299302
for off in range(0, 100, 1):
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# Copyright (C) 2020 Intel Corporation
2+
#
3+
# SPDX-License-Identifier: MIT
4+
5+
from cvat.apps.engine.data_manager import TrackManager
6+
7+
from unittest import TestCase
8+
9+
10+
class TrackManagerTest(TestCase):
11+
def test_single_point_interpolation(self):
12+
track = {
13+
"frame": 0,
14+
"label_id": 0,
15+
"group": None,
16+
"attributes": [],
17+
"shapes": [
18+
{
19+
"frame": 0,
20+
"points": [1.0, 2.0],
21+
"type": "points",
22+
"occluded": False,
23+
"outside": False,
24+
"attributes": []
25+
},
26+
{
27+
"frame": 2,
28+
"attributes": [],
29+
"points": [3.0, 4.0, 5.0, 6.0],
30+
"type": "points",
31+
"occluded": False,
32+
"outside": True
33+
},
34+
]
35+
}
36+
37+
interpolated = TrackManager.get_interpolated_shapes(track, 0, 2)
38+
39+
self.assertEqual(len(interpolated), 3)

0 commit comments

Comments
 (0)