Skip to content

Commit 82fd69c

Browse files
Updates on examples (#174)
* move torch data transfer into dataloader * Update README * use args.data_root * Remove redundant check * Fix isort * Fix black
1 parent 15330b4 commit 82fd69c

File tree

7 files changed

+54
-50
lines changed

7 files changed

+54
-50
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ optimizer.step()
132132

133133
## Examples:
134134

135-
Before running those example scripts, please check the script about which dataset it is needed, and download the dataset first.
135+
Before running those example scripts, please check the script about which dataset is needed, and download the dataset first. You could use `--data_root` to specify the path.
136136

137137
```bash
138138
# clone the repo with submodules.

examples/datasets/dnerf_synthetic.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ def __init__(
8686
near: float = None,
8787
far: float = None,
8888
batch_over_images: bool = True,
89+
device: str = "cuda:0",
8990
):
9091
super().__init__()
9192
assert split in self.SPLITS, "%s" % split
@@ -106,18 +107,23 @@ def __init__(
106107
self.focal,
107108
self.timestamps,
108109
) = _load_renderings(root_fp, subject_id, split)
109-
self.images = torch.from_numpy(self.images).to(torch.uint8)
110-
self.camtoworlds = torch.from_numpy(self.camtoworlds).to(torch.float32)
111-
self.timestamps = torch.from_numpy(self.timestamps).to(torch.float32)[
112-
:, None
113-
]
110+
self.images = torch.from_numpy(self.images).to(device).to(torch.uint8)
111+
self.camtoworlds = (
112+
torch.from_numpy(self.camtoworlds).to(device).to(torch.float32)
113+
)
114+
self.timestamps = (
115+
torch.from_numpy(self.timestamps)
116+
.to(device)
117+
.to(torch.float32)[:, None]
118+
)
114119
self.K = torch.tensor(
115120
[
116121
[self.focal, 0, self.WIDTH / 2.0],
117122
[0, self.focal, self.HEIGHT / 2.0],
118123
[0, 0, 1],
119124
],
120125
dtype=torch.float32,
126+
device=device,
121127
) # (3, 3)
122128
assert self.images.shape[1:3] == (self.HEIGHT, self.WIDTH)
123129

examples/datasets/nerf_360_v2.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,7 @@ def __init__(
169169
far: float = None,
170170
batch_over_images: bool = True,
171171
factor: int = 1,
172+
device: str = "cuda:0",
172173
):
173174
super().__init__()
174175
assert split in self.SPLITS, "%s" % split
@@ -186,9 +187,11 @@ def __init__(
186187
self.images, self.camtoworlds, self.K = _load_colmap(
187188
root_fp, subject_id, split, factor
188189
)
189-
self.images = torch.from_numpy(self.images).to(torch.uint8)
190-
self.camtoworlds = torch.from_numpy(self.camtoworlds).to(torch.float32)
191-
self.K = torch.tensor(self.K).to(torch.float32)
190+
self.images = torch.from_numpy(self.images).to(device).to(torch.uint8)
191+
self.camtoworlds = (
192+
torch.from_numpy(self.camtoworlds).to(device).to(torch.float32)
193+
)
194+
self.K = torch.tensor(self.K).to(device).to(torch.float32)
192195
self.height, self.width = self.images.shape[1:3]
193196

194197
def __len__(self):

examples/datasets/nerf_synthetic.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ def __init__(
7979
near: float = None,
8080
far: float = None,
8181
batch_over_images: bool = True,
82+
device: str = "cuda:0",
8283
):
8384
super().__init__()
8485
assert split in self.SPLITS, "%s" % split
@@ -109,15 +110,18 @@ def __init__(
109110
self.images, self.camtoworlds, self.focal = _load_renderings(
110111
root_fp, subject_id, split
111112
)
112-
self.images = torch.from_numpy(self.images).to(torch.uint8)
113-
self.camtoworlds = torch.from_numpy(self.camtoworlds).to(torch.float32)
113+
self.images = torch.from_numpy(self.images).to(device).to(torch.uint8)
114+
self.camtoworlds = (
115+
torch.from_numpy(self.camtoworlds).to(device).to(torch.float32)
116+
)
114117
self.K = torch.tensor(
115118
[
116119
[self.focal, 0, self.WIDTH / 2.0],
117120
[0, self.focal, self.HEIGHT / 2.0],
118121
[0, 0, 1],
119122
],
120123
dtype=torch.float32,
124+
device=device,
121125
) # (3, 3)
122126
assert self.images.shape[1:3] == (self.HEIGHT, self.WIDTH)
123127

examples/train_mlp_dnerf.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import argparse
66
import math
7-
import os
7+
import pathlib
88
import time
99

1010
import imageio
@@ -24,6 +24,12 @@
2424
set_random_seed(42)
2525

2626
parser = argparse.ArgumentParser()
27+
parser.add_argument(
28+
"--data_root",
29+
type=str,
30+
default=str(pathlib.Path.cwd() / "data/dnerf"),
31+
help="the root dir of the dataset",
32+
)
2733
parser.add_argument(
2834
"--train_split",
2935
type=str,
@@ -91,31 +97,22 @@
9197
gamma=0.33,
9298
)
9399
# setup the dataset
94-
data_root_fp = "/home/ruilongli/data/dnerf/"
95100
target_sample_batch_size = 1 << 16
96101
grid_resolution = 128
97102

98103
train_dataset = SubjectLoader(
99104
subject_id=args.scene,
100-
root_fp=data_root_fp,
105+
root_fp=args.data_root,
101106
split=args.train_split,
102107
num_rays=target_sample_batch_size // render_n_samples,
103108
)
104-
train_dataset.images = train_dataset.images.to(device)
105-
train_dataset.camtoworlds = train_dataset.camtoworlds.to(device)
106-
train_dataset.K = train_dataset.K.to(device)
107-
train_dataset.timestamps = train_dataset.timestamps.to(device)
108109

109110
test_dataset = SubjectLoader(
110111
subject_id=args.scene,
111-
root_fp=data_root_fp,
112+
root_fp=args.data_root,
112113
split="test",
113114
num_rays=None,
114115
)
115-
test_dataset.images = test_dataset.images.to(device)
116-
test_dataset.camtoworlds = test_dataset.camtoworlds.to(device)
117-
test_dataset.K = test_dataset.K.to(device)
118-
test_dataset.timestamps = test_dataset.timestamps.to(device)
119116

120117
occupancy_grid = OccupancyGrid(
121118
roi_aabb=args.aabb,
@@ -191,7 +188,7 @@
191188
f"n_rendering_samples={n_rendering_samples:d} | num_rays={len(pixels):d} |"
192189
)
193190

194-
if step >= 0 and step % max_steps == 0 and step > 0:
191+
if step > 0 and step % max_steps == 0:
195192
# evaluation
196193
radiance_field.eval()
197194

examples/train_mlp_nerf.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import argparse
66
import math
7-
import os
7+
import pathlib
88
import time
99

1010
import imageio
@@ -23,6 +23,12 @@
2323
set_random_seed(42)
2424

2525
parser = argparse.ArgumentParser()
26+
parser.add_argument(
27+
"--data_root",
28+
type=str,
29+
default=str(pathlib.Path.cwd() / "data/nerf_synthetic"),
30+
help="the root dir of the dataset",
31+
)
2632
parser.add_argument(
2733
"--train_split",
2834
type=str,
@@ -112,40 +118,31 @@
112118
if args.scene == "garden":
113119
from datasets.nerf_360_v2 import SubjectLoader
114120

115-
data_root_fp = "/home/ruilongli/data/360_v2/"
116121
target_sample_batch_size = 1 << 16
117122
train_dataset_kwargs = {"color_bkgd_aug": "random", "factor": 4}
118123
test_dataset_kwargs = {"factor": 4}
119124
grid_resolution = 128
120125
else:
121126
from datasets.nerf_synthetic import SubjectLoader
122127

123-
data_root_fp = "/home/ruilongli/data/nerf_synthetic/"
124128
target_sample_batch_size = 1 << 16
125129
grid_resolution = 128
126130

127131
train_dataset = SubjectLoader(
128132
subject_id=args.scene,
129-
root_fp=data_root_fp,
133+
root_fp=args.data_root,
130134
split=args.train_split,
131135
num_rays=target_sample_batch_size // render_n_samples,
132136
**train_dataset_kwargs,
133137
)
134138

135-
train_dataset.images = train_dataset.images.to(device)
136-
train_dataset.camtoworlds = train_dataset.camtoworlds.to(device)
137-
train_dataset.K = train_dataset.K.to(device)
138-
139139
test_dataset = SubjectLoader(
140140
subject_id=args.scene,
141-
root_fp=data_root_fp,
141+
root_fp=args.data_root,
142142
split="test",
143143
num_rays=None,
144144
**test_dataset_kwargs,
145145
)
146-
test_dataset.images = test_dataset.images.to(device)
147-
test_dataset.camtoworlds = test_dataset.camtoworlds.to(device)
148-
test_dataset.K = test_dataset.K.to(device)
149146

150147
occupancy_grid = OccupancyGrid(
151148
roi_aabb=args.aabb,
@@ -217,7 +214,7 @@
217214
f"n_rendering_samples={n_rendering_samples:d} | num_rays={len(pixels):d} |"
218215
)
219216

220-
if step >= 0 and step % max_steps == 0 and step > 0:
217+
if step > 0 and step % max_steps == 0:
221218
# evaluation
222219
radiance_field.eval()
223220

examples/train_ngp_nerf.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import argparse
66
import math
7-
import os
7+
import pathlib
88
import time
99

1010
import imageio
@@ -23,6 +23,12 @@
2323
set_random_seed(42)
2424

2525
parser = argparse.ArgumentParser()
26+
parser.add_argument(
27+
"--data_root",
28+
type=str,
29+
default=str(pathlib.Path.cwd() / "data"),
30+
help="the root dir of the dataset",
31+
)
2632
parser.add_argument(
2733
"--train_split",
2834
type=str,
@@ -87,40 +93,31 @@
8793
if args.unbounded:
8894
from datasets.nerf_360_v2 import SubjectLoader
8995

90-
data_root_fp = "/home/ruilongli/data/360_v2/"
9196
target_sample_batch_size = 1 << 20
9297
train_dataset_kwargs = {"color_bkgd_aug": "random", "factor": 4}
9398
test_dataset_kwargs = {"factor": 4}
9499
grid_resolution = 256
95100
else:
96101
from datasets.nerf_synthetic import SubjectLoader
97102

98-
data_root_fp = "/home/ruilongli/data/nerf_synthetic/"
99103
target_sample_batch_size = 1 << 18
100104
grid_resolution = 128
101105

102106
train_dataset = SubjectLoader(
103107
subject_id=args.scene,
104-
root_fp=data_root_fp,
108+
root_fp=args.data_root,
105109
split=args.train_split,
106110
num_rays=target_sample_batch_size // render_n_samples,
107111
**train_dataset_kwargs,
108112
)
109113

110-
train_dataset.images = train_dataset.images.to(device)
111-
train_dataset.camtoworlds = train_dataset.camtoworlds.to(device)
112-
train_dataset.K = train_dataset.K.to(device)
113-
114114
test_dataset = SubjectLoader(
115115
subject_id=args.scene,
116-
root_fp=data_root_fp,
116+
root_fp=args.data_root,
117117
split="test",
118118
num_rays=None,
119119
**test_dataset_kwargs,
120120
)
121-
test_dataset.images = test_dataset.images.to(device)
122-
test_dataset.camtoworlds = test_dataset.camtoworlds.to(device)
123-
test_dataset.K = test_dataset.K.to(device)
124121

125122
if args.auto_aabb:
126123
camera_locs = torch.cat(
@@ -260,7 +257,7 @@ def occ_eval_fn(x):
260257
f"n_rendering_samples={n_rendering_samples:d} | num_rays={len(pixels):d} |"
261258
)
262259

263-
if step >= 0 and step % max_steps == 0 and step > 0:
260+
if step > 0 and step % max_steps == 0:
264261
# evaluation
265262
radiance_field.eval()
266263

0 commit comments

Comments
 (0)