Skip to content

Commit 3d533b9

Browse files
authored
Convert Cuboid2D to/from KITTI 3D data (#1639)
<!-- Contributing guide: https://github.com/openvinotoolkit/datumaro/blob/develop/CONTRIBUTING.md --> ### Summary CVS-151427 #### New features - New `Cuboid2D` methods: - `Cuboid2D.from_3d(dimensions, location, rotation_y, P, Tr_velo_to_cam)`: Creates a Cuboid2D object from KITTI 3D bbox annotation data. Matrix `P` (`P2` in Kitti format context) is a 3x4 projection matrix in the left color camera coordinate system. Matrix `Tr_velo_to_cam` is a 3x4 projection matrix between LiDAR and camera coordinate systems. - `cuboid_2d.to_3d(P_inv)`: Reconstructs approximate KITTI 3D bbox annotation data (`dimensions`, `location` and `rotation_y`) from 2D projection coordinates. `P_inv` matrix is a [pseudo-inverse](https://en.wikipedia.org/wiki/Moore%E2%80%93Penrose_inverse) of camera-to-image projection matrix. <!-- Resolves #111 and #222. Depends on #1000 (for series of dependent commits). This PR introduces this capability to make the project better in this and that. - Added this feature - Removed that feature - Fixed the problem #1234 --> ### How to test See unit test changes ### Checklist <!-- Put an 'x' in all the boxes that apply --> - [x] I have added unit tests to cover my changes.​ - [ ] I have added integration tests to cover my changes.​ - [x] I have added the description of my changes into [CHANGELOG](https://github.com/openvinotoolkit/datumaro/blob/develop/CHANGELOG.md).​ - [ ] I have updated the [documentation](https://github.com/openvinotoolkit/datumaro/tree/develop/docs) accordingly ### License - [x] I submit _my code changes_ under the same [MIT License](https://github.com/openvinotoolkit/datumaro/blob/develop/LICENSE) that covers the project. Feel free to contact the maintainers if that's a concern. - [ ] I have updated the license header for each file (see an example below). ```python # Copyright (C) 2024 Intel Corporation # # SPDX-License-Identifier: MIT ``` --------- Signed-off-by: Ilya Trushkin <[email protected]>
1 parent eb19963 commit 3d533b9

File tree

4 files changed

+259
-9
lines changed

4 files changed

+259
-9
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1212
(<https://github.com/openvinotoolkit/datumaro/pull/1619>)
1313
- Add PseudoLabeling transform for unlabeled dataset
1414
(<https://github.com/openvinotoolkit/datumaro/pull/1594>)
15+
- Convert Cuboid2D annotation to/from 3D data
16+
(<https://github.com/openvinotoolkit/datumaro/pull/1639>)
1517

1618
### Enhancements
1719
- Enhance 'id_from_image_name' transform to ensure each identifier is unique

src/datumaro/components/annotation.py

Lines changed: 191 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
)
2525

2626
import attr
27+
import cv2
2728
import numpy as np
2829
import shapely.geometry as sg
2930
from attr import asdict, attrs, field
@@ -1372,19 +1373,19 @@ class Cuboid2D(Annotation):
13721373
[(x1, y1), (x2, y2), (x3, y3), (x4, y4), (x5, y5), (x6, y6), (x7, y7), (x8, y8)].
13731374
13741375
1375-
6---7
1376+
2---3
13761377
/| /|
1377-
5-+-8 |
1378-
| 2 + 3
1378+
1-+-4 |
1379+
| 5 + 6
13791380
|/ |/
1380-
1---4
1381+
8---7
13811382
13821383
Attributes:
1383-
_type (AnnotationType): The type of annotation, set to `AnnotationType.bbox`.
1384+
_type (AnnotationType): The type of annotation, set to `AnnotationType.cuboid_2d`.
13841385
13851386
Methods:
13861387
__init__: Initializes the Cuboid2D with its coordinates.
1387-
wrap: Creates a new Bbox instance with updated attributes.
1388+
wrap: Creates a new Cuboid2D instance with updated attributes.
13881389
"""
13891390

13901391
_type = AnnotationType.cuboid_2d
@@ -1393,11 +1394,194 @@ class Cuboid2D(Annotation):
13931394
converter=attr.converters.optional(int), default=None, kw_only=True
13941395
)
13951396
z_order: int = field(default=0, validator=default_if_none(int), kw_only=True)
1397+
y_3d: float = field(default=None, validator=default_if_none(float), kw_only=True)
13961398

1397-
def __init__(self, _points: Iterable[Tuple[float, float]], *args, **kwargs):
1399+
def __init__(
1400+
self,
1401+
_points: Iterable[Tuple[float, float]],
1402+
*args,
1403+
**kwargs,
1404+
):
13981405
kwargs.pop("points", None) # comes from wrap()
13991406
self.__attrs_init__(points=_points, *args, **kwargs)
14001407

1408+
@staticmethod
1409+
def _get_plane_equation(points):
1410+
"""Calculates coefficients of the plane equation from three points."""
1411+
x1, y1, z1 = points[0, 0], points[0, 1], points[0, 2]
1412+
x2, y2, z2 = points[1, 0], points[1, 1], points[1, 2]
1413+
x3, y3, z3 = points[2, 0], points[2, 1], points[2, 2]
1414+
a1 = x2 - x1
1415+
b1 = y2 - y1
1416+
c1 = z2 - z1
1417+
a2 = x3 - x1
1418+
b2 = y3 - y1
1419+
c2 = z3 - z1
1420+
a = b1 * c2 - b2 * c1
1421+
b = a2 * c1 - a1 * c2
1422+
c = a1 * b2 - b1 * a2
1423+
d = -a * x1 - b * y1 - c * z1
1424+
return np.array([a, b, c, d])
1425+
1426+
@staticmethod
1427+
def _get_denorm(Tr_velo_to_cam_homo):
1428+
"""Calculates the denormalized vector perpendicular to the image plane.
1429+
Args:
1430+
Tr_velo_to_cam_homo (np.ndarray): Homogeneous (4x4) LiDAR-to-camera transformation matrix
1431+
Returns:
1432+
np.ndarray: vector"""
1433+
ground_points_lidar = np.array([[0.0, 0.0, 0.0], [0.0, 1.0, 0.0], [1.0, 1.0, 0.0]])
1434+
ground_points_lidar = np.concatenate(
1435+
(ground_points_lidar, np.ones((ground_points_lidar.shape[0], 1))), axis=1
1436+
)
1437+
ground_points_cam = np.matmul(Tr_velo_to_cam_homo, ground_points_lidar.T).T
1438+
denorm = -1 * Cuboid2D._get_plane_equation(ground_points_cam)
1439+
return denorm
1440+
1441+
@staticmethod
1442+
def _get_3d_points(dim, location, rotation_y, denorm):
1443+
"""Get corner points according to the 3D bounding box parameters.
1444+
1445+
Args:
1446+
dim (List[float]): The dimensions of the 3D bounding box as [l, w, h].
1447+
location (List[float]): The location of the 3D bounding box as [x, y, z].
1448+
rotation_y (float): The rotation angle around the y-axis.
1449+
1450+
Returns:
1451+
np.ndarray: The corner points of the 3D bounding box.
1452+
"""
1453+
1454+
c, s = np.cos(rotation_y), np.sin(rotation_y)
1455+
R = np.array([[c, 0, s], [0, 1, 0], [-s, 0, c]], dtype=np.float32)
1456+
l, w, h = dim[2], dim[1], dim[0]
1457+
x_corners = [l / 2, l / 2, -l / 2, -l / 2, l / 2, l / 2, -l / 2, -l / 2]
1458+
y_corners = [0, 0, 0, 0, -h, -h, -h, -h]
1459+
z_corners = [w / 2, -w / 2, -w / 2, w / 2, w / 2, -w / 2, -w / 2, w / 2]
1460+
1461+
corners = np.array([x_corners, y_corners, z_corners], dtype=np.float32)
1462+
corners_3d = np.dot(R, corners)
1463+
1464+
denorm = denorm[:3]
1465+
denorm_norm = denorm / np.sqrt(denorm[0] ** 2 + denorm[1] ** 2 + denorm[2] ** 2)
1466+
ori_denorm = np.array([0.0, -1.0, 0.0])
1467+
theta = -1 * math.acos(np.dot(denorm_norm, ori_denorm))
1468+
n_vector = np.cross(denorm, ori_denorm)
1469+
n_vector_norm = n_vector / np.sqrt(n_vector[0] ** 2 + n_vector[1] ** 2 + n_vector[2] ** 2)
1470+
rotation_matrix, j = cv2.Rodrigues(theta * n_vector_norm)
1471+
corners_3d = np.dot(rotation_matrix, corners_3d)
1472+
corners_3d = corners_3d + np.array(location, dtype=np.float32).reshape(3, 1)
1473+
return corners_3d.transpose(1, 0)
1474+
1475+
@staticmethod
1476+
def _project_to_2d(pts_3d, P):
1477+
"""Project 3D points to 2D image plane.
1478+
1479+
Args:
1480+
pts_3d (np.ndarray): The 3D points to project.
1481+
P (np.ndarray): The projection matrix.
1482+
1483+
Returns:
1484+
np.ndarray: The 2D points projected to the image
1485+
"""
1486+
# Convert to homogeneous coordinates
1487+
pts_3d = pts_3d.T
1488+
pts_3d_homo = np.vstack((pts_3d, np.ones(pts_3d.shape[1])))
1489+
pts_2d = P @ pts_3d_homo
1490+
pts_2d[0, :] = np.divide(pts_2d[0, :], pts_2d[2, :])
1491+
pts_2d[1, :] = np.divide(pts_2d[1, :], pts_2d[2, :])
1492+
pts_2d = pts_2d[:2, :].T
1493+
1494+
return pts_2d
1495+
1496+
@classmethod
1497+
def from_3d(
1498+
cls,
1499+
dim: np.ndarray,
1500+
location: np.ndarray,
1501+
rotation_y: float,
1502+
P: np.ndarray,
1503+
Tr_velo_to_cam: np.ndarray,
1504+
) -> Cuboid2D:
1505+
"""Creates an instance of Cuboid2D class from 3D bounding box parameters.
1506+
1507+
Args:
1508+
dim (np.ndarray): 3 scalars describing length, width and height of a 3D bounding box
1509+
location (np.ndarray): (x, y, z) coordinates of the middle of the top face.
1510+
rotation_y (np.ndarray): rotation along the Y-axis (from -pi to pi)
1511+
P (np.ndarray): Camera-to-Image transformation matrix (3x4)
1512+
Tr_velo_to_cam (np.ndarray): LiDAR-to-Camera transformation matrix (3x4)
1513+
1514+
Returns:
1515+
Cuboid2D: Projection points for the given bounding box
1516+
"""
1517+
Tr_velo_to_cam_homo = np.eye(4)
1518+
Tr_velo_to_cam_homo[:3, :4] = Tr_velo_to_cam
1519+
denorm = cls._get_denorm(Tr_velo_to_cam_homo)
1520+
pts_3d = cls._get_3d_points(dim, location, rotation_y, denorm)
1521+
y_3d = np.mean(pts_3d[:4, 1])
1522+
pts_2d = cls._project_to_2d(pts_3d, P)
1523+
1524+
return cls(list(map(tuple, pts_2d)), y_3d=y_3d)
1525+
1526+
def to_3d(self, P_inv: np.ndarray) -> tuple[np.ndarray, np.ndarray, float]:
1527+
"""Reconstructs 3D object Velodyne coordinates (dimensions, location and rotation along the Y-axis)
1528+
from the given Cuboid2D instance.
1529+
1530+
Args:
1531+
P_inv (np.ndarray): Pseudo-inverse of Camera-to-Image projection matrix
1532+
Returns:
1533+
tuple: dimensions, location and rotation along the Y-axis
1534+
"""
1535+
recon_3d = []
1536+
for idx, coord_2d in enumerate(self.points):
1537+
coord_2d = np.append(coord_2d, 1)
1538+
coord_3d = P_inv @ coord_2d
1539+
if idx < 4:
1540+
coord_3d = coord_3d * self.y_3d / coord_3d[1]
1541+
else:
1542+
coord_3d = coord_3d * recon_3d[idx - 4][0] / coord_3d[0]
1543+
recon_3d.append(coord_3d[:3])
1544+
recon_3d = np.array(recon_3d)
1545+
1546+
x = np.mean(recon_3d[:, 0])
1547+
z = np.mean(recon_3d[:, 2])
1548+
1549+
yaws = []
1550+
pairs = [(0, 1), (3, 2), (4, 5), (7, 6)]
1551+
for p in pairs:
1552+
delta_x = recon_3d[p[0]][0] - recon_3d[p[1]][0]
1553+
delta_z = recon_3d[p[0]][2] - recon_3d[p[1]][2]
1554+
yaws.append(np.arctan2(delta_x, delta_z))
1555+
yaw = np.mean(yaws)
1556+
1557+
widths = []
1558+
pairs = [(0, 1), (2, 3), (4, 5), (6, 7)]
1559+
for p in pairs:
1560+
delta_x = np.sqrt(
1561+
(recon_3d[p[0]][0] - recon_3d[p[1]][0]) ** 2
1562+
+ (recon_3d[p[0]][2] - recon_3d[p[1]][2]) ** 2
1563+
)
1564+
widths.append(delta_x)
1565+
w = np.mean(widths)
1566+
1567+
lengths = []
1568+
pairs = [(1, 2), (0, 3), (5, 6), (4, 7)]
1569+
for p in pairs:
1570+
delta_z = np.sqrt(
1571+
(recon_3d[p[0]][0] - recon_3d[p[1]][0]) ** 2
1572+
+ (recon_3d[p[0]][2] - recon_3d[p[1]][2]) ** 2
1573+
)
1574+
lengths.append(delta_z)
1575+
l = np.mean(lengths)
1576+
1577+
heights = []
1578+
pairs = [(0, 4), (1, 5), (2, 6), (3, 7)]
1579+
for p in pairs:
1580+
delta_y = np.abs(recon_3d[p[0]][1] - recon_3d[p[1]][1])
1581+
heights.append(delta_y)
1582+
h = np.mean(heights)
1583+
return np.array([h, w, l]), np.array([x, self.y_3d, z]), yaw
1584+
14011585

14021586
@attrs(slots=True, order=False)
14031587
class PointsCategories(Categories):

src/datumaro/components/visualizer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -679,8 +679,8 @@ def _draw_cuboid_2d(
679679
# Define the faces based on vertex indices
680680

681681
faces = [
682-
[points[i] for i in [0, 1, 2, 3]], # Bottom face
683-
[points[i] for i in [4, 5, 6, 7]], # Top face
682+
[points[i] for i in [0, 1, 2, 3]], # Top face
683+
[points[i] for i in [4, 5, 6, 7]], # Bottom face
684684
[points[i] for i in [0, 1, 5, 4]], # Front face
685685
[points[i] for i in [1, 2, 6, 5]], # Right face
686686
[points[i] for i in [2, 3, 7, 6]], # Back face

tests/unit/test_annotation.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
from datumaro.components.annotation import (
1515
Annotations,
16+
Cuboid2D,
1617
Ellipse,
1718
ExtractedMask,
1819
HashKey,
@@ -142,3 +143,66 @@ def test_get_semantic_seg_mask_binary_mask(self, fxt_index_mask, dtype):
142143
semantic_seg_mask = annotations.get_semantic_seg_mask(ignore_index=255, dtype=dtype)
143144

144145
assert np.allclose(semantic_seg_mask, fxt_index_mask)
146+
147+
148+
class Cuboid2DTest:
149+
@pytest.fixture
150+
def fxt_cuboid_2d(self):
151+
return Cuboid2D(
152+
[
153+
(684.172, 237.810),
154+
(750.486, 237.673),
155+
(803.791, 256.313),
156+
(714.712, 256.542),
157+
(684.035, 174.227),
158+
(750.263, 174.203),
159+
(803.399, 170.990),
160+
(714.476, 171.015),
161+
],
162+
y_3d=1.49,
163+
)
164+
165+
@pytest.fixture
166+
def fxt_kitti_data(self):
167+
dimensions = np.array([1.49, 1.56, 4.34])
168+
location = np.array([2.51, 1.49, 14.75])
169+
rot_y = -1.59
170+
171+
return dimensions, location, rot_y
172+
173+
@pytest.fixture
174+
def fxt_P2(self):
175+
return np.array(
176+
[
177+
[7.215377000000e02, 0.000000000000e00, 6.095593000000e02, 4.485728000000e01],
178+
[0.000000000000e00, 7.215377000000e02, 1.728540000000e02, 2.163791000000e-01],
179+
[0.000000000000e00, 0.000000000000e00, 1.000000000000e00, 2.745884000000e-03],
180+
]
181+
)
182+
183+
@pytest.fixture
184+
def fxt_velo_to_cam(self):
185+
return np.array(
186+
[
187+
[7.533745000000e-03, -9.999714000000e-01, -6.166020000000e-04, -4.069766000000e-03],
188+
[1.480249000000e-02, 7.280733000000e-04, -9.998902000000e-01, -7.631618000000e-02],
189+
[9.998621000000e-01, 7.523790000000e-03, 1.480755000000e-02, -2.717806000000e-01],
190+
]
191+
)
192+
193+
def test_create_from_3d(self, fxt_kitti_data, fxt_cuboid_2d, fxt_P2, fxt_velo_to_cam):
194+
actual = Cuboid2D.from_3d(
195+
dim=fxt_kitti_data[0],
196+
location=fxt_kitti_data[1],
197+
rotation_y=fxt_kitti_data[2],
198+
P=fxt_P2,
199+
Tr_velo_to_cam=fxt_velo_to_cam,
200+
)
201+
202+
assert np.allclose(actual.points, fxt_cuboid_2d.points, atol=1e-3)
203+
204+
def test_to_3d(self, fxt_kitti_data, fxt_cuboid_2d, fxt_P2):
205+
P_inv = np.linalg.pinv(fxt_P2)
206+
actual = fxt_cuboid_2d.to_3d(P_inv)
207+
for act, exp in zip(actual, fxt_kitti_data):
208+
assert np.allclose(act, exp, atol=2)

0 commit comments

Comments
 (0)