Skip to content

Commit ebeb5dd

Browse files
Cameras (#198)
* opencv_lens_undistortion * fix k4 bug for undistortion, support fisheye * support k3 k4 k5 k6 * fix _opencv_len_distortion; format * naming: len->lens
1 parent 4f7965c commit ebeb5dd

File tree

7 files changed

+850
-31
lines changed

7 files changed

+850
-31
lines changed

nerfacc/cameras.py

Lines changed: 211 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,211 @@
1+
"""
2+
Copyright (c) 2022 Ruilong Li, UC Berkeley.
3+
"""
4+
from typing import Tuple
5+
6+
import torch
7+
import torch.nn.functional as F
8+
from torch import Tensor
9+
10+
from . import cuda as _C
11+
12+
13+
def opencv_lens_undistortion(
14+
uv: Tensor, params: Tensor, eps: float = 1e-6, iters: int = 10
15+
) -> Tensor:
16+
"""Undistort the opencv distortion.
17+
18+
Note:
19+
This function is not differentiable to any inputs.
20+
21+
Args:
22+
uv: (..., 2) UV coordinates.
23+
params: (..., N) or (N) OpenCV distortion parameters. We support
24+
N = 0, 1, 2, 4, 8. If N = 0, we return the input uv directly.
25+
If N = 1, we assume the input is {k1}. If N = 2, we assume the
26+
input is {k1, k2}. If N = 4, we assume the input is {k1, k2, p1, p2}.
27+
If N = 8, we assume the input is {k1, k2, p1, p2, k3, k4, k5, k6}.
28+
29+
Returns:
30+
(..., 2) undistorted UV coordinates.
31+
"""
32+
assert uv.shape[-1] == 2
33+
assert params.shape[-1] in [0, 1, 2, 4, 8]
34+
35+
if params.shape[-1] == 0:
36+
return uv
37+
elif params.shape[-1] < 8:
38+
params = F.pad(params, (0, 8 - params.shape[-1]), "constant", 0)
39+
assert params.shape[-1] == 8
40+
41+
batch_shape = uv.shape[:-1]
42+
params = torch.broadcast_to(params, batch_shape + (params.shape[-1],))
43+
44+
return _C.opencv_lens_undistortion(
45+
uv.contiguous(), params.contiguous(), eps, iters
46+
)
47+
48+
49+
def opencv_lens_undistortion_fisheye(
50+
uv: Tensor, params: Tensor, eps: float = 1e-6, iters: int = 10
51+
) -> Tensor:
52+
"""Undistort the opencv distortion of {k1, k2, k3, k4}.
53+
54+
Note:
55+
This function is not differentiable to any inputs.
56+
57+
Args:
58+
uv: (..., 2) UV coordinates.
59+
params: (..., 4) or (4) OpenCV distortion parameters.
60+
61+
Returns:
62+
(..., 2) undistorted UV coordinates.
63+
"""
64+
assert uv.shape[-1] == 2
65+
assert params.shape[-1] == 4
66+
batch_shape = uv.shape[:-1]
67+
params = torch.broadcast_to(params, batch_shape + (params.shape[-1],))
68+
69+
return _C.opencv_lens_undistortion_fisheye(
70+
uv.contiguous(), params.contiguous(), eps, iters
71+
)
72+
73+
74+
def _opencv_lens_distortion(uv: Tensor, params: Tensor) -> Tensor:
75+
"""The opencv camera distortion of {k1, k2, p1, p2, k3, k4, k5, k6}.
76+
77+
See https://docs.opencv.org/3.4/d9/d0c/group__calib3d.html for more details.
78+
"""
79+
k1, k2, p1, p2, k3, k4, k5, k6 = torch.unbind(params, dim=-1)
80+
s1, s2, s3, s4 = 0, 0, 0, 0
81+
u, v = torch.unbind(uv, dim=-1)
82+
r2 = u * u + v * v
83+
r4 = r2**2
84+
r6 = r4 * r2
85+
ratial = (1 + k1 * r2 + k2 * r4 + k3 * r6) / (
86+
1 + k4 * r2 + k5 * r4 + k6 * r6
87+
)
88+
fx = 2 * p1 * u * v + p2 * (r2 + 2 * u * u) + s1 * r2 + s2 * r4
89+
fy = 2 * p2 * u * v + p1 * (r2 + 2 * v * v) + s3 * r2 + s4 * r4
90+
return torch.stack([u * ratial + fx, v * ratial + fy], dim=-1)
91+
92+
93+
def _opencv_lens_distortion_fisheye(
94+
uv: Tensor, params: Tensor, eps: float = 1e-10
95+
) -> Tensor:
96+
"""The opencv camera distortion of {k1, k2, k3, p1, p2}.
97+
98+
See https://docs.opencv.org/4.x/db/d58/group__calib3d__fisheye.html for more details.
99+
100+
Args:
101+
uv: (..., 2) UV coordinates.
102+
params: (..., 4) or (4) OpenCV distortion parameters.
103+
104+
Returns:
105+
(..., 2) distorted UV coordinates.
106+
"""
107+
assert params.shape[-1] == 4, f"Invalid params shape: {params.shape}"
108+
k1, k2, k3, k4 = torch.unbind(params, dim=-1)
109+
u, v = torch.unbind(uv, dim=-1)
110+
r = torch.sqrt(u * u + v * v)
111+
theta = torch.atan(r)
112+
theta_d = theta * (
113+
1
114+
+ k1 * theta**2
115+
+ k2 * theta**4
116+
+ k3 * theta**6
117+
+ k4 * theta**8
118+
)
119+
scale = theta_d / torch.clamp(r, min=eps)
120+
return uv * scale[..., None]
121+
122+
123+
@torch.jit.script
124+
def _compute_residual_and_jacobian(
125+
x: Tensor, y: Tensor, xd: Tensor, yd: Tensor, params: Tensor
126+
) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
127+
assert params.shape[-1] == 8
128+
129+
k1, k2, p1, p2, k3, k4, k5, k6 = torch.unbind(params, dim=-1)
130+
131+
# let r(x, y) = x^2 + y^2;
132+
# alpha(x, y) = 1 + k1 * r(x, y) + k2 * r(x, y) ^2 + k3 * r(x, y)^3;
133+
# beta(x, y) = 1 + k4 * r(x, y) + k5 * r(x, y) ^2 + k6 * r(x, y)^3;
134+
# d(x, y) = alpha(x, y) / beta(x, y);
135+
r = x * x + y * y
136+
alpha = 1.0 + r * (k1 + r * (k2 + r * k3))
137+
beta = 1.0 + r * (k4 + r * (k5 + r * k6))
138+
d = alpha / beta
139+
140+
# The perfect projection is:
141+
# xd = x * d(x, y) + 2 * p1 * x * y + p2 * (r(x, y) + 2 * x^2);
142+
# yd = y * d(x, y) + 2 * p2 * x * y + p1 * (r(x, y) + 2 * y^2);
143+
#
144+
# Let's define
145+
#
146+
# fx(x, y) = x * d(x, y) + 2 * p1 * x * y + p2 * (r(x, y) + 2 * x^2) - xd;
147+
# fy(x, y) = y * d(x, y) + 2 * p2 * x * y + p1 * (r(x, y) + 2 * y^2) - yd;
148+
#
149+
# We are looking for a solution that satisfies
150+
# fx(x, y) = fy(x, y) = 0;
151+
fx = d * x + 2 * p1 * x * y + p2 * (r + 2 * x * x) - xd
152+
fy = d * y + 2 * p2 * x * y + p1 * (r + 2 * y * y) - yd
153+
154+
# Compute derivative of alpha, beta over r.
155+
alpha_r = k1 + r * (2.0 * k2 + r * (3.0 * k3))
156+
beta_r = k4 + r * (2.0 * k5 + r * (3.0 * k6))
157+
158+
# Compute derivative of d over [x, y]
159+
d_r = (alpha_r * beta - alpha * beta_r) / (beta * beta)
160+
d_x = 2.0 * x * d_r
161+
d_y = 2.0 * y * d_r
162+
163+
# Compute derivative of fx over x and y.
164+
fx_x = d + d_x * x + 2.0 * p1 * y + 6.0 * p2 * x
165+
fx_y = d_y * x + 2.0 * p1 * x + 2.0 * p2 * y
166+
167+
# Compute derivative of fy over x and y.
168+
fy_x = d_x * y + 2.0 * p2 * y + 2.0 * p1 * x
169+
fy_y = d + d_y * y + 2.0 * p2 * x + 6.0 * p1 * y
170+
171+
return fx, fy, fx_x, fx_y, fy_x, fy_y
172+
173+
174+
@torch.jit.script
175+
def _opencv_lens_undistortion(
176+
uv: Tensor, params: Tensor, eps: float = 1e-6, iters: int = 10
177+
) -> Tensor:
178+
"""Same as opencv_lens_undistortion(), but native PyTorch.
179+
180+
Took from with bug fix and modification.
181+
https://github.com/nerfstudio-project/nerfstudio/blob/ec603634edbd61b13bdf2c598fda8c993370b8f7/nerfstudio/cameras/camera_utils.py
182+
"""
183+
assert uv.shape[-1] == 2
184+
assert params.shape[-1] in [0, 1, 2, 4, 8]
185+
186+
if params.shape[-1] == 0:
187+
return uv
188+
elif params.shape[-1] < 8:
189+
params = F.pad(params, (0, 8 - params.shape[-1]), "constant", 0.0)
190+
assert params.shape[-1] == 8
191+
192+
# Initialize from the distorted point.
193+
x, y = x0, y0 = torch.unbind(uv, dim=-1)
194+
195+
zeros = torch.zeros_like(x)
196+
for _ in range(iters):
197+
fx, fy, fx_x, fx_y, fy_x, fy_y = _compute_residual_and_jacobian(
198+
x=x, y=y, xd=x0, yd=y0, params=params
199+
)
200+
denominator = fy_x * fx_y - fx_x * fy_y
201+
mask = torch.abs(denominator) > eps
202+
203+
x_numerator = fx * fy_y - fy * fx_y
204+
y_numerator = fy * fx_x - fx * fy_x
205+
step_x = torch.where(mask, x_numerator / denominator, zeros)
206+
step_y = torch.where(mask, y_numerator / denominator, zeros)
207+
208+
x = x + step_x
209+
y = y + step_y
210+
211+
return torch.stack([x, y], dim=-1)

nerfacc/cameras2.py

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
"""
2+
Copyright (c) 2022 Ruilong Li, UC Berkeley.
3+
4+
Seems like both colmap and nerfstudio are based on OpenCV's camera model.
5+
6+
References:
7+
- nerfstudio: https://github.com/nerfstudio-project/nerfstudio/blob/main/nerfstudio/cameras/cameras.py
8+
- opencv:
9+
- https://docs.opencv.org/3.4/da/d54/group__imgproc__transform.html#ga69f2545a8b62a6b0fc2ee060dc30559d
10+
- https://docs.opencv.org/3.4/d9/d0c/group__calib3d.html
11+
- https://docs.opencv.org/4.x/db/d58/group__calib3d__fisheye.html
12+
- https://github.com/opencv/opencv/blob/master/modules/calib3d/src/fisheye.cpp#L321
13+
- https://github.com/opencv/opencv/blob/17234f82d025e3bbfbf611089637e5aa2038e7b8/modules/calib3d/src/distortion_model.hpp
14+
- https://github.com/opencv/opencv/blob/8d0fbc6a1e9f20c822921e8076551a01e58cd632/modules/calib3d/src/undistort.dispatch.cpp#L578
15+
- colmap: https://github.com/colmap/colmap/blob/dev/src/base/camera_models.h
16+
- calcam: https://euratom-software.github.io/calcam/html/intro_theory.html
17+
- blender:
18+
- https://docs.blender.org/manual/en/latest/render/cycles/object_settings/cameras.html#fisheye-lens-polynomial
19+
- https://github.com/blender/blender/blob/03cc3b94c94c38767802bccac4e9384ab704065a/intern/cycles/kernel/kernel_projection.h
20+
- lensfun: https://lensfun.github.io/manual/v0.3.2/annotated.html
21+
22+
- OpenCV and Blender has different fisheye camera models
23+
- https://stackoverflow.com/questions/73270140/pipeline-for-fisheye-distortion-and-undistortion-with-blender-and-opencv
24+
"""
25+
from typing import Literal, Optional, Tuple
26+
27+
import torch
28+
import torch.nn.functional as F
29+
from torch import Tensor
30+
31+
from . import cuda as _C
32+
33+
34+
def ray_directions_from_uvs(
35+
uvs: Tensor, # [..., 2]
36+
Ks: Tensor, # [..., 3, 3]
37+
params: Optional[Tensor] = None, # [..., M]
38+
) -> Tensor:
39+
"""Create ray directions from uvs and camera parameters in OpenCV format.
40+
41+
Args:
42+
uvs: UV coordinates on image plane. (In pixel unit)
43+
Ks: Camera intrinsics.
44+
params: Camera distortion parameters. See `opencv.undistortPoints` for details.
45+
46+
Returns:
47+
Normalized ray directions in camera space.
48+
"""
49+
u, v = torch.unbind(uvs + 0.5, dim=-1)
50+
fx, fy = Ks[..., 0, 0], Ks[..., 1, 1]
51+
cx, cy = Ks[..., 0, 2], Ks[..., 1, 2]
52+
53+
# undo intrinsics
54+
xys = torch.stack([(u - cx) / fx, (v - cy) / fy], dim=-1) # [..., 2]
55+
56+
# undo lens distortion
57+
if params is not None:
58+
M = params.shape[-1]
59+
60+
if M == 14: # undo tilt projection
61+
R, R_inv = opencv_tilt_projection_matrix(params[..., -2:])
62+
xys_homo = F.pad(xys, (0, 1), value=1.0) # [..., 3]
63+
xys_homo = torch.einsum(
64+
"...ij,...j->...i", R_inv, xys_homo
65+
) # [..., 3]
66+
xys = xys_homo[..., :2]
67+
homo = xys_homo[..., 2:]
68+
xys /= torch.where(homo != 0.0, homo, torch.ones_like(homo))
69+
70+
xys = opencv_lens_undistortion(xys, params) # [..., 2]
71+
72+
# normalized homogeneous coordinates
73+
dirs = F.pad(xys, (0, 1), value=1.0) # [..., 3]
74+
dirs = F.normalize(dirs, dim=-1) # [..., 3]
75+
return dirs
76+
77+
78+
def opencv_lens_undistortion(
79+
uv: Tensor, params: Tensor, eps: float = 1e-6, iters: int = 10
80+
) -> Tensor:
81+
"""Undistort the opencv distortion of {k1, k2, k3, k4, p1, p2}.
82+
83+
Note:
84+
This function is not differentiable to any inputs.
85+
86+
Args:
87+
uv: (..., 2) UV coordinates.
88+
params: (..., 6) or (6) OpenCV distortion parameters.
89+
90+
Returns:
91+
(..., 2) undistorted UV coordinates.
92+
"""
93+
assert uv.shape[-1] == 2
94+
assert params.shape[-1] == 6
95+
batch_shape = uv.shape[:-1]
96+
params = torch.broadcast_to(params, batch_shape + (6,))
97+
98+
return _C.opencv_lens_undistortion(
99+
uv.contiguous(), params.contiguous(), eps, iters
100+
)
101+
102+
103+
def opencv_tilt_projection_matrix(tau: Tensor) -> Tensor:
104+
"""Create a tilt projection matrix.
105+
106+
Reference:
107+
https://docs.opencv.org/3.4/d9/d0c/group__calib3d.html
108+
109+
Args:
110+
tau: (..., 2) tilt angles.
111+
112+
Returns:
113+
(..., 3, 3) tilt projection matrix.
114+
"""
115+
116+
cosx, cosy = torch.unbind(torch.cos(tau), -1)
117+
sinx, siny = torch.unbind(torch.sin(tau), -1)
118+
one = torch.ones_like(tau)
119+
zero = torch.zeros_like(tau)
120+
121+
Rx = torch.stack(
122+
[one, zero, zero, zero, cosx, sinx, zero, -sinx, cosx], -1
123+
).reshape(*tau.shape[:-1], 3, 3)
124+
Ry = torch.stack(
125+
[cosy, zero, -siny, zero, one, zero, siny, zero, cosy], -1
126+
).reshape(*tau.shape[:-1], 3, 3)
127+
Rxy = torch.matmul(Ry, Rx)
128+
Rz = torch.stack(
129+
[
130+
Rxy[..., 2, 2],
131+
zero,
132+
-Rxy[..., 0, 2],
133+
zero,
134+
Rxy[..., 2, 2],
135+
-Rxy[..., 1, 2],
136+
zero,
137+
zero,
138+
one,
139+
],
140+
-1,
141+
).reshape(*tau.shape[:-1], 3, 3)
142+
R = torch.matmul(Rz, Rxy)
143+
144+
inv = 1.0 / Rxy[..., 2, 2]
145+
Rz_inv = torch.stack(
146+
[
147+
inv,
148+
zero,
149+
inv * Rxy[..., 0, 2],
150+
zero,
151+
inv,
152+
inv * Rxy[..., 1, 2],
153+
zero,
154+
zero,
155+
one,
156+
],
157+
-1,
158+
).reshape(*tau.shape[:-1], 3, 3)
159+
R_inv = torch.matmul(Rxy.transpose(-1, -2), Rz_inv)
160+
return R, R_inv

nerfacc/cuda/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,3 +38,9 @@ def call_cuda(*args, **kwargs):
3838
# pdf
3939
importance_sampling = _make_lazy_cuda_func("importance_sampling")
4040
searchsorted = _make_lazy_cuda_func("searchsorted")
41+
42+
# camera
43+
opencv_lens_undistortion = _make_lazy_cuda_func("opencv_lens_undistortion")
44+
opencv_lens_undistortion_fisheye = _make_lazy_cuda_func(
45+
"opencv_lens_undistortion_fisheye"
46+
)

0 commit comments

Comments
 (0)