diff --git a/projects/PTv3/Dockerfile b/projects/PTv3/Dockerfile new file mode 100644 index 00000000..fac7fc1b --- /dev/null +++ b/projects/PTv3/Dockerfile @@ -0,0 +1,29 @@ +ARG AWML_BASE_IMAGE="autoware-ml" +FROM ${AWML_BASE_IMAGE} + +RUN python3 -m pip --no-cache-dir install \ + SharedArray \ + open3d \ + spconv-cu120 \ + flash-attn \ + yapf==0.40.1 \ + torch-geometric \ + ftfy \ + regex \ + tqdm + +RUN python3 -m pip --no-cache-dir install \ + git+https://github.com/openai/CLIP.git + +RUN conda install h5py pyyaml -c anaconda -y +RUN conda install sharedarray tensorboard tensorboardx yapf addict einops scipy plyfile termcolor -c conda-forge -y +RUN conda install pytorch-cluster pytorch-scatter pytorch-sparse -c pyg -y + +RUN python3 -m pip --no-cache-dir install --force-reinstall \ + yapf==0.40.1 + +# For Blackwell, as of 2025/05, some packages need to be compiled or use nightly versions +# RUN python3 -m pip --no-cache-dir install git+https://github.com/rusty1s/pytorch_scatter.git@2.1.2 -vvv +# RUN python3 -m pip --no-cache-dir install torch-geometric -f https://data.pyg.org/whl/nightly/torch-2.7.0%2Bcu128.html +# ENV FLASH_ATTN_CUDA_ARCHS="120" +# MAX_JOBS=4 python3 -m pip --no-cache-dir install flash-attn --no-build-isolation -vvv diff --git a/projects/PTv3/LICENSE_Pointcept b/projects/PTv3/LICENSE_Pointcept new file mode 100644 index 00000000..ee1fac1b --- /dev/null +++ b/projects/PTv3/LICENSE_Pointcept @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2023 Pointcept + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/projects/PTv3/README.md b/projects/PTv3/README.md new file mode 100644 index 00000000..70d96e28 --- /dev/null +++ b/projects/PTv3/README.md @@ -0,0 +1,77 @@ +# Point Transformer V3 (PTv3) + +PTv3 is a lidar segmentation model. +AWML's implementation is a port of the [original code](https://github.com/Pointcept/Pointcept), trimming unused parts of the code base, while also adding support for t4dataset and onnx export. + +## Summary + +- ROS package: [Link](https://github.com/autowarefoundation/autoware_universe/pull/10600) +- Supported datasets + - [x] NuScenes + - [x] T4dataset +- Other supported features + - [x] ONNX export & TensorRT inference + +## Results and models + +- TODO + + +## Get started +### 1. Setup + +- This project requires a different docker environment that most other projects. + +```sh +DOCKER_BUILDKIT=1 docker build -t autoware-ml-ptv3 -f projects/PTv3/Dockerfile . --progress=plain +``` + +-Run docker + +```sh +docker run -it --rm --gpus '"device=0"' --shm-size=64g --name awml -p 6006:6006 -v $PWD/:/workspace -v $PWD/data:/workspace/data autoware-ml-ptv3 +``` + +### 2. Train + +To train the model, use the following commands: + +```sh +cd projects/PTv3 +python tools/train.py --config-file configs/semseg-pt-v3m1-0-t4dataset.py --num-gpus 1 +``` + +To test the model, use the following commands: + +```sh +cd projects/PTv3 +python tools/test.py --config-file configs/semseg-pt-v3m1-0-t4dataset.py --num-gpus 1 \ + --options \ + save_path=data/experiment \ + weight=exp/model/model_best.pth +``` + +### 3. Deployment + +To deploy the model, a modified version of spconv is required. To use it, +please add `projects` to the `PYTHONPATH`: + +```sh +export PYTHONPATH=${PYTHONPATH}:/workspace/projects +``` + +and then export the model: + +```sh +cd projects/PTv3 +python tools/export.py --config-file configs/semseg-pt-v3m1-0-t4dataset.py --num-gpus 1 \ + --options \ + save_path=data/experiment \ + weight=exp/model/model_best.pth +``` + +which will generate a file called `ptv3.onnx` + +## Reference + +- [Pointcept's PTv3](https://github.com/Pointcept/Pointcept) diff --git a/projects/PTv3/configs/_base_/default_runtime.py b/projects/PTv3/configs/_base_/default_runtime.py new file mode 100644 index 00000000..1ec8bf17 --- /dev/null +++ b/projects/PTv3/configs/_base_/default_runtime.py @@ -0,0 +1,39 @@ +weight = None # path to model weight +resume = False # whether to resume training process +evaluate = True # evaluate after each epoch training process +test_only = False # test process + +seed = None # train process will init a random seed and record +save_path = "exp/default" +num_worker = 16 # total worker in all gpu +batch_size = 16 # total batch size in all gpu +batch_size_val = None # auto adapt to bs 1 for each gpu +batch_size_test = None # auto adapt to bs 1 for each gpu +epoch = 100 # total epoch, data loop = epoch // eval_epoch +eval_epoch = 100 # sche total eval & checkpoint epoch +clip_grad = None # disable with None, enable with a float + +sync_bn = False +enable_amp = False +empty_cache = False +empty_cache_per_epoch = False +find_unused_parameters = False + +mix_prob = 0 +param_dicts = None # example: param_dicts = [dict(keyword="block", lr_scale=0.1)] + +# hook +hooks = [ + dict(type="CheckpointLoader"), + dict(type="IterationTimer", warmup_iter=2), + dict(type="InformationWriter"), + dict(type="SemSegEvaluator"), + dict(type="CheckpointSaver", save_freq=None), + dict(type="PreciseEvaluator", test_last=False), +] + +# Trainer +train = dict(type="DefaultTrainer") + +# Tester +test = dict(type="SemSegTester", verbose=True) diff --git a/projects/PTv3/configs/semseg-pt-v3m1-0-base.py b/projects/PTv3/configs/semseg-pt-v3m1-0-base.py new file mode 100644 index 00000000..8be9debf --- /dev/null +++ b/projects/PTv3/configs/semseg-pt-v3m1-0-base.py @@ -0,0 +1,215 @@ +_base_ = ["./_base_/default_runtime.py"] + +# misc custom setting +batch_size = 4 # bs: total bs in all gpus +mix_prob = 0.8 +empty_cache = False +enable_amp = True + +# model settings +model = dict( + type="DefaultSegmentorV2", + num_classes=16, + backbone_out_channels=64, + backbone=dict( + type="PT-v3m1", + in_channels=4, + order=["z", "z-trans", "hilbert", "hilbert-trans"], + stride=(2, 2, 2, 2), + enc_depths=(2, 2, 2, 6, 2), + enc_channels=(32, 64, 128, 256, 512), + enc_num_head=(2, 4, 8, 16, 32), + enc_patch_size=(1024, 1024, 1024, 1024, 1024), + dec_depths=(2, 2, 2, 2), + dec_channels=(64, 64, 128, 256), + dec_num_head=(4, 4, 8, 16), + dec_patch_size=(1024, 1024, 1024, 1024), + mlp_ratio=4, + qkv_bias=True, + qk_scale=None, + attn_drop=0.0, + proj_drop=0.0, + drop_path=0.3, + shuffle_orders=True, + pre_norm=True, + enable_rpe=False, + enable_flash=True, + upcast_attention=False, + upcast_softmax=False, + cls_mode=False, + pdnorm_bn=False, + pdnorm_ln=False, + pdnorm_decouple=True, + pdnorm_adaptive=False, + pdnorm_affine=True, + pdnorm_conditions=("nuScenes", "SemanticKITTI", "Waymo"), + ), + criteria=[ + dict(type="CrossEntropyLoss", loss_weight=1.0, ignore_index=-1), + dict(type="LovaszLoss", mode="multiclass", loss_weight=1.0, ignore_index=-1), + ], +) + +# scheduler settings +epoch = 50 +eval_epoch = 50 +optimizer = dict(type="AdamW", lr=0.002, weight_decay=0.005) +scheduler = dict( + type="OneCycleLR", + max_lr=[0.002, 0.0002], + pct_start=0.04, + anneal_strategy="cos", + div_factor=10.0, + final_div_factor=100.0, +) +param_dicts = [dict(keyword="block", lr=0.0002)] + +# dataset settings +dataset_type = "NuScenesDataset" +data_root = "data/nuscenes" +ignore_index = -1 +names = [ + "barrier", + "bicycle", + "bus", + "car", + "construction_vehicle", + "motorcycle", + "pedestrian", + "traffic_cone", + "trailer", + "truck", + "driveable_surface", + "other_flat", + "sidewalk", + "terrain", + "manmade", + "vegetation", +] + +data = dict( + num_classes=16, + ignore_index=ignore_index, + names=names, + train=dict( + type=dataset_type, + split="train", + data_root=data_root, + transform=[ + # dict(type="RandomDropout", dropout_ratio=0.2, dropout_application_ratio=0.2), + # dict(type="RandomRotateTargetAngle", angle=(1/2, 1, 3/2), center=[0, 0, 0], axis="z", p=0.75), + dict(type="RandomRotate", angle=[-1, 1], axis="z", center=[0, 0, 0], p=0.5), + # dict(type="RandomRotate", angle=[-1/6, 1/6], axis="x", p=0.5), + # dict(type="RandomRotate", angle=[-1/6, 1/6], axis="y", p=0.5), + dict(type="RandomScale", scale=[0.9, 1.1]), + # dict(type="RandomShift", shift=[0.2, 0.2, 0.2]), + dict(type="RandomFlip", p=0.5), + dict(type="RandomJitter", sigma=0.005, clip=0.02), + # dict(type="ElasticDistortion", distortion_params=[[0.2, 0.4], [0.8, 1.6]]), + dict( + type="GridSample", + grid_size=0.05, + hash_type="fnv", + mode="train", + keys=("coord", "strength", "segment"), + return_grid_coord=True, + ), + # dict(type="SphereCrop", point_max=1000000, mode="random"), + # dict(type="CenterShift", apply_z=False), + dict(type="ToTensor"), + dict( + type="Collect", + keys=("coord", "grid_coord", "segment"), + feat_keys=("coord", "strength"), + ), + ], + test_mode=False, + ignore_index=ignore_index, + ), + val=dict( + type=dataset_type, + split="val", + data_root=data_root, + transform=[ + # dict(type="PointClip", point_cloud_range=(-51.2, -51.2, -4, 51.2, 51.2, 2.4)), + dict( + type="GridSample", + grid_size=0.05, + hash_type="fnv", + mode="train", + keys=("coord", "strength", "segment"), + return_grid_coord=True, + ), + # dict(type="SphereCrop", point_max=1000000, mode='center'), + dict(type="ToTensor"), + dict( + type="Collect", + keys=("coord", "grid_coord", "segment"), + feat_keys=("coord", "strength"), + ), + ], + test_mode=False, + ignore_index=ignore_index, + ), + test=dict( + type=dataset_type, + split="val", + data_root=data_root, + transform=[ + dict(type="Copy", keys_dict={"segment": "origin_segment"}), + dict( + type="GridSample", + grid_size=0.025, + hash_type="fnv", + mode="train", + keys=("coord", "strength", "segment"), + return_inverse=True, + ), + ], + test_mode=True, + test_cfg=dict( + voxelize=dict( + type="GridSample", + grid_size=0.05, + hash_type="fnv", + mode="test", + return_grid_coord=True, + keys=("coord", "strength"), + ), + crop=None, + post_transform=[ + dict(type="ToTensor"), + dict( + type="Collect", + keys=("coord", "grid_coord", "index"), + feat_keys=("coord", "strength"), + ), + ], + aug_transform=[ + [dict(type="RandomScale", scale=[0.9, 0.9])], + [dict(type="RandomScale", scale=[0.95, 0.95])], + [dict(type="RandomScale", scale=[1, 1])], + [dict(type="RandomScale", scale=[1.05, 1.05])], + [dict(type="RandomScale", scale=[1.1, 1.1])], + [ + dict(type="RandomScale", scale=[0.9, 0.9]), + dict(type="RandomFlip", p=1), + ], + [ + dict(type="RandomScale", scale=[0.95, 0.95]), + dict(type="RandomFlip", p=1), + ], + [dict(type="RandomScale", scale=[1, 1]), dict(type="RandomFlip", p=1)], + [ + dict(type="RandomScale", scale=[1.05, 1.05]), + dict(type="RandomFlip", p=1), + ], + [ + dict(type="RandomScale", scale=[1.1, 1.1]), + dict(type="RandomFlip", p=1), + ], + ], + ), + ignore_index=ignore_index, + ), +) diff --git a/projects/PTv3/configs/semseg-pt-v3m1-0-t4dataset.py b/projects/PTv3/configs/semseg-pt-v3m1-0-t4dataset.py new file mode 100644 index 00000000..65f95d36 --- /dev/null +++ b/projects/PTv3/configs/semseg-pt-v3m1-0-t4dataset.py @@ -0,0 +1,219 @@ +_base_ = ["./_base_/default_runtime.py"] + +# misc custom setting +batch_size = 1 # bs: total bs in all gpus +num_worker = 16 # total worker in all gpu +mix_prob = 0.8 +empty_cache = False +enable_amp = True + +grid_size = 0.1 # original is 0.05 +num_classes = 6 + +point_cloud_range = [-76.8, -76.8, -4, 76.8, 76.8, 8] + +# model settings +model = dict( + type="DefaultSegmentorV2", + num_classes=num_classes, + backbone_out_channels=64, + backbone=dict( + type="PT-v3m1", + in_channels=4, + order=["z", "z-trans", "hilbert", "hilbert-trans"], + stride=(2, 2, 2, 2), + enc_depths=(2, 2, 2, 6, 2), + enc_channels=(32, 64, 128, 256, 512), + enc_num_head=(2, 4, 8, 16, 32), + enc_patch_size=(1024, 1024, 1024, 1024, 1024), + dec_depths=(2, 2, 2, 2), + dec_channels=(64, 64, 128, 256), + dec_num_head=(4, 4, 8, 16), + dec_patch_size=(1024, 1024, 1024, 1024), + mlp_ratio=4, + qkv_bias=True, + qk_scale=None, + attn_drop=0.0, + proj_drop=0.0, + drop_path=0.3, + shuffle_orders=True, + pre_norm=True, + enable_rpe=False, + enable_flash=True, + upcast_attention=False, + upcast_softmax=False, + cls_mode=False, + pdnorm_bn=False, + pdnorm_ln=False, + pdnorm_decouple=True, + pdnorm_adaptive=False, + pdnorm_affine=True, + pdnorm_conditions=("nuScenes", "SemanticKITTI", "Waymo"), + ), + criteria=[ + dict(type="CrossEntropyLoss", loss_weight=1.0, ignore_index=-1), + dict(type="LovaszLoss", mode="multiclass", loss_weight=1.0, ignore_index=-1), + ], +) + +# scheduler settings +epoch = 50 +eval_epoch = 50 +optimizer = dict(type="AdamW", lr=0.002, weight_decay=0.005) +scheduler = dict( + type="OneCycleLR", + max_lr=[0.002, 0.0002], + pct_start=0.04, + anneal_strategy="cos", + div_factor=10.0, + final_div_factor=100.0, +) +param_dicts = [dict(keyword="block", lr=0.0002)] + +# dataset settings +dataset_type = "T4Dataset" +data_root = "data/t4dataset" +ignore_index = -1 +names = [ + "vehicle", + "bicycle", + "pedestrian", + "road", + "vegetation", + "obstacle", +] + +data = dict( + num_classes=num_classes, + ignore_index=ignore_index, + names=names, + train=dict( + type=dataset_type, + split="train", + data_root=data_root, + transform=[ + # dict(type="RandomDropout", dropout_ratio=0.2, dropout_application_ratio=0.2), + # dict(type="RandomRotateTargetAngle", angle=(1/2, 1, 3/2), center=[0, 0, 0], axis="z", p=0.75), + dict(type="RandomRotate", angle=[-1, 1], axis="z", center=[0, 0, 0], p=0.5), + # dict(type="RandomRotate", angle=[-1/6, 1/6], axis="x", p=0.5), + # dict(type="RandomRotate", angle=[-1/6, 1/6], axis="y", p=0.5), + dict(type="RandomScale", scale=[0.9, 1.1]), + dict( + type="PointClip", + point_cloud_range=point_cloud_range, + ), + # dict(type="RandomShift", shift=[0.2, 0.2, 0.2]), + dict(type="RandomFlip", p=0.5), + dict(type="RandomJitter", sigma=0.005, clip=0.02), + # dict(type="ElasticDistortion", distortion_params=[[0.2, 0.4], [0.8, 1.6]]), + dict( + type="GridSample", + grid_size=grid_size, + hash_type="fnv", + mode="train", + keys=("coord", "strength", "segment"), + return_grid_coord=True, + ), + dict(type="SphereCrop", point_max=128000, mode="random"), + # dict(type="CenterShift", apply_z=False), + dict(type="ToTensor"), + dict( + type="Collect", + keys=("coord", "grid_coord", "segment"), + feat_keys=("coord", "strength"), + ), + ], + test_mode=False, + ignore_index=ignore_index, + ), + val=dict( + type=dataset_type, + split="val", + data_root=data_root, + transform=[ + # dict(type="PointClip", point_cloud_range=(-51.2, -51.2, -4, 51.2, 51.2, 2.4)), + dict( + type="PointClip", + point_cloud_range=point_cloud_range, + ), + dict( + type="GridSample", + grid_size=grid_size, + hash_type="fnv", + mode="train", + keys=("coord", "strength", "segment"), + return_grid_coord=True, + ), + # dict(type="SphereCrop", point_max=1000000, mode='center'), + dict(type="ToTensor"), + dict( + type="Collect", + keys=("coord", "grid_coord", "segment"), + feat_keys=("coord", "strength"), + ), + ], + test_mode=False, + ignore_index=ignore_index, + ), + test=dict( + type=dataset_type, + split="val", + data_root=data_root, + transform=[ + dict(type="Copy", keys_dict={"segment": "origin_segment"}), + dict( + type="GridSample", + grid_size=grid_size, + hash_type="fnv", + mode="train", + keys=("coord", "strength", "segment"), + return_inverse=True, + ), + ], + test_mode=True, + test_cfg=dict( + voxelize=dict( + type="GridSample", + grid_size=grid_size, + hash_type="fnv", + mode="test", + return_grid_coord=True, + keys=("coord", "strength"), + ), + crop=None, + post_transform=[ + dict(type="ToTensor"), + dict( + type="Collect", + keys=("coord", "grid_coord", "index"), + feat_keys=("coord", "strength"), + ), + ], + aug_transform=[ + [dict(type="RandomScale", scale=[0.9, 0.9])], + [dict(type="RandomScale", scale=[0.95, 0.95])], + [dict(type="RandomScale", scale=[1, 1])], + [dict(type="RandomScale", scale=[1.05, 1.05])], + [dict(type="RandomScale", scale=[1.1, 1.1])], + [ + dict(type="RandomScale", scale=[0.9, 0.9]), + dict(type="RandomFlip", p=1), + ], + [ + dict(type="RandomScale", scale=[0.95, 0.95]), + dict(type="RandomFlip", p=1), + ], + [dict(type="RandomScale", scale=[1, 1]), dict(type="RandomFlip", p=1)], + [ + dict(type="RandomScale", scale=[1.05, 1.05]), + dict(type="RandomFlip", p=1), + ], + [ + dict(type="RandomScale", scale=[1.1, 1.1]), + dict(type="RandomFlip", p=1), + ], + ], + ), + ignore_index=ignore_index, + ), +) diff --git a/projects/PTv3/data b/projects/PTv3/data new file mode 120000 index 00000000..a44ed8fa --- /dev/null +++ b/projects/PTv3/data @@ -0,0 +1 @@ +../Pointcept/data \ No newline at end of file diff --git a/projects/PTv3/datasets/__init__.py b/projects/PTv3/datasets/__init__.py new file mode 100644 index 00000000..ddaf5047 --- /dev/null +++ b/projects/PTv3/datasets/__init__.py @@ -0,0 +1,10 @@ +from .builder import build_dataset + +# dataloader +from .dataloader import MultiDatasetDataloader +from .defaults import ConcatDataset, DefaultDataset + +# outdoor scene +from .nuscenes import NuScenesDataset +from .t4dataset import T4Dataset +from .utils import collate_fn, point_collate_fn diff --git a/projects/PTv3/datasets/builder.py b/projects/PTv3/datasets/builder.py new file mode 100644 index 00000000..4c609d54 --- /dev/null +++ b/projects/PTv3/datasets/builder.py @@ -0,0 +1,15 @@ +""" +Dataset Builder + +Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com) +Please cite our work if the code is helpful to you. +""" + +from utils.registry import Registry + +DATASETS = Registry("datasets") + + +def build_dataset(cfg): + """Build datasets.""" + return DATASETS.build(cfg) diff --git a/projects/PTv3/datasets/dataloader.py b/projects/PTv3/datasets/dataloader.py new file mode 100644 index 00000000..5becc4b6 --- /dev/null +++ b/projects/PTv3/datasets/dataloader.py @@ -0,0 +1,104 @@ +import weakref +from functools import partial + +import torch +import torch.utils.data +import utils.comm as comm +from datasets import ConcatDataset +from datasets.utils import point_collate_fn +from utils.env import set_seed + + +class MultiDatasetDummySampler: + def __init__(self): + self.dataloader = None + + def set_epoch(self, epoch): + if comm.get_world_size() > 1: + for dataloader in self.dataloader.dataloaders: + dataloader.sampler.set_epoch(epoch) + return + + +class MultiDatasetDataloader: + """ + Multiple Datasets Dataloader, batch data from a same dataset and mix up ratio determined by loop of each sub dataset. + The overall length is determined by the main dataset (first) and loop of concat dataset. + """ + + def __init__( + self, + concat_dataset: ConcatDataset, + batch_size_per_gpu: int, + num_worker_per_gpu: int, + mix_prob=0, + seed=None, + ): + self.datasets = concat_dataset.datasets + self.ratios = [dataset.loop for dataset in self.datasets] + # reset data loop, original loop serve as ratios + for dataset in self.datasets: + dataset.loop = 1 + # determine union training epoch by main dataset + self.datasets[0].loop = concat_dataset.loop + # build sub-dataloaders + num_workers = num_worker_per_gpu // len(self.datasets) + self.dataloaders = [] + for dataset_id, dataset in enumerate(self.datasets): + if comm.get_world_size() > 1: + sampler = torch.utils.data.distributed.DistributedSampler(dataset) + else: + sampler = None + + init_fn = ( + partial( + self._worker_init_fn, + dataset_id=dataset_id, + num_workers=num_workers, + num_datasets=len(self.datasets), + rank=comm.get_rank(), + seed=seed, + ) + if seed is not None + else None + ) + self.dataloaders.append( + torch.utils.data.DataLoader( + dataset, + batch_size=batch_size_per_gpu, + shuffle=(sampler is None), + num_workers=num_worker_per_gpu, + sampler=sampler, + collate_fn=partial(point_collate_fn, mix_prob=mix_prob), + pin_memory=True, + worker_init_fn=init_fn, + drop_last=True, + persistent_workers=True, + ) + ) + self.sampler = MultiDatasetDummySampler() + self.sampler.dataloader = weakref.proxy(self) + + def __iter__(self): + iterator = [iter(dataloader) for dataloader in self.dataloaders] + while True: + for i in range(len(self.ratios)): + for _ in range(self.ratios[i]): + try: + batch = next(iterator[i]) + except StopIteration: + if i == 0: + return + else: + iterator[i] = iter(self.dataloaders[i]) + batch = next(iterator[i]) + yield batch + + def __len__(self): + main_data_loader_length = len(self.dataloaders[0]) + return main_data_loader_length // self.ratios[0] * sum(self.ratios) + main_data_loader_length % self.ratios[0] + + @staticmethod + def _worker_init_fn(worker_id, num_workers, dataset_id, num_datasets, rank, seed): + worker_seed = num_workers * num_datasets * rank + num_workers * dataset_id + worker_id + seed + set_seed(worker_seed) diff --git a/projects/PTv3/datasets/defaults.py b/projects/PTv3/datasets/defaults.py new file mode 100644 index 00000000..776de283 --- /dev/null +++ b/projects/PTv3/datasets/defaults.py @@ -0,0 +1,194 @@ +""" +Default Datasets + +Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com) +Please cite our work if the code is helpful to you. +""" + +import glob +import os +from collections.abc import Sequence +from copy import deepcopy + +import numpy as np +import torch +from torch.utils.data import Dataset +from utils.cache import shared_dict +from utils.logger import get_root_logger + +from .builder import DATASETS, build_dataset +from .transform import TRANSFORMS, Compose + + +@DATASETS.register_module() +class DefaultDataset(Dataset): + VALID_ASSETS = [ + "coord", + "color", + "normal", + "strength", + "segment", + "instance", + "pose", + ] + + def __init__( + self, + split="train", + data_root="data/dataset", + transform=None, + test_mode=False, + test_cfg=None, + cache=False, + ignore_index=-1, + loop=1, + ): + super(DefaultDataset, self).__init__() + self.data_root = data_root + self.split = split + self.transform = Compose(transform) + self.cache = cache + self.ignore_index = ignore_index + self.loop = loop if not test_mode else 1 # force make loop = 1 while in test mode + self.test_mode = test_mode + self.test_cfg = test_cfg if test_mode else None + + if test_mode: + self.test_voxelize = TRANSFORMS.build(self.test_cfg.voxelize) + self.test_crop = TRANSFORMS.build(self.test_cfg.crop) if self.test_cfg.crop else None + self.post_transform = Compose(self.test_cfg.post_transform) + self.aug_transform = [Compose(aug) for aug in self.test_cfg.aug_transform] + + self.data_list = self.get_data_list() + logger = get_root_logger() + logger.info("Totally {} x {} samples in {} set.".format(len(self.data_list), self.loop, split)) + + def get_data_list(self): + if isinstance(self.split, str): + data_list = glob.glob(os.path.join(self.data_root, self.split, "*")) + elif isinstance(self.split, Sequence): + data_list = [] + for split in self.split: + data_list += glob.glob(os.path.join(self.data_root, split, "*")) + else: + raise NotImplementedError + return data_list + + def get_data(self, idx): + data_path = self.data_list[idx % len(self.data_list)] + name = self.get_data_name(idx) + if self.cache: + cache_name = f"pointcept-{name}" + return shared_dict(cache_name) + + data_dict = {} + assets = os.listdir(data_path) + for asset in assets: + if not asset.endswith(".npy"): + continue + if asset[:-4] not in self.VALID_ASSETS: + continue + data_dict[asset[:-4]] = np.load(os.path.join(data_path, asset)) + data_dict["name"] = name + + if "coord" in data_dict.keys(): + data_dict["coord"] = data_dict["coord"].astype(np.float32) + + if "color" in data_dict.keys(): + data_dict["color"] = data_dict["color"].astype(np.float32) + + if "normal" in data_dict.keys(): + data_dict["normal"] = data_dict["normal"].astype(np.float32) + + if "segment" in data_dict.keys(): + data_dict["segment"] = data_dict["segment"].reshape([-1]).astype(np.int32) + else: + data_dict["segment"] = np.ones(data_dict["coord"].shape[0], dtype=np.int32) * -1 + + if "instance" in data_dict.keys(): + data_dict["instance"] = data_dict["instance"].reshape([-1]).astype(np.int32) + else: + data_dict["instance"] = np.ones(data_dict["coord"].shape[0], dtype=np.int32) * -1 + return data_dict + + def get_data_name(self, idx): + return os.path.basename(self.data_list[idx % len(self.data_list)]) + + def prepare_train_data(self, idx): + # load data + data_dict = self.get_data(idx) + data_dict = self.transform(data_dict) + return data_dict + + def prepare_test_data(self, idx): + # load data + data_dict = self.get_data(idx) + data_dict = self.transform(data_dict) + result_dict = dict(segment=data_dict.pop("segment"), name=data_dict.pop("name")) + if "origin_segment" in data_dict: + assert "inverse" in data_dict + result_dict["origin_segment"] = data_dict.pop("origin_segment") + result_dict["inverse"] = data_dict.pop("inverse") + + data_dict_list = [] + for aug in self.aug_transform: + data_dict_list.append(aug(deepcopy(data_dict))) + + fragment_list = [] + for data in data_dict_list: + if self.test_voxelize is not None: + data_part_list = self.test_voxelize(data) + else: + data["index"] = np.arange(data["coord"].shape[0]) + data_part_list = [data] + for data_part in data_part_list: + if self.test_crop is not None: + data_part = self.test_crop(data_part) + else: + data_part = [data_part] + fragment_list += data_part + + for i in range(len(fragment_list)): + fragment_list[i] = self.post_transform(fragment_list[i]) + result_dict["fragment_list"] = fragment_list + return result_dict + + def __getitem__(self, idx): + if self.test_mode: + return self.prepare_test_data(idx) + else: + return self.prepare_train_data(idx) + + def __len__(self): + return len(self.data_list) * self.loop + + +@DATASETS.register_module() +class ConcatDataset(Dataset): + def __init__(self, datasets, loop=1): + super(ConcatDataset, self).__init__() + self.datasets = [build_dataset(dataset) for dataset in datasets] + self.loop = loop + self.data_list = self.get_data_list() + logger = get_root_logger() + logger.info("Totally {} x {} samples in the concat set.".format(len(self.data_list), self.loop)) + + def get_data_list(self): + data_list = [] + for i in range(len(self.datasets)): + data_list.extend(zip(np.ones(len(self.datasets[i])) * i, np.arange(len(self.datasets[i])))) + return data_list + + def get_data(self, idx): + dataset_idx, data_idx = self.data_list[idx % len(self.data_list)] + return self.datasets[dataset_idx][data_idx] + + def get_data_name(self, idx): + dataset_idx, data_idx = self.data_list[idx % len(self.data_list)] + return self.datasets[dataset_idx].get_data_name(data_idx) + + def __getitem__(self, idx): + return self.get_data(idx) + + def __len__(self): + return len(self.data_list) * self.loop diff --git a/projects/PTv3/datasets/nuscenes.py b/projects/PTv3/datasets/nuscenes.py new file mode 100644 index 00000000..a0fc41c8 --- /dev/null +++ b/projects/PTv3/datasets/nuscenes.py @@ -0,0 +1,112 @@ +""" +nuScenes Dataset + +Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com), Zheng Zhang +Please cite our work if the code is helpful to you. +""" + +import os +import pickle +from collections.abc import Sequence + +import numpy as np + +from .builder import DATASETS +from .defaults import DefaultDataset + + +@DATASETS.register_module() +class NuScenesDataset(DefaultDataset): + def __init__(self, sweeps=10, ignore_index=-1, **kwargs): + self.sweeps = sweeps + self.ignore_index = ignore_index + self.learning_map = self.get_learning_map(ignore_index) + super().__init__(ignore_index=ignore_index, **kwargs) + + def get_info_path(self, split): + assert split in ["train", "val", "test"] + if split == "train": + return os.path.join(self.data_root, "info", f"nuscenes_infos_{self.sweeps}sweeps_train.pkl") + elif split == "val": + return os.path.join(self.data_root, "info", f"nuscenes_infos_{self.sweeps}sweeps_val.pkl") + elif split == "test": + return os.path.join(self.data_root, "info", f"nuscenes_infos_{self.sweeps}sweeps_test.pkl") + else: + raise NotImplementedError + + def get_data_list(self): + if isinstance(self.split, str): + info_paths = [self.get_info_path(self.split)] + elif isinstance(self.split, Sequence): + info_paths = [self.get_info_path(s) for s in self.split] + else: + raise NotImplementedError + data_list = [] + for info_path in info_paths: + with open(info_path, "rb") as f: + info = pickle.load(f) + data_list.extend(info) + return data_list + + def get_data(self, idx): + data = self.data_list[idx % len(self.data_list)] + lidar_path = os.path.join(self.data_root, "raw", data["lidar_path"]) + points = np.fromfile(str(lidar_path), dtype=np.float32, count=-1).reshape([-1, 5]) + coord = points[:, :3] + strength = points[:, 3].reshape([-1, 1]) / 255 # scale strength to [0, 1] + + if "gt_segment_path" in data.keys(): + gt_segment_path = os.path.join(self.data_root, "raw", data["gt_segment_path"]) + segment = np.fromfile(str(gt_segment_path), dtype=np.uint8, count=-1).reshape([-1]) + segment = np.vectorize(self.learning_map.__getitem__)(segment).astype(np.int64) + else: + segment = np.ones((points.shape[0],), dtype=np.int64) * self.ignore_index + data_dict = dict( + coord=coord, + strength=strength, + segment=segment, + name=self.get_data_name(idx), + ) + return data_dict + + def get_data_name(self, idx): + # return data name for lidar seg, optimize the code when need to support detection + return self.data_list[idx % len(self.data_list)]["lidar_token"] + + @staticmethod + def get_learning_map(ignore_index): + learning_map = { + 0: ignore_index, + 1: ignore_index, + 2: 6, + 3: 6, + 4: 6, + 5: ignore_index, + 6: 6, + 7: ignore_index, + 8: ignore_index, + 9: 0, + 10: ignore_index, + 11: ignore_index, + 12: 7, + 13: ignore_index, + 14: 1, + 15: 2, + 16: 2, + 17: 3, + 18: 4, + 19: ignore_index, + 20: ignore_index, + 21: 5, + 22: 8, + 23: 9, + 24: 10, + 25: 11, + 26: 12, + 27: 13, + 28: 14, + 29: ignore_index, + 30: 15, + 31: ignore_index, + } + return learning_map diff --git a/projects/PTv3/datasets/preprocessing/nuscenes/preprocess_nuscenes_info.py b/projects/PTv3/datasets/preprocessing/nuscenes/preprocess_nuscenes_info.py new file mode 100644 index 00000000..86a2bc5d --- /dev/null +++ b/projects/PTv3/datasets/preprocessing/nuscenes/preprocess_nuscenes_info.py @@ -0,0 +1,560 @@ +""" +Preprocessing Script for nuScenes Informantion +modified from OpenPCDet (https://github.com/open-mmlab/OpenPCDet) + +Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com) +Please cite our work if the code is helpful to you. +""" + +import argparse +import os +import pickle +from functools import reduce +from pathlib import Path + +import numpy as np +import tqdm +from nuscenes.nuscenes import NuScenes +from nuscenes.utils import splits +from nuscenes.utils.geometry_utils import transform_matrix +from pyquaternion import Quaternion + +map_name_from_general_to_detection = { + "human.pedestrian.adult": "pedestrian", + "human.pedestrian.child": "pedestrian", + "human.pedestrian.wheelchair": "ignore", + "human.pedestrian.stroller": "ignore", + "human.pedestrian.personal_mobility": "ignore", + "human.pedestrian.police_officer": "pedestrian", + "human.pedestrian.construction_worker": "pedestrian", + "animal": "ignore", + "vehicle.car": "car", + "vehicle.motorcycle": "motorcycle", + "vehicle.bicycle": "bicycle", + "vehicle.bus.bendy": "bus", + "vehicle.bus.rigid": "bus", + "vehicle.truck": "truck", + "vehicle.construction": "construction_vehicle", + "vehicle.emergency.ambulance": "ignore", + "vehicle.emergency.police": "ignore", + "vehicle.trailer": "trailer", + "movable_object.barrier": "barrier", + "movable_object.trafficcone": "traffic_cone", + "movable_object.pushable_pullable": "ignore", + "movable_object.debris": "ignore", + "static_object.bicycle_rack": "ignore", +} + + +cls_attr_dist = { + "barrier": { + "cycle.with_rider": 0, + "cycle.without_rider": 0, + "pedestrian.moving": 0, + "pedestrian.sitting_lying_down": 0, + "pedestrian.standing": 0, + "vehicle.moving": 0, + "vehicle.parked": 0, + "vehicle.stopped": 0, + }, + "bicycle": { + "cycle.with_rider": 2791, + "cycle.without_rider": 8946, + "pedestrian.moving": 0, + "pedestrian.sitting_lying_down": 0, + "pedestrian.standing": 0, + "vehicle.moving": 0, + "vehicle.parked": 0, + "vehicle.stopped": 0, + }, + "bus": { + "cycle.with_rider": 0, + "cycle.without_rider": 0, + "pedestrian.moving": 0, + "pedestrian.sitting_lying_down": 0, + "pedestrian.standing": 0, + "vehicle.moving": 9092, + "vehicle.parked": 3294, + "vehicle.stopped": 3881, + }, + "car": { + "cycle.with_rider": 0, + "cycle.without_rider": 0, + "pedestrian.moving": 0, + "pedestrian.sitting_lying_down": 0, + "pedestrian.standing": 0, + "vehicle.moving": 114304, + "vehicle.parked": 330133, + "vehicle.stopped": 46898, + }, + "construction_vehicle": { + "cycle.with_rider": 0, + "cycle.without_rider": 0, + "pedestrian.moving": 0, + "pedestrian.sitting_lying_down": 0, + "pedestrian.standing": 0, + "vehicle.moving": 882, + "vehicle.parked": 11549, + "vehicle.stopped": 2102, + }, + "ignore": { + "cycle.with_rider": 307, + "cycle.without_rider": 73, + "pedestrian.moving": 0, + "pedestrian.sitting_lying_down": 0, + "pedestrian.standing": 0, + "vehicle.moving": 165, + "vehicle.parked": 400, + "vehicle.stopped": 102, + }, + "motorcycle": { + "cycle.with_rider": 4233, + "cycle.without_rider": 8326, + "pedestrian.moving": 0, + "pedestrian.sitting_lying_down": 0, + "pedestrian.standing": 0, + "vehicle.moving": 0, + "vehicle.parked": 0, + "vehicle.stopped": 0, + }, + "pedestrian": { + "cycle.with_rider": 0, + "cycle.without_rider": 0, + "pedestrian.moving": 157444, + "pedestrian.sitting_lying_down": 13939, + "pedestrian.standing": 46530, + "vehicle.moving": 0, + "vehicle.parked": 0, + "vehicle.stopped": 0, + }, + "traffic_cone": { + "cycle.with_rider": 0, + "cycle.without_rider": 0, + "pedestrian.moving": 0, + "pedestrian.sitting_lying_down": 0, + "pedestrian.standing": 0, + "vehicle.moving": 0, + "vehicle.parked": 0, + "vehicle.stopped": 0, + }, + "trailer": { + "cycle.with_rider": 0, + "cycle.without_rider": 0, + "pedestrian.moving": 0, + "pedestrian.sitting_lying_down": 0, + "pedestrian.standing": 0, + "vehicle.moving": 3421, + "vehicle.parked": 19224, + "vehicle.stopped": 1895, + }, + "truck": { + "cycle.with_rider": 0, + "cycle.without_rider": 0, + "pedestrian.moving": 0, + "pedestrian.sitting_lying_down": 0, + "pedestrian.standing": 0, + "vehicle.moving": 21339, + "vehicle.parked": 55626, + "vehicle.stopped": 11097, + }, +} + + +def get_available_scenes(nusc): + available_scenes = [] + for scene in nusc.scene: + scene_token = scene["token"] + scene_rec = nusc.get("scene", scene_token) + sample_rec = nusc.get("sample", scene_rec["first_sample_token"]) + sd_rec = nusc.get("sample_data", sample_rec["data"]["LIDAR_TOP"]) + has_more_frames = True + scene_not_exist = False + while has_more_frames: + lidar_path, boxes, _ = nusc.get_sample_data(sd_rec["token"]) + if not Path(lidar_path).exists(): + scene_not_exist = True + break + else: + break + if scene_not_exist: + continue + available_scenes.append(scene) + return available_scenes + + +def get_sample_data(nusc, sample_data_token, selected_anntokens=None): + """ + Returns the data path as well as all annotations related to that sample_data. + Note that the boxes are transformed into the current sensor"s coordinate frame. + Args: + nusc: + sample_data_token: Sample_data token. + selected_anntokens: If provided only return the selected annotation. + + Returns: + + """ + # Retrieve sensor & pose records + sd_record = nusc.get("sample_data", sample_data_token) + cs_record = nusc.get("calibrated_sensor", sd_record["calibrated_sensor_token"]) + sensor_record = nusc.get("sensor", cs_record["sensor_token"]) + pose_record = nusc.get("ego_pose", sd_record["ego_pose_token"]) + + data_path = nusc.get_sample_data_path(sample_data_token) + + if sensor_record["modality"] == "camera": + cam_intrinsic = np.array(cs_record["camera_intrinsic"]) + else: + cam_intrinsic = None + + # Retrieve all sample annotations and map to sensor coordinate system. + if selected_anntokens is not None: + boxes = list(map(nusc.get_box, selected_anntokens)) + else: + boxes = nusc.get_boxes(sample_data_token) + + # Make list of Box objects including coord system transforms. + box_list = [] + for box in boxes: + box.velocity = nusc.box_velocity(box.token) + # Move box to ego vehicle coord system + box.translate(-np.array(pose_record["translation"])) + box.rotate(Quaternion(pose_record["rotation"]).inverse) + + # Move box to sensor coord system + box.translate(-np.array(cs_record["translation"])) + box.rotate(Quaternion(cs_record["rotation"]).inverse) + + box_list.append(box) + + return data_path, box_list, cam_intrinsic + + +def quaternion_yaw(q: Quaternion) -> float: + """ + Calculate the yaw angle from a quaternion. + Note that this only works for a quaternion that represents a box in lidar or global coordinate frame. + It does not work for a box in the camera frame. + :param q: Quaternion of interest. + :return: Yaw angle in radians. + """ + + # Project into xy plane. + v = np.dot(q.rotation_matrix, np.array([1, 0, 0])) + + # Measure yaw using arctan. + yaw = np.arctan2(v[1], v[0]) + + return yaw + + +def obtain_sensor2top(nusc, sensor_token, l2e_t, l2e_r_mat, e2g_t, e2g_r_mat, sensor_type="lidar"): + """Obtain the info with RT matric from general sensor to Top LiDAR. + + Args: + nusc (class): Dataset class in the nuScenes dataset. + sensor_token (str): Sample data token corresponding to the + specific sensor type. + l2e_t (np.ndarray): Translation from lidar to ego in shape (1, 3). + l2e_r_mat (np.ndarray): Rotation matrix from lidar to ego + in shape (3, 3). + e2g_t (np.ndarray): Translation from ego to global in shape (1, 3). + e2g_r_mat (np.ndarray): Rotation matrix from ego to global + in shape (3, 3). + sensor_type (str): Sensor to calibrate. Default: "lidar". + + Returns: + sweep (dict): Sweep information after transformation. + """ + sd_rec = nusc.get("sample_data", sensor_token) + cs_record = nusc.get("calibrated_sensor", sd_rec["calibrated_sensor_token"]) + pose_record = nusc.get("ego_pose", sd_rec["ego_pose_token"]) + data_path = str(nusc.get_sample_data_path(sd_rec["token"])) + # if os.getcwd() in data_path: # path from lyftdataset is absolute path + # data_path = data_path.split(f"{os.getcwd()}/")[-1] # relative path + sweep = { + "data_path": data_path, + "type": sensor_type, + "sample_data_token": sd_rec["token"], + "sensor2ego_translation": cs_record["translation"], + "sensor2ego_rotation": cs_record["rotation"], + "ego2global_translation": pose_record["translation"], + "ego2global_rotation": pose_record["rotation"], + "timestamp": sd_rec["timestamp"], + } + l2e_r_s = sweep["sensor2ego_rotation"] + l2e_t_s = sweep["sensor2ego_translation"] + e2g_r_s = sweep["ego2global_rotation"] + e2g_t_s = sweep["ego2global_translation"] + + # obtain the RT from sensor to Top LiDAR + # sweep->ego->global->ego'->lidar + l2e_r_s_mat = Quaternion(l2e_r_s).rotation_matrix + e2g_r_s_mat = Quaternion(e2g_r_s).rotation_matrix + R = (l2e_r_s_mat.T @ e2g_r_s_mat.T) @ (np.linalg.inv(e2g_r_mat).T @ np.linalg.inv(l2e_r_mat).T) + T = (l2e_t_s @ e2g_r_s_mat.T + e2g_t_s) @ (np.linalg.inv(e2g_r_mat).T @ np.linalg.inv(l2e_r_mat).T) + T -= ( + e2g_t @ (np.linalg.inv(e2g_r_mat).T @ np.linalg.inv(l2e_r_mat).T) + l2e_t @ np.linalg.inv(l2e_r_mat).T + ).squeeze(0) + sweep["sensor2lidar_rotation"] = R.T # points @ R.T + T + sweep["sensor2lidar_translation"] = T + return sweep + + +def fill_trainval_infos(data_path, nusc, train_scenes, test=False, max_sweeps=10, with_camera=False): + train_nusc_infos = [] + val_nusc_infos = [] + progress_bar = tqdm.tqdm(total=len(nusc.sample), desc="create_info", dynamic_ncols=True) + + ref_chan = "LIDAR_TOP" # The radar channel from which we track back n sweeps to aggregate the point cloud. + chan = "LIDAR_TOP" # The reference channel of the current sample_rec that the point clouds are mapped to. + + for index, sample in enumerate(nusc.sample): + progress_bar.update() + + ref_sd_token = sample["data"][ref_chan] + ref_sd_rec = nusc.get("sample_data", ref_sd_token) + ref_cs_rec = nusc.get("calibrated_sensor", ref_sd_rec["calibrated_sensor_token"]) + ref_pose_rec = nusc.get("ego_pose", ref_sd_rec["ego_pose_token"]) + ref_time = 1e-6 * ref_sd_rec["timestamp"] + + ref_lidar_path, ref_boxes, _ = get_sample_data(nusc, ref_sd_token) + + ref_cam_front_token = sample["data"]["CAM_FRONT"] + ref_cam_path, _, ref_cam_intrinsic = nusc.get_sample_data(ref_cam_front_token) + + # Homogeneous transform from ego car frame to reference frame + ref_from_car = transform_matrix(ref_cs_rec["translation"], Quaternion(ref_cs_rec["rotation"]), inverse=True) + + # Homogeneous transformation matrix from global to _current_ ego car frame + car_from_global = transform_matrix( + ref_pose_rec["translation"], + Quaternion(ref_pose_rec["rotation"]), + inverse=True, + ) + info = { + "lidar_path": Path(ref_lidar_path).relative_to(data_path).__str__(), + "lidar_token": ref_sd_token, + "cam_front_path": Path(ref_cam_path).relative_to(data_path).__str__(), + "cam_intrinsic": ref_cam_intrinsic, + "token": sample["token"], + "sweeps": [], + "ref_from_car": ref_from_car, + "car_from_global": car_from_global, + "timestamp": ref_time, + } + if with_camera: + info["cams"] = dict() + l2e_r = ref_cs_rec["rotation"] + l2e_t = (ref_cs_rec["translation"],) + e2g_r = ref_pose_rec["rotation"] + e2g_t = ref_pose_rec["translation"] + l2e_r_mat = Quaternion(l2e_r).rotation_matrix + e2g_r_mat = Quaternion(e2g_r).rotation_matrix + + # obtain 6 image's information per frame + camera_types = [ + "CAM_FRONT", + "CAM_FRONT_RIGHT", + "CAM_FRONT_LEFT", + "CAM_BACK", + "CAM_BACK_LEFT", + "CAM_BACK_RIGHT", + ] + for cam in camera_types: + cam_token = sample["data"][cam] + cam_path, _, camera_intrinsics = nusc.get_sample_data(cam_token) + cam_info = obtain_sensor2top(nusc, cam_token, l2e_t, l2e_r_mat, e2g_t, e2g_r_mat, cam) + cam_info["data_path"] = Path(cam_info["data_path"]).relative_to(data_path).__str__() + cam_info.update(camera_intrinsics=camera_intrinsics) + info["cams"].update({cam: cam_info}) + + sample_data_token = sample["data"][chan] + curr_sd_rec = nusc.get("sample_data", sample_data_token) + sweeps = [] + while len(sweeps) < max_sweeps - 1: + if curr_sd_rec["prev"] == "": + if len(sweeps) == 0: + sweep = { + "lidar_path": Path(ref_lidar_path).relative_to(data_path).__str__(), + "sample_data_token": curr_sd_rec["token"], + "transform_matrix": None, + "time_lag": curr_sd_rec["timestamp"] * 0, + } + sweeps.append(sweep) + else: + sweeps.append(sweeps[-1]) + else: + curr_sd_rec = nusc.get("sample_data", curr_sd_rec["prev"]) + + # Get past pose + current_pose_rec = nusc.get("ego_pose", curr_sd_rec["ego_pose_token"]) + global_from_car = transform_matrix( + current_pose_rec["translation"], + Quaternion(current_pose_rec["rotation"]), + inverse=False, + ) + + # Homogeneous transformation matrix from sensor coordinate frame to ego car frame. + current_cs_rec = nusc.get("calibrated_sensor", curr_sd_rec["calibrated_sensor_token"]) + car_from_current = transform_matrix( + current_cs_rec["translation"], + Quaternion(current_cs_rec["rotation"]), + inverse=False, + ) + + tm = reduce( + np.dot, + [ref_from_car, car_from_global, global_from_car, car_from_current], + ) + + lidar_path = nusc.get_sample_data_path(curr_sd_rec["token"]) + + time_lag = ref_time - 1e-6 * curr_sd_rec["timestamp"] + + sweep = { + "lidar_path": Path(lidar_path).relative_to(data_path).__str__(), + "sample_data_token": curr_sd_rec["token"], + "transform_matrix": tm, + "global_from_car": global_from_car, + "car_from_current": car_from_current, + "time_lag": time_lag, + } + sweeps.append(sweep) + + info["sweeps"] = sweeps + + assert len(info["sweeps"]) == max_sweeps - 1, ( + f"sweep {curr_sd_rec['token']} only has {len(info['sweeps'])} sweeps, " + f"you should duplicate to sweep num {max_sweeps - 1}" + ) + + if not test: + # processing gt bbox + annotations = [nusc.get("sample_annotation", token) for token in sample["anns"]] + + # the filtering gives 0.5~1 map improvement + num_lidar_pts = np.array([anno["num_lidar_pts"] for anno in annotations]) + num_radar_pts = np.array([anno["num_radar_pts"] for anno in annotations]) + mask = num_lidar_pts + num_radar_pts > 0 + + locs = np.array([b.center for b in ref_boxes]).reshape(-1, 3) + dims = np.array([b.wlh for b in ref_boxes]).reshape(-1, 3)[:, [1, 0, 2]] # wlh == > dxdydz (lwh) + velocity = np.array([b.velocity for b in ref_boxes]).reshape(-1, 3) + rots = np.array([quaternion_yaw(b.orientation) for b in ref_boxes]).reshape(-1, 1) + names = np.array([b.name for b in ref_boxes]) + tokens = np.array([b.token for b in ref_boxes]) + gt_boxes = np.concatenate([locs, dims, rots, velocity[:, :2]], axis=1) + + assert len(annotations) == len(gt_boxes) == len(velocity) + + info["gt_boxes"] = gt_boxes[mask, :] + info["gt_boxes_velocity"] = velocity[mask, :] + info["gt_names"] = np.array([map_name_from_general_to_detection[name] for name in names])[mask] + info["gt_boxes_token"] = tokens[mask] + info["num_lidar_pts"] = num_lidar_pts[mask] + info["num_radar_pts"] = num_radar_pts[mask] + + # processing gt segment + segment_path = nusc.get("lidarseg", ref_sd_token)["filename"] + info["gt_segment_path"] = segment_path + + if sample["scene_token"] in train_scenes: + train_nusc_infos.append(info) + else: + val_nusc_infos.append(info) + + progress_bar.close() + return train_nusc_infos, val_nusc_infos + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--dataset_root", required=True, help="Path to the nuScenes dataset.") + parser.add_argument( + "--output_root", + required=True, + help="Output path where processed information located.", + ) + parser.add_argument("--max_sweeps", default=10, type=int, help="Max number of sweeps. Default: 10.") + parser.add_argument( + "--with_camera", + action="store_true", + default=False, + help="Whether use camera or not.", + ) + config = parser.parse_args() + + print(f"Loading nuScenes tables for version v1.0-trainval...") + nusc_trainval = NuScenes(version="v1.0-trainval", dataroot=config.dataset_root, verbose=False) + available_scenes_trainval = get_available_scenes(nusc_trainval) + available_scene_names_trainval = [s["name"] for s in available_scenes_trainval] + print("total scene num:", len(nusc_trainval.scene)) + print("exist scene num:", len(available_scenes_trainval)) + assert len(available_scenes_trainval) == len(nusc_trainval.scene) == 850 + + print(f"Loading nuScenes tables for version v1.0-test...") + nusc_test = NuScenes(version="v1.0-test", dataroot=config.dataset_root, verbose=False) + available_scenes_test = get_available_scenes(nusc_test) + available_scene_names_test = [s["name"] for s in available_scenes_test] + print("total scene num:", len(nusc_test.scene)) + print("exist scene num:", len(available_scenes_test)) + assert len(available_scenes_test) == len(nusc_test.scene) == 150 + + train_scenes = splits.train + train_scenes = set( + [available_scenes_trainval[available_scene_names_trainval.index(s)]["token"] for s in train_scenes] + ) + test_scenes = splits.test + test_scenes = set([available_scenes_test[available_scene_names_test.index(s)]["token"] for s in test_scenes]) + print(f"Filling trainval information...") + train_nusc_infos, val_nusc_infos = fill_trainval_infos( + config.dataset_root, + nusc_trainval, + train_scenes, + test=False, + max_sweeps=config.max_sweeps, + with_camera=config.with_camera, + ) + print(f"Filling test information...") + test_nusc_infos, _ = fill_trainval_infos( + config.dataset_root, + nusc_test, + test_scenes, + test=True, + max_sweeps=config.max_sweeps, + with_camera=config.with_camera, + ) + + print(f"Saving nuScenes information...") + os.makedirs(os.path.join(config.output_root, "info"), exist_ok=True) + print( + f"train sample: {len(train_nusc_infos)}, val sample: {len(val_nusc_infos)}, test sample: {len(test_nusc_infos)}" + ) + with open( + os.path.join( + config.output_root, + "info", + f"nuscenes_infos_{config.max_sweeps}sweeps_train.pkl", + ), + "wb", + ) as f: + pickle.dump(train_nusc_infos, f) + with open( + os.path.join( + config.output_root, + "info", + f"nuscenes_infos_{config.max_sweeps}sweeps_val.pkl", + ), + "wb", + ) as f: + pickle.dump(val_nusc_infos, f) + with open( + os.path.join( + config.output_root, + "info", + f"nuscenes_infos_{config.max_sweeps}sweeps_test.pkl", + ), + "wb", + ) as f: + pickle.dump(test_nusc_infos, f) diff --git a/projects/PTv3/datasets/t4dataset.py b/projects/PTv3/datasets/t4dataset.py new file mode 100644 index 00000000..073cb4c8 --- /dev/null +++ b/projects/PTv3/datasets/t4dataset.py @@ -0,0 +1,89 @@ +""" +nuScenes Dataset + +Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com), Zheng Zhang +Please cite our work if the code is helpful to you. +""" + +import os +import pickle +from collections.abc import Sequence +from pathlib import Path + +import numpy as np + +from .builder import DATASETS +from .defaults import DefaultDataset + + +@DATASETS.register_module() +class T4Dataset(DefaultDataset): + def __init__(self, sweeps=10, ignore_index=-1, **kwargs): + self.sweeps = sweeps + self.ignore_index = ignore_index + self.learning_map = self.get_learning_map(ignore_index) + super().__init__(ignore_index=ignore_index, **kwargs) + + def get_info_path(self, split): + assert split in ["train", "val", "test"] + if split == "train": + return os.path.join(self.data_root, "info/kenzo_all", f"t4dataset_xx1_infos_train.pkl") + elif split == "val": + return os.path.join(self.data_root, "info/kenzo_all", f"t4dataset_xx1_infos_val.pkl") + elif split == "test": + return os.path.join(self.data_root, "info", f"t4dataset_xx1_infos_test.pkl") + else: + raise NotImplementedError + + def get_data_list(self): + if isinstance(self.split, str): + info_paths = [self.get_info_path(self.split)] + elif isinstance(self.split, Sequence): + info_paths = [self.get_info_path(s) for s in self.split] + else: + raise NotImplementedError + data_list = [] + for info_path in info_paths: + with open(info_path, "rb") as f: + info = pickle.load(f) + data_list.extend(info["data_list"]) + return data_list + + def get_data(self, idx): + data = self.data_list[idx % len(self.data_list)] + lidar_path = os.path.join(self.data_root, data["lidar_points"]["lidar_path"]) + points = np.fromfile(str(lidar_path), dtype=np.float32, count=-1).reshape([-1, 5]) + coord = points[:, :3] + strength = points[:, 3].reshape([-1, 1]) / 255 # scale strength to [0, 1] + + lidar_path = Path(lidar_path) + basename = lidar_path.name.split(".")[0] + seg_path = lidar_path.parent / f"{basename}_seg.npy" + + segment = np.load(str(seg_path)).reshape([-1]) + segment = np.vectorize(self.learning_map.__getitem__)(segment).astype(np.int64) + + data_dict = dict( + coord=coord, + strength=strength, + segment=segment, + name=self.get_data_name(idx), + ) + return data_dict + + def get_data_name(self, idx): + # return data name for lidar seg, optimize the code when need to support detection + return self.data_list[idx % len(self.data_list)]["token"] + + @staticmethod + def get_learning_map(ignore_index): + learning_map = { + 0: 0, + 1: 1, + 2: 2, + 3: 3, + 4: 4, + 5: 5, + 255: ignore_index, + } + return learning_map diff --git a/projects/PTv3/datasets/transform.py b/projects/PTv3/datasets/transform.py new file mode 100644 index 00000000..fb911adc --- /dev/null +++ b/projects/PTv3/datasets/transform.py @@ -0,0 +1,477 @@ +""" +3D Point Cloud Augmentation + +Inspirited by chrischoy/SpatioTemporalSegmentation + +Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com) +Please cite our work if the code is helpful to you. +""" + +import copy +import numbers +import random +from collections.abc import Mapping, Sequence + +import numpy as np +import torch +from utils.registry import Registry + +TRANSFORMS = Registry("transforms") + + +@TRANSFORMS.register_module() +class Collect(object): + def __init__(self, keys, offset_keys_dict=None, **kwargs): + """ + e.g. Collect(keys=[coord], feat_keys=[coord, color]) + """ + if offset_keys_dict is None: + offset_keys_dict = dict(offset="coord") + self.keys = keys + self.offset_keys = offset_keys_dict + self.kwargs = kwargs + + def __call__(self, data_dict): + data = dict() + if isinstance(self.keys, str): + self.keys = [self.keys] + for key in self.keys: + data[key] = data_dict[key] + for key, value in self.offset_keys.items(): + data[key] = torch.tensor([data_dict[value].shape[0]]) + for name, keys in self.kwargs.items(): + name = name.replace("_keys", "") + assert isinstance(keys, Sequence) + data[name] = torch.cat([data_dict[key].float() for key in keys], dim=1) + return data + + +@TRANSFORMS.register_module() +class Copy(object): + def __init__(self, keys_dict=None): + if keys_dict is None: + keys_dict = dict(coord="origin_coord", segment="origin_segment") + self.keys_dict = keys_dict + + def __call__(self, data_dict): + for key, value in self.keys_dict.items(): + if isinstance(data_dict[key], np.ndarray): + data_dict[value] = data_dict[key].copy() + elif isinstance(data_dict[key], torch.Tensor): + data_dict[value] = data_dict[key].clone().detach() + else: + data_dict[value] = copy.deepcopy(data_dict[key]) + return data_dict + + +@TRANSFORMS.register_module() +class ToTensor(object): + def __call__(self, data): + if isinstance(data, torch.Tensor): + return data + elif isinstance(data, str): + # note that str is also a kind of sequence, judgement should before sequence + return data + elif isinstance(data, int): + return torch.LongTensor([data]) + elif isinstance(data, float): + return torch.FloatTensor([data]) + elif isinstance(data, np.ndarray) and np.issubdtype(data.dtype, bool): + return torch.from_numpy(data) + elif isinstance(data, np.ndarray) and np.issubdtype(data.dtype, np.integer): + return torch.from_numpy(data).long() + elif isinstance(data, np.ndarray) and np.issubdtype(data.dtype, np.floating): + return torch.from_numpy(data).float() + elif isinstance(data, Mapping): + result = {sub_key: self(item) for sub_key, item in data.items()} + return result + elif isinstance(data, Sequence): + result = [self(item) for item in data] + return result + else: + raise TypeError(f"type {type(data)} cannot be converted to tensor.") + + +@TRANSFORMS.register_module() +class PointClip(object): + def __init__(self, point_cloud_range=(-80, -80, -3, 80, 80, 1)): + self.point_cloud_range = point_cloud_range + + def __call__(self, data_dict): + if "coord" in data_dict.keys(): + data_dict["coord"] = np.clip( + data_dict["coord"], + a_min=self.point_cloud_range[:3], + a_max=self.point_cloud_range[3:], + ) + return data_dict + + +@TRANSFORMS.register_module() +class RandomDropout(object): + def __init__(self, dropout_ratio=0.2, dropout_application_ratio=0.5): + """ + upright_axis: axis index among x,y,z, i.e. 2 for z + """ + self.dropout_ratio = dropout_ratio + self.dropout_application_ratio = dropout_application_ratio + + def __call__(self, data_dict): + if random.random() < self.dropout_application_ratio: + n = len(data_dict["coord"]) + idx = np.random.choice(n, int(n * (1 - self.dropout_ratio)), replace=False) + if "sampled_index" in data_dict: + # for ScanNet data efficient, we need to make sure labeled point is sampled. + idx = np.unique(np.append(idx, data_dict["sampled_index"])) + mask = np.zeros_like(data_dict["segment"]).astype(bool) + mask[data_dict["sampled_index"]] = True + data_dict["sampled_index"] = np.where(mask[idx])[0] + if "coord" in data_dict.keys(): + data_dict["coord"] = data_dict["coord"][idx] + if "color" in data_dict.keys(): + data_dict["color"] = data_dict["color"][idx] + if "normal" in data_dict.keys(): + data_dict["normal"] = data_dict["normal"][idx] + if "strength" in data_dict.keys(): + data_dict["strength"] = data_dict["strength"][idx] + if "segment" in data_dict.keys(): + data_dict["segment"] = data_dict["segment"][idx] + if "instance" in data_dict.keys(): + data_dict["instance"] = data_dict["instance"][idx] + return data_dict + + +@TRANSFORMS.register_module() +class RandomRotate(object): + def __init__(self, angle=None, center=None, axis="z", always_apply=False, p=0.5): + self.angle = [-1, 1] if angle is None else angle + self.axis = axis + self.always_apply = always_apply + self.p = p if not self.always_apply else 1 + self.center = center + + def __call__(self, data_dict): + if random.random() > self.p: + return data_dict + angle = np.random.uniform(self.angle[0], self.angle[1]) * np.pi + rot_cos, rot_sin = np.cos(angle), np.sin(angle) + if self.axis == "x": + rot_t = np.array([[1, 0, 0], [0, rot_cos, -rot_sin], [0, rot_sin, rot_cos]]) + elif self.axis == "y": + rot_t = np.array([[rot_cos, 0, rot_sin], [0, 1, 0], [-rot_sin, 0, rot_cos]]) + elif self.axis == "z": + rot_t = np.array([[rot_cos, -rot_sin, 0], [rot_sin, rot_cos, 0], [0, 0, 1]]) + else: + raise NotImplementedError + if "coord" in data_dict.keys(): + if self.center is None: + x_min, y_min, z_min = data_dict["coord"].min(axis=0) + x_max, y_max, z_max = data_dict["coord"].max(axis=0) + center = [(x_min + x_max) / 2, (y_min + y_max) / 2, (z_min + z_max) / 2] + else: + center = self.center + data_dict["coord"] -= center + data_dict["coord"] = np.dot(data_dict["coord"], np.transpose(rot_t)) + data_dict["coord"] += center + if "normal" in data_dict.keys(): + data_dict["normal"] = np.dot(data_dict["normal"], np.transpose(rot_t)) + return data_dict + + +@TRANSFORMS.register_module() +class RandomScale(object): + def __init__(self, scale=None, anisotropic=False): + self.scale = scale if scale is not None else [0.95, 1.05] + self.anisotropic = anisotropic + + def __call__(self, data_dict): + if "coord" in data_dict.keys(): + scale = np.random.uniform(self.scale[0], self.scale[1], 3 if self.anisotropic else 1) + data_dict["coord"] *= scale + return data_dict + + +@TRANSFORMS.register_module() +class RandomFlip(object): + def __init__(self, p=0.5): + self.p = p + + def __call__(self, data_dict): + if np.random.rand() < self.p: + if "coord" in data_dict.keys(): + data_dict["coord"][:, 0] = -data_dict["coord"][:, 0] + if "normal" in data_dict.keys(): + data_dict["normal"][:, 0] = -data_dict["normal"][:, 0] + if np.random.rand() < self.p: + if "coord" in data_dict.keys(): + data_dict["coord"][:, 1] = -data_dict["coord"][:, 1] + if "normal" in data_dict.keys(): + data_dict["normal"][:, 1] = -data_dict["normal"][:, 1] + return data_dict + + +@TRANSFORMS.register_module() +class RandomJitter(object): + def __init__(self, sigma=0.01, clip=0.05): + assert clip > 0 + self.sigma = sigma + self.clip = clip + + def __call__(self, data_dict): + if "coord" in data_dict.keys(): + jitter = np.clip( + self.sigma * np.random.randn(data_dict["coord"].shape[0], 3), + -self.clip, + self.clip, + ) + data_dict["coord"] += jitter + return data_dict + + +@TRANSFORMS.register_module() +class GridSample(object): + def __init__( + self, + grid_size=0.05, + hash_type="fnv", + mode="train", + keys=("coord", "color", "normal", "segment"), + return_inverse=False, + return_grid_coord=False, + return_min_coord=False, + return_displacement=False, + project_displacement=False, + ): + self.grid_size = grid_size + self.hash = self.fnv_hash_vec if hash_type == "fnv" else self.ravel_hash_vec + assert mode in ["train", "test"] + self.mode = mode + self.keys = keys + self.return_inverse = return_inverse + self.return_grid_coord = return_grid_coord + self.return_min_coord = return_min_coord + self.return_displacement = return_displacement + self.project_displacement = project_displacement + + def __call__(self, data_dict): + assert "coord" in data_dict.keys() + scaled_coord = data_dict["coord"].astype(np.float32) / np.array(self.grid_size).astype(np.float32) + grid_coord = np.floor(scaled_coord).astype(int) + min_coord = grid_coord.min(0) + grid_coord -= min_coord + scaled_coord -= min_coord + min_coord = min_coord * np.array(self.grid_size) + key = self.hash(grid_coord) + idx_sort = np.argsort(key) + key_sort = key[idx_sort] + _, inverse, count = np.unique(key_sort, return_inverse=True, return_counts=True) + if self.mode == "train": # train mode + idx_select = ( + np.cumsum(np.insert(count, 0, 0)[0:-1]) + np.random.randint(0, count.max(), count.size) % count + ) + idx_unique = idx_sort[idx_select] + if "sampled_index" in data_dict: + # for ScanNet data efficient, we need to make sure labeled point is sampled. + idx_unique = np.unique(np.append(idx_unique, data_dict["sampled_index"])) + mask = np.zeros_like(data_dict["segment"]).astype(bool) + mask[data_dict["sampled_index"]] = True + data_dict["sampled_index"] = np.where(mask[idx_unique])[0] + if self.return_inverse: + data_dict["inverse"] = np.zeros_like(inverse) + data_dict["inverse"][idx_sort] = inverse + if self.return_grid_coord: + data_dict["grid_coord"] = grid_coord[idx_unique] + if self.return_min_coord: + data_dict["min_coord"] = min_coord.reshape([1, 3]) + if self.return_displacement: + displacement = scaled_coord - grid_coord - 0.5 # [0, 1] -> [-0.5, 0.5] displacement to center + if self.project_displacement: + displacement = np.sum(displacement * data_dict["normal"], axis=-1, keepdims=True) + data_dict["displacement"] = displacement[idx_unique] + for key in self.keys: + data_dict[key] = data_dict[key][idx_unique] + return data_dict + + elif self.mode == "test": # test mode + data_part_list = [] + for i in range(count.max()): + idx_select = np.cumsum(np.insert(count, 0, 0)[0:-1]) + i % count + idx_part = idx_sort[idx_select] + data_part = dict(index=idx_part) + if self.return_inverse: + data_dict["inverse"] = np.zeros_like(inverse) + data_dict["inverse"][idx_sort] = inverse + if self.return_grid_coord: + data_part["grid_coord"] = grid_coord[idx_part] + if self.return_min_coord: + data_part["min_coord"] = min_coord.reshape([1, 3]) + if self.return_displacement: + displacement = scaled_coord - grid_coord - 0.5 # [0, 1] -> [-0.5, 0.5] displacement to center + if self.project_displacement: + displacement = np.sum(displacement * data_dict["normal"], axis=-1, keepdims=True) + data_dict["displacement"] = displacement[idx_part] + for key in data_dict.keys(): + if key in self.keys: + data_part[key] = data_dict[key][idx_part] + else: + data_part[key] = data_dict[key] + data_part_list.append(data_part) + return data_part_list + else: + raise NotImplementedError + + @staticmethod + def ravel_hash_vec(arr): + """ + Ravel the coordinates after subtracting the min coordinates. + """ + assert arr.ndim == 2 + arr = arr.copy() + arr -= arr.min(0) + arr = arr.astype(np.uint64, copy=False) + arr_max = arr.max(0).astype(np.uint64) + 1 + + keys = np.zeros(arr.shape[0], dtype=np.uint64) + # Fortran style indexing + for j in range(arr.shape[1] - 1): + keys += arr[:, j] + keys *= arr_max[j + 1] + keys += arr[:, -1] + return keys + + @staticmethod + def fnv_hash_vec(arr): + """ + FNV64-1A + """ + assert arr.ndim == 2 + # Floor first for negative coordinates + arr = arr.copy() + arr = arr.astype(np.uint64, copy=False) + hashed_arr = np.uint64(14695981039346656037) * np.ones(arr.shape[0], dtype=np.uint64) + for j in range(arr.shape[1]): + hashed_arr *= np.uint64(1099511628211) + hashed_arr = np.bitwise_xor(hashed_arr, arr[:, j]) + return hashed_arr + + +@TRANSFORMS.register_module() +class SphereCrop(object): + def __init__(self, point_max=80000, sample_rate=None, mode="random"): + self.point_max = point_max + self.sample_rate = sample_rate + assert mode in ["random", "center", "all"] + self.mode = mode + + def __call__(self, data_dict): + point_max = ( + int(self.sample_rate * data_dict["coord"].shape[0]) if self.sample_rate is not None else self.point_max + ) + + assert "coord" in data_dict.keys() + if self.mode == "all": + # TODO: Optimize + if "index" not in data_dict.keys(): + data_dict["index"] = np.arange(data_dict["coord"].shape[0]) + data_part_list = [] + # coord_list, color_list, dist2_list, idx_list, offset_list = [], [], [], [], [] + if data_dict["coord"].shape[0] > point_max: + coord_p, idx_uni = np.random.rand(data_dict["coord"].shape[0]) * 1e-3, np.array([]) + while idx_uni.size != data_dict["index"].shape[0]: + init_idx = np.argmin(coord_p) + dist2 = np.sum( + np.power(data_dict["coord"] - data_dict["coord"][init_idx], 2), + 1, + ) + idx_crop = np.argsort(dist2)[:point_max] + + data_crop_dict = dict() + if "coord" in data_dict.keys(): + data_crop_dict["coord"] = data_dict["coord"][idx_crop] + if "grid_coord" in data_dict.keys(): + data_crop_dict["grid_coord"] = data_dict["grid_coord"][idx_crop] + if "normal" in data_dict.keys(): + data_crop_dict["normal"] = data_dict["normal"][idx_crop] + if "color" in data_dict.keys(): + data_crop_dict["color"] = data_dict["color"][idx_crop] + if "displacement" in data_dict.keys(): + data_crop_dict["displacement"] = data_dict["displacement"][idx_crop] + if "strength" in data_dict.keys(): + data_crop_dict["strength"] = data_dict["strength"][idx_crop] + data_crop_dict["weight"] = dist2[idx_crop] + data_crop_dict["index"] = data_dict["index"][idx_crop] + data_part_list.append(data_crop_dict) + + delta = np.square(1 - data_crop_dict["weight"] / np.max(data_crop_dict["weight"])) + coord_p[idx_crop] += delta + idx_uni = np.unique(np.concatenate((idx_uni, data_crop_dict["index"]))) + else: + data_crop_dict = data_dict.copy() + data_crop_dict["weight"] = np.zeros(data_dict["coord"].shape[0]) + data_crop_dict["index"] = data_dict["index"] + data_part_list.append(data_crop_dict) + return data_part_list + # mode is "random" or "center" + elif data_dict["coord"].shape[0] > point_max: + if self.mode == "random": + center = data_dict["coord"][np.random.randint(data_dict["coord"].shape[0])] + elif self.mode == "center": + center = data_dict["coord"][data_dict["coord"].shape[0] // 2] + else: + raise NotImplementedError + idx_crop = np.argsort(np.sum(np.square(data_dict["coord"] - center), 1))[:point_max] + if "coord" in data_dict.keys(): + data_dict["coord"] = data_dict["coord"][idx_crop] + if "origin_coord" in data_dict.keys(): + data_dict["origin_coord"] = data_dict["origin_coord"][idx_crop] + if "grid_coord" in data_dict.keys(): + data_dict["grid_coord"] = data_dict["grid_coord"][idx_crop] + if "color" in data_dict.keys(): + data_dict["color"] = data_dict["color"][idx_crop] + if "normal" in data_dict.keys(): + data_dict["normal"] = data_dict["normal"][idx_crop] + if "segment" in data_dict.keys(): + data_dict["segment"] = data_dict["segment"][idx_crop] + if "instance" in data_dict.keys(): + data_dict["instance"] = data_dict["instance"][idx_crop] + if "displacement" in data_dict.keys(): + data_dict["displacement"] = data_dict["displacement"][idx_crop] + if "strength" in data_dict.keys(): + data_dict["strength"] = data_dict["strength"][idx_crop] + return data_dict + + +@TRANSFORMS.register_module() +class ShufflePoint(object): + def __call__(self, data_dict): + assert "coord" in data_dict.keys() + shuffle_index = np.arange(data_dict["coord"].shape[0]) + np.random.shuffle(shuffle_index) + if "coord" in data_dict.keys(): + data_dict["coord"] = data_dict["coord"][shuffle_index] + if "grid_coord" in data_dict.keys(): + data_dict["grid_coord"] = data_dict["grid_coord"][shuffle_index] + if "displacement" in data_dict.keys(): + data_dict["displacement"] = data_dict["displacement"][shuffle_index] + if "color" in data_dict.keys(): + data_dict["color"] = data_dict["color"][shuffle_index] + if "normal" in data_dict.keys(): + data_dict["normal"] = data_dict["normal"][shuffle_index] + if "segment" in data_dict.keys(): + data_dict["segment"] = data_dict["segment"][shuffle_index] + if "instance" in data_dict.keys(): + data_dict["instance"] = data_dict["instance"][shuffle_index] + return data_dict + + +class Compose(object): + def __init__(self, cfg=None): + self.cfg = cfg if cfg is not None else [] + self.transforms = [] + for t_cfg in self.cfg: + self.transforms.append(TRANSFORMS.build(t_cfg)) + + def __call__(self, data_dict): + for t in self.transforms: + data_dict = t(data_dict) + return data_dict diff --git a/projects/PTv3/datasets/utils.py b/projects/PTv3/datasets/utils.py new file mode 100644 index 00000000..ea8a27a6 --- /dev/null +++ b/projects/PTv3/datasets/utils.py @@ -0,0 +1,56 @@ +""" +Utils for Datasets + +Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com) +Please cite our work if the code is helpful to you. +""" + +import random +from collections.abc import Mapping, Sequence + +import numpy as np +import torch +from torch.utils.data.dataloader import default_collate + + +def collate_fn(batch): + """ + collate function for point cloud which support dict and list, + 'coord' is necessary to determine 'offset' + """ + if not isinstance(batch, Sequence): + raise TypeError(f"{batch.dtype} is not supported.") + + if isinstance(batch[0], torch.Tensor): + return torch.cat(list(batch)) + elif isinstance(batch[0], str): + # str is also a kind of Sequence, judgement should before Sequence + return list(batch) + elif isinstance(batch[0], Sequence): + for data in batch: + data.append(torch.tensor([data[0].shape[0]])) + batch = [collate_fn(samples) for samples in zip(*batch)] + batch[-1] = torch.cumsum(batch[-1], dim=0).int() + return batch + elif isinstance(batch[0], Mapping): + batch = {key: collate_fn([d[key] for d in batch]) for key in batch[0]} + for key in batch.keys(): + if "offset" in key: + batch[key] = torch.cumsum(batch[key], dim=0) + return batch + else: + return default_collate(batch) + + +def point_collate_fn(batch, mix_prob=0): + assert isinstance(batch[0], Mapping) # currently, only support input_dict, rather than input_list + batch = collate_fn(batch) + if "offset" in batch.keys(): + # Mix3d (https://arxiv.org/pdf/2110.02210.pdf) + if random.random() < mix_prob: + batch["offset"] = torch.cat([batch["offset"][1:-1:2], batch["offset"][-1].unsqueeze(0)], dim=0) + return batch + + +def gaussian_kernel(dist2: np.array, a: float = 1, c: float = 5): + return a * np.exp(-dist2 / (2 * c**2)) diff --git a/projects/PTv3/engines/__init__.py b/projects/PTv3/engines/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/projects/PTv3/engines/defaults.py b/projects/PTv3/engines/defaults.py new file mode 100644 index 00000000..b1af27ec --- /dev/null +++ b/projects/PTv3/engines/defaults.py @@ -0,0 +1,139 @@ +""" +Default training/testing logic + +modified from detectron2(https://github.com/facebookresearch/detectron2) + +Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com) +Please cite our work if the code is helpful to you. +""" + +import argparse +import multiprocessing as mp +import os +import sys + +import utils.comm as comm +from torch.nn.parallel import DistributedDataParallel +from utils.config import Config, DictAction +from utils.env import get_random_seed, set_seed + + +def create_ddp_model(model, *, fp16_compression=False, **kwargs): + """ + Create a DistributedDataParallel model if there are >1 processes. + Args: + model: a torch.nn.Module + fp16_compression: add fp16 compression hooks to the ddp object. + See more at https://pytorch.org/docs/stable/ddp_comm_hooks.html#torch.distributed.algorithms.ddp_comm_hooks.default_hooks.fp16_compress_hook + kwargs: other arguments of :module:`torch.nn.parallel.DistributedDataParallel`. + """ + if comm.get_world_size() == 1: + return model + # kwargs['find_unused_parameters'] = True + if "device_ids" not in kwargs: + kwargs["device_ids"] = [comm.get_local_rank()] + if "output_device" not in kwargs: + kwargs["output_device"] = [comm.get_local_rank()] + ddp = DistributedDataParallel(model, **kwargs) + if fp16_compression: + from torch.distributed.algorithms.ddp_comm_hooks import default as comm_hooks + + ddp.register_comm_hook(state=None, hook=comm_hooks.fp16_compress_hook) + return ddp + + +def worker_init_fn(worker_id, num_workers, rank, seed): + """Worker init func for dataloader. + + The seed of each worker equals to num_worker * rank + worker_id + user_seed + + Args: + worker_id (int): Worker id. + num_workers (int): Number of workers. + rank (int): The rank of current process. + seed (int): The random seed to use. + """ + + worker_seed = num_workers * rank + worker_id + seed + set_seed(worker_seed) + + +def default_argument_parser(epilog=None): + parser = argparse.ArgumentParser( + epilog=epilog + or f""" + Examples: + Run on single machine: + $ {sys.argv[0]} --num-gpus 8 --config-file cfg.yaml + Change some config options: + $ {sys.argv[0]} --config-file cfg.yaml MODEL.WEIGHTS /path/to/weight.pth SOLVER.BASE_LR 0.001 + Run on multiple machines: + (machine0)$ {sys.argv[0]} --machine-rank 0 --num-machines 2 --dist-url [--other-flags] + (machine1)$ {sys.argv[0]} --machine-rank 1 --num-machines 2 --dist-url [--other-flags] + """, + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + parser.add_argument("--config-file", default="", metavar="FILE", help="path to config file") + parser.add_argument("--num-gpus", type=int, default=1, help="number of gpus *per machine*") + parser.add_argument("--num-machines", type=int, default=1, help="total number of machines") + parser.add_argument( + "--machine-rank", + type=int, + default=0, + help="the rank of this machine (unique per machine)", + ) + # PyTorch still may leave orphan processes in multi-gpu training. + # Therefore we use a deterministic way to obtain port, + # so that users are aware of orphan processes by seeing the port occupied. + # port = 2 ** 15 + 2 ** 14 + hash(os.getuid() if sys.platform != "win32" else 1) % 2 ** 14 + parser.add_argument( + "--dist-url", + # default="tcp://127.0.0.1:{}".format(port), + default="auto", + help="initialization URL for pytorch distributed backend. See " + "https://pytorch.org/docs/stable/distributed.html for details.", + ) + parser.add_argument("--options", nargs="+", action=DictAction, help="custom options") + return parser + + +def default_config_parser(file_path, options): + # config name protocol: dataset_name/model_name-exp_name + if os.path.isfile(file_path): + cfg = Config.fromfile(file_path) + else: + sep = file_path.find("-") + cfg = Config.fromfile(os.path.join(file_path[:sep], file_path[sep + 1 :])) + + if options is not None: + cfg.merge_from_dict(options) + + if cfg.seed is None: + cfg.seed = get_random_seed() + + cfg.data.train.loop = cfg.epoch // cfg.eval_epoch + + os.makedirs(os.path.join(cfg.save_path, "model"), exist_ok=True) + if not cfg.resume: + cfg.dump(os.path.join(cfg.save_path, "config.py")) + return cfg + + +def default_setup(cfg): + # scalar by world size + world_size = comm.get_world_size() + cfg.num_worker = cfg.num_worker if cfg.num_worker is not None else mp.cpu_count() + cfg.num_worker_per_gpu = cfg.num_worker // world_size + assert cfg.batch_size % world_size == 0 + assert cfg.batch_size_val is None or cfg.batch_size_val % world_size == 0 + assert cfg.batch_size_test is None or cfg.batch_size_test % world_size == 0 + cfg.batch_size_per_gpu = cfg.batch_size // world_size + cfg.batch_size_val_per_gpu = cfg.batch_size_val // world_size if cfg.batch_size_val is not None else 1 + cfg.batch_size_test_per_gpu = cfg.batch_size_test // world_size if cfg.batch_size_test is not None else 1 + # update data loop + assert cfg.epoch % cfg.eval_epoch == 0 + # settle random seed + rank = comm.get_rank() + seed = None if cfg.seed is None else cfg.seed * cfg.num_worker_per_gpu + rank + set_seed(seed) + return cfg diff --git a/projects/PTv3/engines/hooks/__init__.py b/projects/PTv3/engines/hooks/__init__.py new file mode 100644 index 00000000..21d23cb7 --- /dev/null +++ b/projects/PTv3/engines/hooks/__init__.py @@ -0,0 +1,4 @@ +from .builder import build_hooks +from .default import HookBase +from .evaluator import * +from .misc import * diff --git a/projects/PTv3/engines/hooks/builder.py b/projects/PTv3/engines/hooks/builder.py new file mode 100644 index 00000000..1bd8e912 --- /dev/null +++ b/projects/PTv3/engines/hooks/builder.py @@ -0,0 +1,17 @@ +""" +Hook Builder + +Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com) +Please cite our work if the code is helpful to you. +""" + +from utils.registry import Registry + +HOOKS = Registry("hooks") + + +def build_hooks(cfg): + hooks = [] + for hook_cfg in cfg: + hooks.append(HOOKS.build(hook_cfg)) + return hooks diff --git a/projects/PTv3/engines/hooks/default.py b/projects/PTv3/engines/hooks/default.py new file mode 100644 index 00000000..87a64415 --- /dev/null +++ b/projects/PTv3/engines/hooks/default.py @@ -0,0 +1,32 @@ +""" +Default Hook + +Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com) +Please cite our work if the code is helpful to you. +""" + + +class HookBase: + """ + Base class for hooks that can be registered with :class:`TrainerBase`. + """ + + trainer = None # A weak reference to the trainer object. + + def before_train(self): + pass + + def before_epoch(self): + pass + + def before_step(self): + pass + + def after_step(self): + pass + + def after_epoch(self): + pass + + def after_train(self): + pass diff --git a/projects/PTv3/engines/hooks/evaluator.py b/projects/PTv3/engines/hooks/evaluator.py new file mode 100644 index 00000000..f72546f6 --- /dev/null +++ b/projects/PTv3/engines/hooks/evaluator.py @@ -0,0 +1,91 @@ +""" +Evaluate Hook + +Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com) +Please cite our work if the code is helpful to you. +""" + +import numpy as np +import torch +import torch.distributed as dist +import utils.comm as comm +from utils.misc import intersection_and_union_gpu + +from .builder import HOOKS +from .default import HookBase + + +@HOOKS.register_module() +class SemSegEvaluator(HookBase): + def after_epoch(self): + if self.trainer.cfg.evaluate: + self.eval() + + def eval(self): + self.trainer.logger.info(">>>>>>>>>>>>>>>> Start Evaluation >>>>>>>>>>>>>>>>") + self.trainer.model.eval() + for i, input_dict in enumerate(self.trainer.val_loader): + for key in input_dict.keys(): + if isinstance(input_dict[key], torch.Tensor): + input_dict[key] = input_dict[key].cuda(non_blocking=True) + with torch.no_grad(): + output_dict = self.trainer.model(input_dict) + output = output_dict["seg_logits"] + loss = output_dict["loss"] + pred = output.max(1)[1] + segment = input_dict["segment"] + intersection, union, target = intersection_and_union_gpu( + pred, + segment, + self.trainer.cfg.data.num_classes, + self.trainer.cfg.data.ignore_index, + ) + if comm.get_world_size() > 1: + dist.all_reduce(intersection), dist.all_reduce(union), dist.all_reduce(target) + intersection, union, target = ( + intersection.cpu().numpy(), + union.cpu().numpy(), + target.cpu().numpy(), + ) + # Here there is no need to sync since sync happened in dist.all_reduce + self.trainer.storage.put_scalar("val_intersection", intersection) + self.trainer.storage.put_scalar("val_union", union) + self.trainer.storage.put_scalar("val_target", target) + self.trainer.storage.put_scalar("val_loss", loss.item()) + info = "Test: [{iter}/{max_iter}] ".format(iter=i + 1, max_iter=len(self.trainer.val_loader)) + if "origin_coord" in input_dict.keys(): + info = "Interp. " + info + self.trainer.logger.info( + info + "Loss {loss:.4f} ".format(iter=i + 1, max_iter=len(self.trainer.val_loader), loss=loss.item()) + ) + loss_avg = self.trainer.storage.history("val_loss").avg + intersection = self.trainer.storage.history("val_intersection").total + union = self.trainer.storage.history("val_union").total + target = self.trainer.storage.history("val_target").total + iou_class = intersection / (union + 1e-10) + acc_class = intersection / (target + 1e-10) + m_iou = np.mean(iou_class) + m_acc = np.mean(acc_class) + all_acc = sum(intersection) / (sum(target) + 1e-10) + self.trainer.logger.info("Val result: mIoU/mAcc/allAcc {:.4f}/{:.4f}/{:.4f}.".format(m_iou, m_acc, all_acc)) + for i in range(self.trainer.cfg.data.num_classes): + self.trainer.logger.info( + "Class_{idx}-{name} Result: iou/accuracy {iou:.4f}/{accuracy:.4f}".format( + idx=i, + name=self.trainer.cfg.data.names[i], + iou=iou_class[i], + accuracy=acc_class[i], + ) + ) + current_epoch = self.trainer.epoch + 1 + if self.trainer.writer is not None: + self.trainer.writer.add_scalar("val/loss", loss_avg, current_epoch) + self.trainer.writer.add_scalar("val/mIoU", m_iou, current_epoch) + self.trainer.writer.add_scalar("val/mAcc", m_acc, current_epoch) + self.trainer.writer.add_scalar("val/allAcc", all_acc, current_epoch) + self.trainer.logger.info("<<<<<<<<<<<<<<<<< End Evaluation <<<<<<<<<<<<<<<<<") + self.trainer.comm_info["current_metric_value"] = m_iou # save for saver + self.trainer.comm_info["current_metric_name"] = "mIoU" # save for saver + + def after_train(self): + self.trainer.logger.info("Best {}: {:.4f}".format("mIoU", self.trainer.best_metric_value)) diff --git a/projects/PTv3/engines/hooks/misc.py b/projects/PTv3/engines/hooks/misc.py new file mode 100644 index 00000000..4c0f88a1 --- /dev/null +++ b/projects/PTv3/engines/hooks/misc.py @@ -0,0 +1,248 @@ +""" +Misc Hook + +Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com) +Please cite our work if the code is helpful to you. +""" + +import os +import shutil +import sys +import time +from collections import OrderedDict + +import torch +import utils.comm as comm +from engines.test import TESTERS +from utils.comm import is_main_process +from utils.timer import Timer + +from .builder import HOOKS +from .default import HookBase + + +@HOOKS.register_module() +class IterationTimer(HookBase): + def __init__(self, warmup_iter=1): + self._warmup_iter = warmup_iter + self._start_time = time.perf_counter() + self._iter_timer = Timer() + self._remain_iter = 0 + + def before_train(self): + self._start_time = time.perf_counter() + self._remain_iter = self.trainer.max_epoch * len(self.trainer.train_loader) + + def before_epoch(self): + self._iter_timer.reset() + + def before_step(self): + data_time = self._iter_timer.seconds() + self.trainer.storage.put_scalar("data_time", data_time) + + def after_step(self): + batch_time = self._iter_timer.seconds() + self._iter_timer.reset() + self.trainer.storage.put_scalar("batch_time", batch_time) + self._remain_iter -= 1 + remain_time = self._remain_iter * self.trainer.storage.history("batch_time").avg + t_m, t_s = divmod(remain_time, 60) + t_h, t_m = divmod(t_m, 60) + remain_time = "{:02d}:{:02d}:{:02d}".format(int(t_h), int(t_m), int(t_s)) + if "iter_info" in self.trainer.comm_info.keys(): + info = ( + "Data {data_time_val:.3f} ({data_time_avg:.3f}) " + "Batch {batch_time_val:.3f} ({batch_time_avg:.3f}) " + "Remain {remain_time} ".format( + data_time_val=self.trainer.storage.history("data_time").val, + data_time_avg=self.trainer.storage.history("data_time").avg, + batch_time_val=self.trainer.storage.history("batch_time").val, + batch_time_avg=self.trainer.storage.history("batch_time").avg, + remain_time=remain_time, + ) + ) + self.trainer.comm_info["iter_info"] += info + if self.trainer.comm_info["iter"] <= self._warmup_iter: + self.trainer.storage.history("data_time").reset() + self.trainer.storage.history("batch_time").reset() + + +@HOOKS.register_module() +class InformationWriter(HookBase): + def __init__(self): + self.curr_iter = 0 + self.model_output_keys = [] + + def before_train(self): + self.trainer.comm_info["iter_info"] = "" + self.curr_iter = self.trainer.start_epoch * len(self.trainer.train_loader) + + def before_step(self): + self.curr_iter += 1 + # MSC pretrain do not have offset information. Comment the code for support MSC + # info = "Train: [{epoch}/{max_epoch}][{iter}/{max_iter}] " \ + # "Scan {batch_size} ({points_num}) ".format( + # epoch=self.trainer.epoch + 1, max_epoch=self.trainer.max_epoch, + # iter=self.trainer.comm_info["iter"], max_iter=len(self.trainer.train_loader), + # batch_size=len(self.trainer.comm_info["input_dict"]["offset"]), + # points_num=self.trainer.comm_info["input_dict"]["offset"][-1] + # ) + info = "Train: [{epoch}/{max_epoch}][{iter}/{max_iter}] ".format( + epoch=self.trainer.epoch + 1, + max_epoch=self.trainer.max_epoch, + iter=self.trainer.comm_info["iter"] + 1, + max_iter=len(self.trainer.train_loader), + ) + self.trainer.comm_info["iter_info"] += info + + def after_step(self): + if "model_output_dict" in self.trainer.comm_info.keys(): + model_output_dict = self.trainer.comm_info["model_output_dict"] + self.model_output_keys = model_output_dict.keys() + for key in self.model_output_keys: + self.trainer.storage.put_scalar(key, model_output_dict[key].item()) + + for key in self.model_output_keys: + self.trainer.comm_info["iter_info"] += "{key}: {value:.4f} ".format( + key=key, value=self.trainer.storage.history(key).val + ) + lr = self.trainer.optimizer.state_dict()["param_groups"][0]["lr"] + self.trainer.comm_info["iter_info"] += "Lr: {lr:.5f}".format(lr=lr) + self.trainer.logger.info(self.trainer.comm_info["iter_info"]) + self.trainer.comm_info["iter_info"] = "" # reset iter info + if self.trainer.writer is not None: + self.trainer.writer.add_scalar("lr", lr, self.curr_iter) + for key in self.model_output_keys: + self.trainer.writer.add_scalar( + "train_batch/" + key, + self.trainer.storage.history(key).val, + self.curr_iter, + ) + + def after_epoch(self): + epoch_info = "Train result: " + for key in self.model_output_keys: + epoch_info += "{key}: {value:.4f} ".format(key=key, value=self.trainer.storage.history(key).avg) + self.trainer.logger.info(epoch_info) + if self.trainer.writer is not None: + for key in self.model_output_keys: + self.trainer.writer.add_scalar( + "train/" + key, + self.trainer.storage.history(key).avg, + self.trainer.epoch + 1, + ) + + +@HOOKS.register_module() +class CheckpointSaver(HookBase): + def __init__(self, save_freq=None): + self.save_freq = save_freq # None or int, None indicate only save model last + + def after_epoch(self): + if is_main_process(): + is_best = False + if self.trainer.cfg.evaluate: + current_metric_value = self.trainer.comm_info["current_metric_value"] + current_metric_name = self.trainer.comm_info["current_metric_name"] + if current_metric_value > self.trainer.best_metric_value: + self.trainer.best_metric_value = current_metric_value + is_best = True + self.trainer.logger.info( + "Best validation {} updated to: {:.4f}".format(current_metric_name, current_metric_value) + ) + self.trainer.logger.info( + "Currently Best {}: {:.4f}".format(current_metric_name, self.trainer.best_metric_value) + ) + + filename = os.path.join(self.trainer.cfg.save_path, "model", "model_last.pth") + self.trainer.logger.info("Saving checkpoint to: " + filename) + torch.save( + { + "epoch": self.trainer.epoch + 1, + "state_dict": self.trainer.model.state_dict(), + "optimizer": self.trainer.optimizer.state_dict(), + "scheduler": self.trainer.scheduler.state_dict(), + "scaler": (self.trainer.scaler.state_dict() if self.trainer.cfg.enable_amp else None), + "best_metric_value": self.trainer.best_metric_value, + }, + filename + ".tmp", + ) + os.replace(filename + ".tmp", filename) + if is_best: + shutil.copyfile( + filename, + os.path.join(self.trainer.cfg.save_path, "model", "model_best.pth"), + ) + if self.save_freq and (self.trainer.epoch + 1) % self.save_freq == 0: + shutil.copyfile( + filename, + os.path.join( + self.trainer.cfg.save_path, + "model", + f"epoch_{self.trainer.epoch + 1}.pth", + ), + ) + + +@HOOKS.register_module() +class CheckpointLoader(HookBase): + def __init__(self, keywords="", replacement=None, strict=False): + self.keywords = keywords + self.replacement = replacement if replacement is not None else keywords + self.strict = strict + + def before_train(self): + self.trainer.logger.info("=> Loading checkpoint & weight ...") + if self.trainer.cfg.weight and os.path.isfile(self.trainer.cfg.weight): + self.trainer.logger.info(f"Loading weight at: {self.trainer.cfg.weight}") + checkpoint = torch.load( + self.trainer.cfg.weight, + map_location=lambda storage, loc: storage.cuda(), + weights_only=False, + ) + self.trainer.logger.info( + f"Loading layer weights with keyword: {self.keywords}, " f"replace keyword with: {self.replacement}" + ) + weight = OrderedDict() + for key, value in checkpoint["state_dict"].items(): + if not key.startswith("module."): + key = "module." + key # xxx.xxx -> module.xxx.xxx + # Now all keys contain "module." no matter DDP or not. + if self.keywords in key: + key = key.replace(self.keywords, self.replacement) + if comm.get_world_size() == 1: + key = key[7:] # module.xxx.xxx -> xxx.xxx + weight[key] = value + load_state_info = self.trainer.model.load_state_dict(weight, strict=self.strict) + self.trainer.logger.info(f"Missing keys: {load_state_info[0]}") + if self.trainer.cfg.resume: + self.trainer.logger.info(f"Resuming train at eval epoch: {checkpoint['epoch']}") + self.trainer.start_epoch = checkpoint["epoch"] + self.trainer.best_metric_value = checkpoint["best_metric_value"] + self.trainer.optimizer.load_state_dict(checkpoint["optimizer"]) + self.trainer.scheduler.load_state_dict(checkpoint["scheduler"]) + if self.trainer.cfg.enable_amp: + self.trainer.scaler.load_state_dict(checkpoint["scaler"]) + else: + self.trainer.logger.info(f"No weight found at: {self.trainer.cfg.weight}") + + +@HOOKS.register_module() +class PreciseEvaluator(HookBase): + def __init__(self, test_last=False): + self.test_last = test_last + + def after_train(self): + self.trainer.logger.info(">>>>>>>>>>>>>>>> Start Precise Evaluation >>>>>>>>>>>>>>>>") + torch.cuda.empty_cache() + cfg = self.trainer.cfg + tester = TESTERS.build(dict(type=cfg.test.type, cfg=cfg, model=self.trainer.model)) + if self.test_last: + self.trainer.logger.info("=> Testing on model_last ...") + else: + self.trainer.logger.info("=> Testing on model_best ...") + best_path = os.path.join(self.trainer.cfg.save_path, "model", "model_best.pth") + checkpoint = torch.load(best_path) + state_dict = checkpoint["state_dict"] + tester.model.load_state_dict(state_dict, strict=True, weights_only=False) + tester.test() diff --git a/projects/PTv3/engines/launch.py b/projects/PTv3/engines/launch.py new file mode 100644 index 00000000..df048525 --- /dev/null +++ b/projects/PTv3/engines/launch.py @@ -0,0 +1,129 @@ +""" +Launcher + +modified from detectron2(https://github.com/facebookresearch/detectron2) + +Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com) +Please cite our work if the code is helpful to you. +""" + +import logging +import os +from datetime import timedelta + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +from utils import comm + +__all__ = ["DEFAULT_TIMEOUT", "launch"] + +DEFAULT_TIMEOUT = timedelta(minutes=60) + + +def _find_free_port(): + import socket + + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + # Binding to port 0 will cause the OS to find an available port for us + sock.bind(("", 0)) + port = sock.getsockname()[1] + sock.close() + # NOTE: there is still a chance the port could be taken by other processes. + return port + + +def launch( + main_func, + num_gpus_per_machine, + num_machines=1, + machine_rank=0, + dist_url=None, + cfg=(), + timeout=DEFAULT_TIMEOUT, +): + """ + Launch multi-gpu or distributed training. + This function must be called on all machines involved in the training. + It will spawn child processes (defined by ``num_gpus_per_machine``) on each machine. + Args: + main_func: a function that will be called by `main_func(*args)` + num_gpus_per_machine (int): number of GPUs per machine + num_machines (int): the total number of machines + machine_rank (int): the rank of this machine + dist_url (str): url to connect to for distributed jobs, including protocol + e.g. "tcp://127.0.0.1:8686". + Can be set to "auto" to automatically select a free port on localhost + timeout (timedelta): timeout of the distributed workers + args (tuple): arguments passed to main_func + """ + world_size = num_machines * num_gpus_per_machine + if world_size > 1: + if dist_url == "auto": + assert num_machines == 1, "dist_url=auto not supported in multi-machine jobs." + port = _find_free_port() + dist_url = f"tcp://127.0.0.1:{port}" + if num_machines > 1 and dist_url.startswith("file://"): + logger = logging.getLogger(__name__) + logger.warning("file:// is not a reliable init_method in multi-machine jobs. Prefer tcp://") + + mp.spawn( + _distributed_worker, + nprocs=num_gpus_per_machine, + args=( + main_func, + world_size, + num_gpus_per_machine, + machine_rank, + dist_url, + cfg, + timeout, + ), + daemon=False, + ) + else: + main_func(*cfg) + + +def _distributed_worker( + local_rank, + main_func, + world_size, + num_gpus_per_machine, + machine_rank, + dist_url, + cfg, + timeout=DEFAULT_TIMEOUT, +): + assert torch.cuda.is_available(), "cuda is not available. Please check your installation." + global_rank = machine_rank * num_gpus_per_machine + local_rank + try: + dist.init_process_group( + backend="NCCL", + init_method=dist_url, + world_size=world_size, + rank=global_rank, + timeout=timeout, + ) + except Exception as e: + logger = logging.getLogger(__name__) + logger.error("Process group URL: {}".format(dist_url)) + raise e + + # Setup the local process group (which contains ranks within the same machine) + assert comm._LOCAL_PROCESS_GROUP is None + num_machines = world_size // num_gpus_per_machine + for i in range(num_machines): + ranks_on_i = list(range(i * num_gpus_per_machine, (i + 1) * num_gpus_per_machine)) + pg = dist.new_group(ranks_on_i) + if i == machine_rank: + comm._LOCAL_PROCESS_GROUP = pg + + assert num_gpus_per_machine <= torch.cuda.device_count() + torch.cuda.set_device(local_rank) + + # synchronize is needed here to prevent a possible timeout after calling init_process_group + # See: https://github.com/facebookresearch/maskrcnn-benchmark/issues/172 + comm.synchronize() + + main_func(*cfg) diff --git a/projects/PTv3/engines/test.py b/projects/PTv3/engines/test.py new file mode 100644 index 00000000..de5e8f8d --- /dev/null +++ b/projects/PTv3/engines/test.py @@ -0,0 +1,283 @@ +""" +Tester + +Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com) +Please cite our work if the code is helpful to you. +""" + +import os +import time +from collections import OrderedDict + +import numpy as np +import torch +import torch.nn.functional as F +import torch.utils.data +import utils.comm as comm +from datasets import build_dataset, collate_fn +from models import build_model +from utils.logger import get_root_logger +from utils.misc import ( + AverageMeter, + intersection_and_union, + make_dirs, +) +from utils.registry import Registry + +from .defaults import create_ddp_model + +TESTERS = Registry("testers") + + +class TesterBase: + def __init__(self, cfg, model=None, test_loader=None, verbose=False) -> None: + torch.multiprocessing.set_sharing_strategy("file_system") + self.logger = get_root_logger( + log_file=os.path.join(cfg.save_path, "test.log"), + file_mode="a" if cfg.resume else "w", + ) + self.logger.info("=> Loading config ...") + self.cfg = cfg + self.verbose = verbose + if self.verbose: + self.logger.info(f"Save path: {cfg.save_path}") + self.logger.info(f"Config:\n{cfg.pretty_text}") + if model is None: + self.logger.info("=> Building model ...") + self.model = self.build_model() + else: + self.model = model + if test_loader is None: + self.logger.info("=> Building test dataset & dataloader ...") + self.test_loader = self.build_test_loader() + else: + self.test_loader = test_loader + + def build_model(self): + model = build_model(self.cfg.model) + n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) + self.logger.info(f"Num params: {n_parameters}") + model = create_ddp_model( + model.cuda(), + broadcast_buffers=False, + find_unused_parameters=self.cfg.find_unused_parameters, + ) + if os.path.isfile(self.cfg.weight): + self.logger.info(f"Loading weight at: {self.cfg.weight}") + checkpoint = torch.load(self.cfg.weight, weights_only=False) + weight = OrderedDict() + for key, value in checkpoint["state_dict"].items(): + if key.startswith("module."): + if comm.get_world_size() == 1: + key = key[7:] # module.xxx.xxx -> xxx.xxx + else: + if comm.get_world_size() > 1: + key = "module." + key # xxx.xxx -> module.xxx.xxx + weight[key] = value + model.load_state_dict(weight, strict=True) + self.logger.info("=> Loaded weight '{}' (epoch {})".format(self.cfg.weight, checkpoint["epoch"])) + else: + raise RuntimeError("=> No checkpoint found at '{}'".format(self.cfg.weight)) + return model + + def build_test_loader(self): + test_dataset = build_dataset(self.cfg.data.test) + if comm.get_world_size() > 1: + test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset) + else: + test_sampler = None + test_loader = torch.utils.data.DataLoader( + test_dataset, + batch_size=self.cfg.batch_size_test_per_gpu, + shuffle=False, + num_workers=self.cfg.batch_size_test_per_gpu, + pin_memory=True, + sampler=test_sampler, + collate_fn=self.__class__.collate_fn, + ) + return test_loader + + def test(self): + raise NotImplementedError + + @staticmethod + def collate_fn(batch): + raise collate_fn(batch) + + +@TESTERS.register_module() +class SemSegTester(TesterBase): + def test(self): + assert self.test_loader.batch_size == 1 + logger = get_root_logger() + logger.info(">>>>>>>>>>>>>>>> Start Evaluation >>>>>>>>>>>>>>>>") + + batch_time = AverageMeter() + intersection_meter = AverageMeter() + union_meter = AverageMeter() + target_meter = AverageMeter() + self.model.eval() + + save_path = os.path.join(self.cfg.save_path, "result") + make_dirs(save_path) + # create submit folder only on main process + if self.cfg.data.test.type == "NuScenesDataset" and comm.is_main_process(): + import json + + make_dirs(os.path.join(save_path, "submit", "lidarseg", "test")) + make_dirs(os.path.join(save_path, "submit", "test")) + submission = dict( + meta=dict( + use_camera=False, + use_lidar=True, + use_radar=False, + use_map=False, + use_external=False, + ) + ) + with open(os.path.join(save_path, "submit", "test", "submission.json"), "w") as f: + json.dump(submission, f, indent=4) + comm.synchronize() + record = {} + # fragment inference + for idx, data_dict in enumerate(self.test_loader): + end = time.time() + data_dict = data_dict[0] # current assume batch size is 1 + fragment_list = data_dict.pop("fragment_list") + segment = data_dict.pop("segment") + data_name = data_dict.pop("name") + pred_save_path = os.path.join(save_path, "{}_pred.npy".format(data_name)) + feat_save_path = os.path.join(save_path, "{}_feat.npy".format(data_name)) + result_save_path = os.path.join(save_path, "{}_{}_pred.npz".format(idx, data_name)) + if os.path.isfile(pred_save_path): + logger.info("{}/{}: {}, loaded pred and label.".format(idx + 1, len(self.test_loader), data_name)) + pred = np.load(pred_save_path) + if "origin_segment" in data_dict.keys(): + segment = data_dict["origin_segment"] + else: + pred = torch.zeros((segment.size, self.cfg.data.num_classes)).cuda() + feat = torch.zeros((segment.size, 4)).cuda() + for i in range(len(fragment_list)): + fragment_batch_size = 1 + s_i, e_i = i * fragment_batch_size, min((i + 1) * fragment_batch_size, len(fragment_list)) + input_dict = collate_fn(fragment_list[s_i:e_i]) + for key in input_dict.keys(): + if isinstance(input_dict[key], torch.Tensor): + input_dict[key] = input_dict[key].cuda(non_blocking=True) + idx_part = input_dict["index"] + with torch.no_grad(): + pred_part = self.model(input_dict)["seg_logits"] # (n, k) + pred_part = F.softmax(pred_part, -1) + if self.cfg.empty_cache: + torch.cuda.empty_cache() + bs = 0 + for be in input_dict["offset"]: + pred[idx_part[bs:be], :] += pred_part[bs:be] + feat[idx_part[bs:be], :] = input_dict["feat"][bs:be] + bs = be + + logger.info( + "Test: {}/{}-{data_name}, Batch: {batch_idx}/{batch_num}".format( + idx + 1, + len(self.test_loader), + data_name=data_name, + batch_idx=i, + batch_num=len(fragment_list), + ) + ) + pred = pred.max(1)[1].data.cpu().numpy() + + if "origin_segment" in data_dict.keys(): + assert "inverse" in data_dict.keys() + pred = pred[data_dict["inverse"]] + feat = feat[data_dict["inverse"]] + segment = data_dict["origin_segment"] + # np.save(pred_save_path, pred) + # np.save(feat_save_path, feat.cpu().numpy()) + np.savez_compressed(result_save_path, pred=pred, feat=feat.cpu().numpy()) + if self.cfg.data.test.type == "NuScenesDataset": + np.array(pred + 1).astype(np.uint8).tofile( + os.path.join( + save_path, + "submit", + "lidarseg", + "test", + "{}_lidarseg.bin".format(data_name), + ) + ) + + intersection, union, target = intersection_and_union( + pred, segment, self.cfg.data.num_classes, self.cfg.data.ignore_index + ) + intersection_meter.update(intersection) + union_meter.update(union) + target_meter.update(target) + record[data_name] = dict(intersection=intersection, union=union, target=target) + + mask = union != 0 + iou_class = intersection / (union + 1e-10) + iou = np.mean(iou_class[mask]) + acc = sum(intersection) / (sum(target) + 1e-10) + + m_iou = np.mean(intersection_meter.sum / (union_meter.sum + 1e-10)) + m_acc = np.mean(intersection_meter.sum / (target_meter.sum + 1e-10)) + + batch_time.update(time.time() - end) + logger.info( + "Test: {} [{}/{}]-{} " + "Batch {batch_time.val:.3f} ({batch_time.avg:.3f}) " + "Accuracy {acc:.4f} ({m_acc:.4f}) " + "mIoU {iou:.4f} ({m_iou:.4f})".format( + data_name, + idx + 1, + len(self.test_loader), + segment.size, + batch_time=batch_time, + acc=acc, + m_acc=m_acc, + iou=iou, + m_iou=m_iou, + ) + ) + + logger.info("Syncing ...") + comm.synchronize() + record_sync = comm.gather(record, dst=0) + + if comm.is_main_process(): + record = {} + for _ in range(len(record_sync)): + r = record_sync.pop() + record.update(r) + del r + intersection = np.sum([meters["intersection"] for _, meters in record.items()], axis=0) + union = np.sum([meters["union"] for _, meters in record.items()], axis=0) + target = np.sum([meters["target"] for _, meters in record.items()], axis=0) + + if self.cfg.data.test.type == "S3DISDataset": + torch.save( + dict(intersection=intersection, union=union, target=target), + os.path.join(save_path, f"{self.test_loader.dataset.split}.pth"), + ) + + iou_class = intersection / (union + 1e-10) + accuracy_class = intersection / (target + 1e-10) + mIoU = np.mean(iou_class) + mAcc = np.mean(accuracy_class) + allAcc = sum(intersection) / (sum(target) + 1e-10) + + logger.info("Val result: mIoU/mAcc/allAcc {:.4f}/{:.4f}/{:.4f}".format(mIoU, mAcc, allAcc)) + for i in range(self.cfg.data.num_classes): + logger.info( + "Class_{idx} - {name} Result: iou/accuracy {iou:.4f}/{accuracy:.4f}".format( + idx=i, + name=self.cfg.data.names[i], + iou=iou_class[i], + accuracy=accuracy_class[i], + ) + ) + logger.info("<<<<<<<<<<<<<<<<< End Evaluation <<<<<<<<<<<<<<<<<") + + @staticmethod + def collate_fn(batch): + return batch diff --git a/projects/PTv3/engines/train.py b/projects/PTv3/engines/train.py new file mode 100644 index 00000000..d8cd5e5c --- /dev/null +++ b/projects/PTv3/engines/train.py @@ -0,0 +1,315 @@ +""" +Trainer + +Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com) +Please cite our work if the code is helpful to you. +""" + +import os +import sys +import weakref +from functools import partial + +import torch +import torch.nn as nn +import torch.utils.data + +if sys.version_info >= (3, 10): + from collections.abc import Iterator +else: + from collections import Iterator + +import utils.comm as comm +from datasets import build_dataset, collate_fn, point_collate_fn +from models import build_model +from tensorboardX import SummaryWriter +from utils.events import EventStorage, ExceptionWriter +from utils.logger import get_root_logger +from utils.optimizer import build_optimizer +from utils.registry import Registry +from utils.scheduler import build_scheduler + +from .defaults import create_ddp_model, worker_init_fn +from .hooks import HookBase, build_hooks + +TRAINERS = Registry("trainers") + + +class TrainerBase: + def __init__(self) -> None: + self.hooks = [] + self.epoch = 0 + self.start_epoch = 0 + self.max_epoch = 0 + self.max_iter = 0 + self.comm_info = dict() + self.data_iterator: Iterator = enumerate([]) + self.storage: EventStorage + self.writer: SummaryWriter + + def register_hooks(self, hooks) -> None: + hooks = build_hooks(hooks) + for h in hooks: + assert isinstance(h, HookBase) + # To avoid circular reference, hooks and trainer cannot own each other. + # This normally does not matter, but will cause memory leak if the + # involved objects contain __del__: + # See http://engineering.hearsaysocial.com/2013/06/16/circular-references-in-python/ + h.trainer = weakref.proxy(self) + self.hooks.extend(hooks) + + def train(self): + with EventStorage() as self.storage: + # => before train + self.before_train() + for self.epoch in range(self.start_epoch, self.max_epoch): + # => before epoch + self.before_epoch() + # => run_epoch + for ( + self.comm_info["iter"], + self.comm_info["input_dict"], + ) in self.data_iterator: + # => before_step + self.before_step() + # => run_step + self.run_step() + # => after_step + self.after_step() + # => after epoch + self.after_epoch() + # => after train + self.after_train() + + def before_train(self): + for h in self.hooks: + h.before_train() + + def before_epoch(self): + for h in self.hooks: + h.before_epoch() + + def before_step(self): + for h in self.hooks: + h.before_step() + + def run_step(self): + raise NotImplementedError + + def after_step(self): + for h in self.hooks: + h.after_step() + + def after_epoch(self): + for h in self.hooks: + h.after_epoch() + self.storage.reset_histories() + + def after_train(self): + # Sync GPU before running train hooks + comm.synchronize() + for h in self.hooks: + h.after_train() + if comm.is_main_process(): + self.writer.close() + + +@TRAINERS.register_module("DefaultTrainer") +class Trainer(TrainerBase): + def __init__(self, cfg): + super(Trainer, self).__init__() + self.epoch = 0 + self.start_epoch = 0 + self.max_epoch = cfg.eval_epoch + self.best_metric_value = -torch.inf + self.logger = get_root_logger( + log_file=os.path.join(cfg.save_path, "train.log"), + file_mode="a" if cfg.resume else "w", + ) + self.logger.info("=> Loading config ...") + self.cfg = cfg + self.logger.info(f"Save path: {cfg.save_path}") + self.logger.info(f"Config:\n{cfg.pretty_text}") + self.logger.info("=> Building model ...") + self.model = self.build_model() + self.logger.info("=> Building writer ...") + self.writer = self.build_writer() + self.logger.info("=> Building train dataset & dataloader ...") + self.train_loader = self.build_train_loader() + self.logger.info("=> Building val dataset & dataloader ...") + self.val_loader = self.build_val_loader() + self.logger.info("=> Building optimize, scheduler, scaler(amp) ...") + self.optimizer = self.build_optimizer() + self.scheduler = self.build_scheduler() + self.scaler = self.build_scaler() + self.logger.info("=> Building hooks ...") + self.register_hooks(self.cfg.hooks) + + def train(self): + with EventStorage() as self.storage, ExceptionWriter(): + # => before train + self.before_train() + self.logger.info(">>>>>>>>>>>>>>>> Start Training >>>>>>>>>>>>>>>>") + for self.epoch in range(self.start_epoch, self.max_epoch): + # => before epoch + # TODO: optimize to iteration based + if comm.get_world_size() > 1: + self.train_loader.sampler.set_epoch(self.epoch) + self.model.train() + self.data_iterator = enumerate(self.train_loader) + self.before_epoch() + # => run_epoch + for ( + self.comm_info["iter"], + self.comm_info["input_dict"], + ) in self.data_iterator: + # => before_step + self.before_step() + # => run_step + self.run_step() + # => after_step + self.after_step() + # => after epoch + self.after_epoch() + # => after train + self.after_train() + + def run_step(self): + input_dict = self.comm_info["input_dict"] + for key in input_dict.keys(): + if isinstance(input_dict[key], torch.Tensor): + input_dict[key] = input_dict[key].cuda(non_blocking=True) + with torch.cuda.amp.autocast(enabled=self.cfg.enable_amp): + output_dict = self.model(input_dict) + loss = output_dict["loss"] + self.optimizer.zero_grad() + if self.cfg.enable_amp: + self.scaler.scale(loss).backward() + self.scaler.unscale_(self.optimizer) + if self.cfg.clip_grad is not None: + torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.cfg.clip_grad) + self.scaler.step(self.optimizer) + + # When enable amp, optimizer.step call are skipped if the loss scaling factor is too large. + # Fix torch warning scheduler step before optimizer step. + scaler = self.scaler.get_scale() + self.scaler.update() + if scaler <= self.scaler.get_scale(): + self.scheduler.step() + else: + loss.backward() + if self.cfg.clip_grad is not None: + torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.cfg.clip_grad) + self.optimizer.step() + self.scheduler.step() + if self.cfg.empty_cache: + torch.cuda.empty_cache() + self.comm_info["model_output_dict"] = output_dict + + def after_epoch(self): + for h in self.hooks: + h.after_epoch() + self.storage.reset_histories() + if self.cfg.empty_cache_per_epoch: + torch.cuda.empty_cache() + + def build_model(self): + model = build_model(self.cfg.model) + if self.cfg.sync_bn: + model = nn.SyncBatchNorm.convert_sync_batchnorm(model) + n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) + # logger.info(f"Model: \n{self.model}") + self.logger.info(f"Num params: {n_parameters}") + model = create_ddp_model( + model.cuda(), + broadcast_buffers=False, + find_unused_parameters=self.cfg.find_unused_parameters, + ) + return model + + def build_writer(self): + writer = SummaryWriter(self.cfg.save_path) if comm.is_main_process() else None + self.logger.info(f"Tensorboard writer logging dir: {self.cfg.save_path}") + return writer + + def build_train_loader(self): + train_data = build_dataset(self.cfg.data.train) + + if comm.get_world_size() > 1: + train_sampler = torch.utils.data.distributed.DistributedSampler(train_data) + else: + train_sampler = None + + init_fn = ( + partial( + worker_init_fn, + num_workers=self.cfg.num_worker_per_gpu, + rank=comm.get_rank(), + seed=self.cfg.seed, + ) + if self.cfg.seed is not None + else None + ) + + train_loader = torch.utils.data.DataLoader( + train_data, + batch_size=self.cfg.batch_size_per_gpu, + shuffle=(train_sampler is None), + num_workers=self.cfg.num_worker_per_gpu, + sampler=train_sampler, + collate_fn=partial(point_collate_fn, mix_prob=self.cfg.mix_prob), + pin_memory=True, + worker_init_fn=init_fn, + drop_last=True, + persistent_workers=True, + ) + return train_loader + + def build_val_loader(self): + val_loader = None + if self.cfg.evaluate: + val_data = build_dataset(self.cfg.data.val) + if comm.get_world_size() > 1: + val_sampler = torch.utils.data.distributed.DistributedSampler(val_data) + else: + val_sampler = None + val_loader = torch.utils.data.DataLoader( + val_data, + batch_size=self.cfg.batch_size_val_per_gpu, + shuffle=False, + num_workers=self.cfg.num_worker_per_gpu, + pin_memory=True, + sampler=val_sampler, + collate_fn=collate_fn, + ) + return val_loader + + def build_optimizer(self): + return build_optimizer(self.cfg.optimizer, self.model, self.cfg.param_dicts) + + def build_scheduler(self): + assert hasattr(self, "optimizer") + assert hasattr(self, "train_loader") + self.cfg.scheduler.total_steps = len(self.train_loader) * self.cfg.eval_epoch + return build_scheduler(self.cfg.scheduler, self.optimizer) + + def build_scaler(self): + scaler = torch.cuda.amp.GradScaler() if self.cfg.enable_amp else None + return scaler + + +@TRAINERS.register_module("MultiDatasetTrainer") +class MultiDatasetTrainer(Trainer): + def build_train_loader(self): + from pointcept.datasets import MultiDatasetDataloader + + train_data = build_dataset(self.cfg.data.train) + train_loader = MultiDatasetDataloader( + train_data, + self.cfg.batch_size_per_gpu, + self.cfg.num_worker_per_gpu, + self.cfg.mix_prob, + self.cfg.seed, + ) + self.comm_info["iter_per_epoch"] = len(train_loader) + return train_loader diff --git a/projects/PTv3/models/__init__.py b/projects/PTv3/models/__init__.py new file mode 100644 index 00000000..6dd1f71b --- /dev/null +++ b/projects/PTv3/models/__init__.py @@ -0,0 +1,8 @@ +from .builder import build_model +from .default import * + +# Pretraining +from .point_prompt_training import * + +# Backbones +from .point_transformer_v3 import * diff --git a/projects/PTv3/models/builder.py b/projects/PTv3/models/builder.py new file mode 100644 index 00000000..da6a97f7 --- /dev/null +++ b/projects/PTv3/models/builder.py @@ -0,0 +1,16 @@ +""" +Model Builder + +Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com) +Please cite our work if the code is helpful to you. +""" + +from utils.registry import Registry + +MODELS = Registry("models") +MODULES = Registry("modules") + + +def build_model(cfg): + """Build models.""" + return MODELS.build(cfg) diff --git a/projects/PTv3/models/default.py b/projects/PTv3/models/default.py new file mode 100644 index 00000000..8494f488 --- /dev/null +++ b/projects/PTv3/models/default.py @@ -0,0 +1,42 @@ +import torch.nn as nn +from models.losses import build_criteria +from models.utils.structure import Point + +from .builder import MODELS, build_model + + +@MODELS.register_module() +class DefaultSegmentorV2(nn.Module): + def __init__( + self, + num_classes, + backbone_out_channels, + backbone=None, + criteria=None, + ): + super().__init__() + self.seg_head = nn.Linear(backbone_out_channels, num_classes) if num_classes > 0 else nn.Identity() + self.backbone = build_model(backbone) + self.criteria = build_criteria(criteria) + + def forward(self, input_dict): + point = Point(input_dict) + point = self.backbone(point) + # Backbone added after v1.5.0 return Point instead of feat and use DefaultSegmentorV2 + # TODO: remove this part after make all backbone return Point only. + if isinstance(point, Point): + feat = point.feat + else: + feat = point + seg_logits = self.seg_head(feat) + # train + if self.training: + loss = self.criteria(seg_logits, input_dict["segment"]) + return dict(loss=loss) + # eval + elif "segment" in input_dict.keys(): + loss = self.criteria(seg_logits, input_dict["segment"]) + return dict(loss=loss, seg_logits=seg_logits) + # test + else: + return dict(seg_logits=seg_logits) diff --git a/projects/PTv3/models/losses/__init__.py b/projects/PTv3/models/losses/__init__.py new file mode 100644 index 00000000..f3c252c3 --- /dev/null +++ b/projects/PTv3/models/losses/__init__.py @@ -0,0 +1,3 @@ +from .builder import build_criteria +from .lovasz import LovaszLoss +from .misc import CrossEntropyLoss diff --git a/projects/PTv3/models/losses/builder.py b/projects/PTv3/models/losses/builder.py new file mode 100644 index 00000000..b9041763 --- /dev/null +++ b/projects/PTv3/models/losses/builder.py @@ -0,0 +1,31 @@ +""" +Criteria Builder + +Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com) +Please cite our work if the code is helpful to you. +""" + +from utils.registry import Registry + +LOSSES = Registry("losses") + + +class Criteria(object): + def __init__(self, cfg=None): + self.cfg = cfg if cfg is not None else [] + self.criteria = [] + for loss_cfg in self.cfg: + self.criteria.append(LOSSES.build(cfg=loss_cfg)) + + def __call__(self, pred, target): + if len(self.criteria) == 0: + # loss computation occur in model + return pred + loss = 0 + for c in self.criteria: + loss += c(pred, target) + return loss + + +def build_criteria(cfg): + return Criteria(cfg) diff --git a/projects/PTv3/models/losses/lovasz.py b/projects/PTv3/models/losses/lovasz.py new file mode 100644 index 00000000..9595ef97 --- /dev/null +++ b/projects/PTv3/models/losses/lovasz.py @@ -0,0 +1,189 @@ +""" +Lovasz Loss +refer https://arxiv.org/abs/1705.08790 + +Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com) +Please cite our work if the code is helpful to you. +""" + +from itertools import filterfalse +from typing import Optional + +import torch +from torch.nn.modules.loss import _Loss + +from .builder import LOSSES + +BINARY_MODE: str = "binary" +MULTICLASS_MODE: str = "multiclass" +MULTILABEL_MODE: str = "multilabel" + + +def _lovasz_grad(gt_sorted): + """Compute gradient of the Lovasz extension w.r.t sorted errors + See Alg. 1 in paper + """ + p = len(gt_sorted) + gts = gt_sorted.sum() + intersection = gts - gt_sorted.float().cumsum(0) + union = gts + (1 - gt_sorted).float().cumsum(0) + jaccard = 1.0 - intersection / union + if p > 1: # cover 1-pixel case + jaccard[1:p] = jaccard[1:p] - jaccard[0:-1] + return jaccard + + +def _lovasz_softmax(probas, labels, classes="present", class_seen=None, per_image=False, ignore=None): + """Multi-class Lovasz-Softmax loss + Args: + @param probas: [B, C, H, W] Class probabilities at each prediction (between 0 and 1). + Interpreted as binary (sigmoid) output with outputs of size [B, H, W]. + @param labels: [B, H, W] Tensor, ground truth labels (between 0 and C - 1) + @param classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average. + @param per_image: compute the loss per image instead of per batch + @param ignore: void class labels + """ + if per_image: + loss = mean( + _lovasz_softmax_flat(*_flatten_probas(prob.unsqueeze(0), lab.unsqueeze(0), ignore), classes=classes) + for prob, lab in zip(probas, labels) + ) + else: + loss = _lovasz_softmax_flat(*_flatten_probas(probas, labels, ignore), classes=classes, class_seen=class_seen) + return loss + + +def _lovasz_softmax_flat(probas, labels, classes="present", class_seen=None): + """Multi-class Lovasz-Softmax loss + Args: + @param probas: [P, C] Class probabilities at each prediction (between 0 and 1) + @param labels: [P] Tensor, ground truth labels (between 0 and C - 1) + @param classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average. + """ + if probas.numel() == 0: + # only void pixels, the gradients should be 0 + return probas * 0.0 + C = probas.size(1) + losses = [] + class_to_sum = list(range(C)) if classes in ["all", "present"] else classes + # for c in class_to_sum: + for c in labels.unique(): + if class_seen is None: + fg = (labels == c).type_as(probas) # foreground for class c + if classes == "present" and fg.sum() == 0: + continue + if C == 1: + if len(classes) > 1: + raise ValueError("Sigmoid output possible only with 1 class") + class_pred = probas[:, 0] + else: + class_pred = probas[:, c] + errors = (fg - class_pred).abs() + errors_sorted, perm = torch.sort(errors, 0, descending=True) + perm = perm.data + fg_sorted = fg[perm] + losses.append(torch.dot(errors_sorted, _lovasz_grad(fg_sorted))) + else: + if c in class_seen: + fg = (labels == c).type_as(probas) # foreground for class c + if classes == "present" and fg.sum() == 0: + continue + if C == 1: + if len(classes) > 1: + raise ValueError("Sigmoid output possible only with 1 class") + class_pred = probas[:, 0] + else: + class_pred = probas[:, c] + errors = (fg - class_pred).abs() + errors_sorted, perm = torch.sort(errors, 0, descending=True) + perm = perm.data + fg_sorted = fg[perm] + losses.append(torch.dot(errors_sorted, _lovasz_grad(fg_sorted))) + return mean(losses) + + +def _flatten_probas(probas, labels, ignore=None): + """Flattens predictions in the batch""" + if probas.dim() == 3: + # assumes output of a sigmoid layer + B, H, W = probas.size() + probas = probas.view(B, 1, H, W) + + C = probas.size(1) + probas = torch.movedim(probas, 1, -1) # [B, C, Di, Dj, ...] -> [B, Di, Dj, ..., C] + probas = probas.contiguous().view(-1, C) # [P, C] + + labels = labels.view(-1) + if ignore is None: + return probas, labels + valid = labels != ignore + vprobas = probas[valid] + vlabels = labels[valid] + return vprobas, vlabels + + +def mean(values, ignore_nan=False, empty=0): + """Nan-mean compatible with generators.""" + values = iter(values) + if ignore_nan: + values = filterfalse(isnan, values) + try: + n = 1 + acc = next(values) + except StopIteration: + if empty == "raise": + raise ValueError("Empty mean") + return empty + for n, v in enumerate(values, 2): + acc += v + if n == 1: + return acc + return acc / n + + +@LOSSES.register_module() +class LovaszLoss(_Loss): + def __init__( + self, + mode: str, + class_seen: Optional[int] = None, + per_image: bool = False, + ignore_index: Optional[int] = None, + loss_weight: float = 1.0, + ): + """Lovasz loss for segmentation task. + It supports binary, multiclass and multilabel cases + Args: + mode: Loss mode 'binary', 'multiclass' or 'multilabel' + ignore_index: Label that indicates ignored pixels (does not contribute to loss) + per_image: If True loss computed per each image and then averaged, else computed per whole batch + Shape + - **y_pred** - torch.Tensor of shape (N, C, H, W) + - **y_true** - torch.Tensor of shape (N, H, W) or (N, C, H, W) + Reference + https://github.com/BloodAxe/pytorch-toolbelt + """ + assert mode in {BINARY_MODE, MULTILABEL_MODE, MULTICLASS_MODE} + super().__init__() + + self.mode = mode + self.ignore_index = ignore_index + self.per_image = per_image + self.class_seen = class_seen + self.loss_weight = loss_weight + + def forward(self, y_pred, y_true): + if self.mode in {BINARY_MODE, MULTILABEL_MODE}: + loss = _lovasz_hinge(y_pred, y_true, per_image=self.per_image, ignore=self.ignore_index) + elif self.mode == MULTICLASS_MODE: + y_pred = y_pred.softmax(dim=1) + loss = _lovasz_softmax( + y_pred, + y_true, + class_seen=self.class_seen, + per_image=self.per_image, + ignore=self.ignore_index, + ) + else: + raise ValueError("Wrong mode {}.".format(self.mode)) + return loss * self.loss_weight diff --git a/projects/PTv3/models/losses/misc.py b/projects/PTv3/models/losses/misc.py new file mode 100644 index 00000000..c60a9802 --- /dev/null +++ b/projects/PTv3/models/losses/misc.py @@ -0,0 +1,39 @@ +""" +Misc Losses + +Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com) +Please cite our work if the code is helpful to you. +""" + +import torch +import torch.nn as nn + +from .builder import LOSSES + + +@LOSSES.register_module() +class CrossEntropyLoss(nn.Module): + def __init__( + self, + weight=None, + size_average=None, + reduce=None, + reduction="mean", + label_smoothing=0.0, + loss_weight=1.0, + ignore_index=-1, + ): + super(CrossEntropyLoss, self).__init__() + weight = torch.tensor(weight).cuda() if weight is not None else None + self.loss_weight = loss_weight + self.loss = nn.CrossEntropyLoss( + weight=weight, + size_average=size_average, + ignore_index=ignore_index, + reduce=reduce, + reduction=reduction, + label_smoothing=label_smoothing, + ) + + def forward(self, pred, target): + return self.loss(pred, target) * self.loss_weight diff --git a/projects/PTv3/models/modules.py b/projects/PTv3/models/modules.py new file mode 100644 index 00000000..851376e9 --- /dev/null +++ b/projects/PTv3/models/modules.py @@ -0,0 +1,82 @@ +import sys +from collections import OrderedDict + +import spconv.pytorch as spconv +import torch.nn as nn +from models.utils.structure import Point + + +class PointModule(nn.Module): + r"""PointModule + placeholder, all module subclass from this will take Point in PointSequential. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + +class PointSequential(PointModule): + r"""A sequential container. + Modules will be added to it in the order they are passed in the constructor. + Alternatively, an ordered dict of modules can also be passed in. + """ + + def __init__(self, *args, **kwargs): + super().__init__() + if len(args) == 1 and isinstance(args[0], OrderedDict): + for key, module in args[0].items(): + self.add_module(key, module) + else: + for idx, module in enumerate(args): + self.add_module(str(idx), module) + for name, module in kwargs.items(): + if sys.version_info < (3, 6): + raise ValueError("kwargs only supported in py36+") + if name in self._modules: + raise ValueError("name exists.") + self.add_module(name, module) + + def __getitem__(self, idx): + if not (-len(self) <= idx < len(self)): + raise IndexError("index {} is out of range".format(idx)) + if idx < 0: + idx += len(self) + it = iter(self._modules.values()) + for i in range(idx): + next(it) + return next(it) + + def __len__(self): + return len(self._modules) + + def add(self, module, name=None): + if name is None: + name = str(len(self._modules)) + if name in self._modules: + raise KeyError("name exists") + self.add_module(name, module) + + def forward(self, input): + for k, module in self._modules.items(): + # Point module + if isinstance(module, PointModule): + input = module(input) + # Spconv module + elif spconv.modules.is_spconv_module(module): + if isinstance(input, Point): + input.sparse_conv_feat = module(input.sparse_conv_feat) + input.feat = input.sparse_conv_feat.features + else: + input = module(input) + # PyTorch module + else: + if isinstance(input, Point): + input.feat = module(input.feat) + if "sparse_conv_feat" in input.keys(): + input.sparse_conv_feat = input.sparse_conv_feat.replace_feature(input.feat) + elif isinstance(input, spconv.SparseConvTensor): + if input.indices.shape[0] != 0: + input = input.replace_feature(module(input.features)) + else: + input = module(input) + return input diff --git a/projects/PTv3/models/point_prompt_training/__init__.py b/projects/PTv3/models/point_prompt_training/__init__.py new file mode 100644 index 00000000..d1f24f30 --- /dev/null +++ b/projects/PTv3/models/point_prompt_training/__init__.py @@ -0,0 +1 @@ +from .prompt_driven_normalization import PDNorm diff --git a/projects/PTv3/models/point_prompt_training/prompt_driven_normalization.py b/projects/PTv3/models/point_prompt_training/prompt_driven_normalization.py new file mode 100644 index 00000000..72dad0d5 --- /dev/null +++ b/projects/PTv3/models/point_prompt_training/prompt_driven_normalization.py @@ -0,0 +1,44 @@ +import torch.nn as nn +from models.builder import MODULES +from models.modules import PointModule + + +@MODULES.register_module() +class PDNorm(PointModule): + def __init__( + self, + num_features, + norm_layer, + context_channels=256, + conditions=("ScanNet", "S3DIS", "Structured3D"), + decouple=True, + adaptive=False, + ): + super().__init__() + self.conditions = conditions + self.decouple = decouple + self.adaptive = adaptive + if self.decouple: + self.norm = nn.ModuleList([norm_layer(num_features) for _ in conditions]) + else: + self.norm = norm_layer + if self.adaptive: + self.modulation = nn.Sequential(nn.SiLU(), nn.Linear(context_channels, 2 * num_features, bias=True)) + + def forward(self, point): + assert {"feat", "condition"}.issubset(point.keys()) + if isinstance(point.condition, str): + condition = point.condition + else: + condition = point.condition[0] + if self.decouple: + assert condition in self.conditions + norm = self.norm[self.conditions.index(condition)] + else: + norm = self.norm + point.feat = norm(point.feat) + if self.adaptive: + assert "context" in point.keys() + shift, scale = self.modulation(point.context).chunk(2, dim=1) + point.feat = point.feat * (1.0 + scale) + shift + return point diff --git a/projects/PTv3/models/point_transformer_v3/__init__.py b/projects/PTv3/models/point_transformer_v3/__init__.py new file mode 100644 index 00000000..5fe25f32 --- /dev/null +++ b/projects/PTv3/models/point_transformer_v3/__init__.py @@ -0,0 +1 @@ +from .point_transformer_v3m1_base import * diff --git a/projects/PTv3/models/point_transformer_v3/point_transformer_v3m1_base.py b/projects/PTv3/models/point_transformer_v3/point_transformer_v3m1_base.py new file mode 100644 index 00000000..df7fe447 --- /dev/null +++ b/projects/PTv3/models/point_transformer_v3/point_transformer_v3m1_base.py @@ -0,0 +1,822 @@ +""" +Point Transformer - V3 Mode1 + +Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com) +Please cite our work if the code is helpful to you. +""" + +import math +from functools import partial + +import torch +import torch.nn as nn +import torch_scatter +from addict import Dict + +try: + import flash_attn +except ImportError: + flash_attn = None + +from models.builder import MODELS +from models.modules import PointModule, PointSequential +from models.point_prompt_training import PDNorm +from models.scatter.functional import argsort, segment_csr, unique +from models.utils.misc import offset2bincount +from models.utils.structure import Point + +# NOTE(knzo25): hack to use exportable spconv when available + +try: + + from SparseConvolution.sparse_conv import SubMConv3d + + print("Using spconv2.0 with export support") + +except ImportError: + + from spconv.pytorch import SubMConv3d + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + Copied from timm https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/drop.py. + """ + + def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + self.scale_by_keep = scale_by_keep + + def drop_path(self, x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, + the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for + changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use + 'survival rate' as the argument. + + """ + if drop_prob == 0.0 or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0 and scale_by_keep: + random_tensor.div_(keep_prob) + return x * random_tensor + + def forward(self, x): + return self.drop_path(x, self.drop_prob, self.training, self.scale_by_keep) + + def extra_repr(self): + return f"drop_prob={round(self.drop_prob,3):0.3f}" + + +class RPE(torch.nn.Module): + def __init__(self, patch_size, num_heads): + super().__init__() + self.patch_size = patch_size + self.num_heads = num_heads + self.pos_bnd = int((4 * patch_size) ** (1 / 3) * 2) + self.rpe_num = 2 * self.pos_bnd + 1 + self.rpe_table = torch.nn.Parameter(torch.zeros(3 * self.rpe_num, num_heads)) + torch.nn.init.trunc_normal_(self.rpe_table, std=0.02) + + def forward(self, coord): + idx = ( + coord.clamp(-self.pos_bnd, self.pos_bnd) # clamp into bnd + + self.pos_bnd # relative position to positive index + + torch.arange(3, device=coord.device) * self.rpe_num # x, y, z stride + ) + out = self.rpe_table.index_select(0, idx.reshape(-1)) + out = out.view(idx.shape + (-1,)).sum(3) + out = out.permute(0, 3, 1, 2) # (N, K, K, H) -> (N, H, K, K) + return out + + +class SerializedAttention(PointModule): + def __init__( + self, + channels, + num_heads, + patch_size, + qkv_bias=True, + qk_scale=None, + attn_drop=0.0, + proj_drop=0.0, + order_index=0, + enable_rpe=False, + enable_flash=True, + upcast_attention=True, + upcast_softmax=True, + export_mode=False, + ): + super().__init__() + assert channels % num_heads == 0 + self.channels = channels + self.num_heads = num_heads + self.scale = qk_scale or (channels // num_heads) ** -0.5 + self.order_index = order_index + self.upcast_attention = upcast_attention + self.upcast_softmax = upcast_softmax + self.enable_rpe = enable_rpe + self.enable_flash = enable_flash and not export_mode + self.export_mode = export_mode + if self.enable_flash: + assert enable_rpe is False, "Set enable_rpe to False when enable Flash Attention" + assert upcast_attention is False, "Set upcast_attention to False when enable Flash Attention" + assert upcast_softmax is False, "Set upcast_softmax to False when enable Flash Attention" + assert flash_attn is not None, "Make sure flash_attn is installed." + self.patch_size = patch_size + self.attn_drop = attn_drop + else: + # when disable flash attention, we still don't want to use mask + # consequently, patch size will auto set to the + # min number of patch_size_max and number of points + self.patch_size_max = patch_size + self.patch_size = 0 + self.attn_drop = torch.nn.Dropout(attn_drop) + + self.qkv = torch.nn.Linear(channels, channels * 3, bias=qkv_bias) + self.proj = torch.nn.Linear(channels, channels) + self.proj_drop = torch.nn.Dropout(proj_drop) + self.softmax = torch.nn.Softmax(dim=-1) + self.rpe = RPE(patch_size, num_heads) if self.enable_rpe else None + + @torch.no_grad() + def get_rel_pos(self, point, order): + K = self.patch_size + rel_pos_key = f"rel_pos_{self.order_index}" + if rel_pos_key not in point.keys(): + grid_coord = point.grid_coord[order] + grid_coord = grid_coord.reshape(-1, K, 3) + point[rel_pos_key] = grid_coord.unsqueeze(2) - grid_coord.unsqueeze(1) + return point[rel_pos_key] + + @torch.no_grad() + def get_padding_and_inverse(self, point): + pad_key = "pad" + unpad_key = "unpad" + cu_seqlens_key = "cu_seqlens_key" + if pad_key not in point.keys() or unpad_key not in point.keys() or cu_seqlens_key not in point.keys(): + offset = point.offset + bincount = offset2bincount(offset) + bincount_pad = ( + torch.maximum( + torch.div( + bincount + self.patch_size - 1, + self.patch_size, + rounding_mode="trunc", + ), + torch.tensor(1, device=bincount.device), + ) + * self.patch_size + ) + # only pad point when num of points larger than patch_size + mask_pad = bincount > self.patch_size + bincount_pad = (1 - mask_pad.int()) * bincount + mask_pad.int() * bincount_pad + + if not self.export_mode: + _offset = nn.functional.pad(offset, (1, 0)) + _offset_pad = nn.functional.pad(torch.cumsum(bincount_pad, dim=0), (1, 0)) + + pad = torch.arange(_offset_pad[-1], device=offset.device) + unpad = torch.arange(_offset[-1], device=offset.device) + cu_seqlens = [] + for i in range(len(offset)): + unpad[_offset[i] : _offset[i + 1]] += _offset_pad[i] - _offset[i] + if bincount[i] != bincount_pad[i]: + pad[ + _offset_pad[i + 1] - self.patch_size + (bincount[i] % self.patch_size) : _offset_pad[i + 1] + ] = pad[ + _offset_pad[i + 1] + - 2 * self.patch_size + + (bincount[i] % self.patch_size) : _offset_pad[i + 1] + - self.patch_size + ] + pad[_offset_pad[i] : _offset_pad[i + 1]] -= _offset_pad[i] - _offset[i] + cu_seqlens.append( + torch.arange( + _offset_pad[i], + _offset_pad[i + 1], + step=self.patch_size, + dtype=torch.int32, + device=offset.device, + ) + ) + point[pad_key] = pad + point[unpad_key] = unpad + point[cu_seqlens_key] = nn.functional.pad(torch.concat(cu_seqlens), (0, 1), value=_offset_pad[-1]) + else: + # NOTE(knzo25): needed due to tensorrt reasons + assert len(offset) == 1 + + # pad_orig = pad + # unpad_orig = unpad + + pad = torch.arange(bincount_pad[0], device=offset.device) + unpad = torch.arange(offset[0], device=offset.device) + cu_seqlens = [] + + pad[bincount_pad[0] - self.patch_size + (bincount[0] % self.patch_size) : bincount_pad[0]] = pad[ + bincount_pad[0] + - 2 * self.patch_size + + (bincount[0] % self.patch_size) : bincount_pad[0] + - self.patch_size + ] + + cu_seqlens.append( + torch.arange( + 0, + bincount_pad[0], + step=self.patch_size, + dtype=torch.int32, + device=offset.device, + ) + ) + + point[pad_key] = pad + point[unpad_key] = unpad + point[cu_seqlens_key] = nn.functional.pad(torch.concat(cu_seqlens), (0, 1), value=bincount_pad[0]) + + return point[pad_key], point[unpad_key], point[cu_seqlens_key] + + def forward(self, point): + if not self.enable_flash: + assert offset2bincount(point.offset).min() >= self.patch_size_max # NOTE(knzo25): assumed for deployment + self.patch_size = self.patch_size_max + + H = self.num_heads + K = self.patch_size + C = self.channels + + pad, unpad, cu_seqlens = self.get_padding_and_inverse(point) + + order = point.serialized_order[self.order_index][pad] + inverse = unpad[point.serialized_inverse[self.order_index]] + + # padding and reshape feat and batch for serialized point patch + qkv = self.qkv(point.feat)[order] + + if not self.enable_flash: + # encode and reshape qkv: (N', K, 3, H, C') => (3, N', H, K, C') + q, k, v = qkv.reshape(-1, K, 3, H, C // H).permute(2, 0, 3, 1, 4).unbind(dim=0) + # attn + if self.upcast_attention: + q = q.float() + k = k.float() + attn = (q * self.scale) @ k.transpose(-2, -1) # (N', H, K, K) + if self.enable_rpe: + attn = attn + self.rpe(self.get_rel_pos(point, order)) + if self.upcast_softmax: + attn = attn.float() + attn = self.softmax(attn) + attn = self.attn_drop(attn).to(qkv.dtype) + feat = (attn @ v).transpose(1, 2).reshape(-1, C) + else: + feat = flash_attn.flash_attn_varlen_qkvpacked_func( + qkv.half().reshape(-1, 3, H, C // H), + cu_seqlens, + max_seqlen=self.patch_size, + dropout_p=self.attn_drop if self.training else 0, + softmax_scale=self.scale, + ).reshape(-1, C) + feat = feat.to(qkv.dtype) + feat = feat[inverse] + + # ffn + feat = self.proj(feat) + feat = self.proj_drop(feat) + point.feat = feat + return point + + +class MLP(nn.Module): + def __init__( + self, + in_channels, + hidden_channels=None, + out_channels=None, + act_layer=nn.GELU, + drop=0.0, + ): + super().__init__() + out_channels = out_channels or in_channels + hidden_channels = hidden_channels or in_channels + self.fc1 = nn.Linear(in_channels, hidden_channels) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_channels, out_channels) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class Block(PointModule): + def __init__( + self, + channels, + num_heads, + patch_size=48, + mlp_ratio=4.0, + qkv_bias=True, + qk_scale=None, + attn_drop=0.0, + proj_drop=0.0, + drop_path=0.0, + norm_layer=nn.LayerNorm, + act_layer=nn.GELU, + pre_norm=True, + order_index=0, + cpe_indice_key=None, + enable_rpe=False, + enable_flash=True, + upcast_attention=True, + upcast_softmax=True, + export_mode=False, + ): + super().__init__() + self.channels = channels + self.pre_norm = pre_norm + self.export_mode = export_mode + + self.cpe = PointSequential( + SubMConv3d( + channels, + channels, + kernel_size=3, + bias=True, + indice_key=cpe_indice_key, + ), + nn.Linear(channels, channels), + norm_layer(channels), + ) + + self.norm1 = PointSequential(norm_layer(channels)) + self.attn = SerializedAttention( + channels=channels, + patch_size=patch_size, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=proj_drop, + order_index=order_index, + enable_rpe=enable_rpe, + enable_flash=enable_flash, + upcast_attention=upcast_attention, + upcast_softmax=upcast_softmax, + export_mode=self.export_mode, + ) + self.norm2 = PointSequential(norm_layer(channels)) + self.mlp = PointSequential( + MLP( + in_channels=channels, + hidden_channels=int(channels * mlp_ratio), + out_channels=channels, + act_layer=act_layer, + drop=proj_drop, + ) + ) + self.drop_path = PointSequential(DropPath(drop_path) if drop_path > 0.0 else nn.Identity()) + + def forward(self, point: Point): + shortcut = point.feat + point = self.cpe(point) + point.feat = shortcut + point.feat + shortcut = point.feat + if self.pre_norm: + point = self.norm1(point) + point = self.drop_path(self.attn(point)) + point.feat = shortcut + point.feat + if not self.pre_norm: + point = self.norm1(point) + + shortcut = point.feat + if self.pre_norm: + point = self.norm2(point) + point = self.drop_path(self.mlp(point)) + point.feat = shortcut + point.feat + if not self.pre_norm: + point = self.norm2(point) + point.sparse_conv_feat = point.sparse_conv_feat.replace_feature(point.feat) + return point + + +class SerializedPooling(PointModule): + + def __init__( + self, + in_channels, + out_channels, + stride=2, + norm_layer=None, + act_layer=None, + reduce="max", + shuffle_orders=True, + traceable=True, # record parent and cluster + export_mode=False, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.export_mode = export_mode + + assert stride == 2 ** (math.ceil(stride) - 1).bit_length() # 2, 4, 8 + # TODO: add support to grid pool (any stride) + self.stride = stride + assert reduce in ["sum", "mean", "min", "max"] + self.reduce = reduce + self.shuffle_orders = shuffle_orders + self.traceable = traceable + + self.proj = nn.Linear(in_channels, out_channels) + if norm_layer is not None: + self.norm = PointSequential(norm_layer(out_channels)) + if act_layer is not None: + self.act = PointSequential(act_layer()) + + def forward(self, point: Point): + pooling_depth = (math.ceil(self.stride) - 1).bit_length() + if pooling_depth > point.serialized_depth: + pooling_depth = 0 + assert { + "serialized_code", + "serialized_order", + "serialized_inverse", + "serialized_depth", + "sparse_shape", + }.issubset(point.keys()), "Run point.serialization() point cloud before SerializedPooling" + + sparse_shape = point.sparse_shape + + code = point.serialized_code >> pooling_depth * 3 + + if not self.export_mode: + code_, cluster, counts = torch.unique( + code[0], + sorted=True, + return_inverse=True, + return_counts=True, + ) + + _, indices = torch.sort(cluster) + else: + code_, cluster, counts, num_unique = unique(code[0]) + indices = argsort(cluster) + + # indices of point sorted by cluster, for torch_scatter.segment_csr + + # index pointer for sorted point, for torch_scatter.segment_csr + idx_ptr = torch.cat([counts.new_zeros(1), torch.cumsum(counts, dim=0)]) + # head_indices of each cluster, for reduce attr e.g. code, batch + head_indices = indices[idx_ptr[:-1]] + # generate down code, order, inverse + code = code[:, head_indices] + + if not self.export_mode: + order = torch.argsort(code) + else: + order = torch.stack([argsort(code[i]) for i in range(len(code))], dim=0) + inverse = torch.zeros_like(order).scatter_( + dim=1, + index=order, + src=torch.arange(0, code.shape[1], device=order.device).repeat(code.shape[0], 1), + ) + + if self.shuffle_orders: + perm = torch.randperm(code.shape[0]) + code = code[perm] + order = order[perm] + inverse = inverse[perm] + + # collect information + + if not self.export_mode: + scatter_feat = torch_scatter.segment_csr(self.proj(point.feat)[indices], idx_ptr, reduce=self.reduce) + scatter_coord = torch_scatter.segment_csr(point.coord[indices], idx_ptr, reduce="mean") + else: + scatter_feat = segment_csr(self.proj(point.feat)[indices], idx_ptr, self.reduce) + scatter_coord = segment_csr(point.coord[indices], idx_ptr, "mean") + + point_dict = Dict( + feat=scatter_feat, + coord=scatter_coord, + grid_coord=point.grid_coord[head_indices] >> pooling_depth, + serialized_code=code, + serialized_order=order, + serialized_inverse=inverse, + serialized_depth=point.serialized_depth - pooling_depth, + batch=point.batch[head_indices], + sparse_shape=sparse_shape >> pooling_depth, + ) + + if "condition" in point.keys(): + point_dict["condition"] = point.condition + if "context" in point.keys(): + point_dict["context"] = point.context + + if self.traceable: + point_dict["pooling_inverse"] = cluster + point_dict["pooling_parent"] = point + point = Point(point_dict) + if self.norm is not None: + point = self.norm(point) + if self.act is not None: + point = self.act(point) + point.sparsify() + return point + + +class SerializedUnpooling(PointModule): + def __init__( + self, + in_channels, + skip_channels, + out_channels, + norm_layer=None, + act_layer=None, + traceable=False, # record parent and cluster + ): + super().__init__() + self.proj = PointSequential(nn.Linear(in_channels, out_channels)) + self.proj_skip = PointSequential(nn.Linear(skip_channels, out_channels)) + + if norm_layer is not None: + self.proj.add(norm_layer(out_channels)) + self.proj_skip.add(norm_layer(out_channels)) + + if act_layer is not None: + self.proj.add(act_layer()) + self.proj_skip.add(act_layer()) + + self.traceable = traceable + + def forward(self, point): + assert "pooling_parent" in point.keys() + assert "pooling_inverse" in point.keys() + parent = point.pop("pooling_parent") + inverse = point.pop("pooling_inverse") + point = self.proj(point) + parent = self.proj_skip(parent) + parent.feat = parent.feat + point.feat[inverse] + + if self.traceable: + parent["unpooling_parent"] = point + return parent + + +class Embedding(PointModule): + def __init__( + self, + in_channels, + embed_channels, + norm_layer=None, + act_layer=None, + ): + super().__init__() + self.in_channels = in_channels + self.embed_channels = embed_channels + + # TODO: check remove spconv + self.stem = PointSequential( + conv=SubMConv3d( + in_channels, + embed_channels, + kernel_size=5, + padding=1, + bias=False, + indice_key="stem", + ) + ) + if norm_layer is not None: + self.stem.add(norm_layer(embed_channels), name="norm") + if act_layer is not None: + self.stem.add(act_layer(), name="act") + + def forward(self, point: Point): + point = self.stem(point) + return point + + +@MODELS.register_module("PT-v3m1") +class PointTransformerV3(PointModule): + + def __init__( + self, + in_channels=6, + order=("z", "z-trans"), + stride=(2, 2, 2, 2), + enc_depths=(2, 2, 2, 6, 2), + enc_channels=(32, 64, 128, 256, 512), + enc_num_head=(2, 4, 8, 16, 32), + enc_patch_size=(48, 48, 48, 48, 48), + dec_depths=(2, 2, 2, 2), + dec_channels=(64, 64, 128, 256), + dec_num_head=(4, 4, 8, 16), + dec_patch_size=(48, 48, 48, 48), + mlp_ratio=4, + qkv_bias=True, + qk_scale=None, + attn_drop=0.0, + proj_drop=0.0, + drop_path=0.3, + pre_norm=True, + shuffle_orders=True, + enable_rpe=False, + enable_flash=True, + upcast_attention=False, + upcast_softmax=False, + cls_mode=False, + pdnorm_bn=False, + pdnorm_ln=False, + pdnorm_decouple=True, + pdnorm_adaptive=False, + pdnorm_affine=True, + pdnorm_conditions=("ScanNet", "S3DIS", "Structured3D"), + export_mode=False, + ): + super().__init__() + self.num_stages = len(enc_depths) + self.order = [order] if isinstance(order, str) else order + self.cls_mode = cls_mode + self.shuffle_orders = shuffle_orders + self.export_mode = export_mode + + assert self.num_stages == len(stride) + 1 + assert self.num_stages == len(enc_depths) + assert self.num_stages == len(enc_channels) + assert self.num_stages == len(enc_num_head) + assert self.num_stages == len(enc_patch_size) + assert self.cls_mode or self.num_stages == len(dec_depths) + 1 + assert self.cls_mode or self.num_stages == len(dec_channels) + 1 + assert self.cls_mode or self.num_stages == len(dec_num_head) + 1 + assert self.cls_mode or self.num_stages == len(dec_patch_size) + 1 + + # norm layers + if pdnorm_bn: + bn_layer = partial( + PDNorm, + norm_layer=partial(nn.BatchNorm1d, eps=1e-3, momentum=0.01, affine=pdnorm_affine), + conditions=pdnorm_conditions, + decouple=pdnorm_decouple, + adaptive=pdnorm_adaptive, + ) + else: + bn_layer = partial(nn.BatchNorm1d, eps=1e-3, momentum=0.01) + if pdnorm_ln: + ln_layer = partial( + PDNorm, + norm_layer=partial(nn.LayerNorm, elementwise_affine=pdnorm_affine), + conditions=pdnorm_conditions, + decouple=pdnorm_decouple, + adaptive=pdnorm_adaptive, + ) + else: + ln_layer = nn.LayerNorm + # activation layers + act_layer = nn.GELU + + self.embedding = Embedding( + in_channels=in_channels, + embed_channels=enc_channels[0], + norm_layer=bn_layer, + act_layer=act_layer, + ) + + # encoder + enc_drop_path = [x.item() for x in torch.linspace(0, drop_path, sum(enc_depths))] + self.enc = PointSequential() + for s in range(self.num_stages): + enc_drop_path_ = enc_drop_path[sum(enc_depths[:s]) : sum(enc_depths[: s + 1])] + enc = PointSequential() + if s > 0: + enc.add( + SerializedPooling( + in_channels=enc_channels[s - 1], + out_channels=enc_channels[s], + stride=stride[s - 1], + norm_layer=bn_layer, + act_layer=act_layer, + shuffle_orders=shuffle_orders, + export_mode=self.export_mode, + ), + name="down", + ) + for i in range(enc_depths[s]): + enc.add( + Block( + channels=enc_channels[s], + num_heads=enc_num_head[s], + patch_size=enc_patch_size[s], + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=proj_drop, + drop_path=enc_drop_path_[i], + norm_layer=ln_layer, + act_layer=act_layer, + pre_norm=pre_norm, + order_index=i % len(self.order), + cpe_indice_key=f"stage{s}", + enable_rpe=enable_rpe, + enable_flash=enable_flash, + upcast_attention=upcast_attention, + upcast_softmax=upcast_softmax, + export_mode=export_mode, + ), + name=f"block{i}", + ) + if len(enc) != 0: + self.enc.add(module=enc, name=f"enc{s}") + + # decoder + if not self.cls_mode: + dec_drop_path = [x.item() for x in torch.linspace(0, drop_path, sum(dec_depths))] + self.dec = PointSequential() + dec_channels = list(dec_channels) + [enc_channels[-1]] + for s in reversed(range(self.num_stages - 1)): + dec_drop_path_ = dec_drop_path[sum(dec_depths[:s]) : sum(dec_depths[: s + 1])] + dec_drop_path_.reverse() + dec = PointSequential() + dec.add( + SerializedUnpooling( + in_channels=dec_channels[s + 1], + skip_channels=enc_channels[s], + out_channels=dec_channels[s], + norm_layer=bn_layer, + act_layer=act_layer, + ), + name="up", + ) + for i in range(dec_depths[s]): + dec.add( + Block( + channels=dec_channels[s], + num_heads=dec_num_head[s], + patch_size=dec_patch_size[s], + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=proj_drop, + drop_path=dec_drop_path_[i], + norm_layer=ln_layer, + act_layer=act_layer, + pre_norm=pre_norm, + order_index=i % len(self.order), + cpe_indice_key=f"stage{s}", + enable_rpe=enable_rpe, + enable_flash=enable_flash, + upcast_attention=upcast_attention, + upcast_softmax=upcast_softmax, + export_mode=export_mode, + ), + name=f"block{i}", + ) + self.dec.add(module=dec, name=f"dec{s}") + + def forward(self, data_dict): + point = Point(data_dict) + point.serialization(order=self.order, shuffle_orders=self.shuffle_orders) + point.sparsify() + + point = self.embedding(point) + point = self.enc(point) + + if not self.cls_mode: + point = self.dec(point) + # else: + # point.feat = torch_scatter.segment_csr( + # src=point.feat, + # indptr=nn.functional.pad(point.offset, (1, 0)), + # reduce="mean", + # ) + return point + + def export_forward(self, data_dict): + point = Point(data_dict) + # point.serialization(order=self.order, shuffle_orders=self.shuffle_orders) + point["serialized_depth"] = data_dict["serialized_depth"] + point["serialized_code"] = data_dict["serialized_code"] + point["serialized_order"] = data_dict["serialized_order"] + point["serialized_inverse"] = data_dict["serialized_inverse"] + point["sparse_shape"] = data_dict["sparse_shape"] + point.sparsify() + + point = self.embedding(point) + + point = self.enc(point) + + if not self.cls_mode: + point = self.dec(point) + + return point diff --git a/projects/PTv3/models/utils/__init__.py b/projects/PTv3/models/utils/__init__.py new file mode 100644 index 00000000..df332701 --- /dev/null +++ b/projects/PTv3/models/utils/__init__.py @@ -0,0 +1,4 @@ +from .checkpoint import checkpoint +from .misc import batch2offset, off_diagonal, offset2batch, offset2bincount +from .serialization import decode, encode +from .structure import Point diff --git a/projects/PTv3/models/utils/checkpoint.py b/projects/PTv3/models/utils/checkpoint.py new file mode 100644 index 00000000..58820352 --- /dev/null +++ b/projects/PTv3/models/utils/checkpoint.py @@ -0,0 +1,57 @@ +""" +Checkpoint Utils for Models + +Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com) +Please cite our work if the code is helpful to you. +""" + +import torch + + +class CheckpointFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, run_function, length, *args): + ctx.run_function = run_function + ctx.input_tensors = list(args[:length]) + ctx.input_params = list(args[length:]) + + with torch.no_grad(): + output_tensors = ctx.run_function(*ctx.input_tensors) + return output_tensors + + @staticmethod + def backward(ctx, *output_grads): + ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] + with torch.enable_grad(): + # Fixes a bug where the first op in run_function modifies the + # Tensor storage in place, which is not allowed for detach()'d + # Tensors. + shallow_copies = [x.view_as(x) for x in ctx.input_tensors] + output_tensors = ctx.run_function(*shallow_copies) + input_grads = torch.autograd.grad( + output_tensors, + ctx.input_tensors + ctx.input_params, + output_grads, + allow_unused=True, + ) + del ctx.input_tensors + del ctx.input_params + del output_tensors + return (None, None) + input_grads + + +def checkpoint(func, inputs, params, flag): + """ + Evaluate a function without caching intermediate activations, allowing for + reduced memory at the expense of extra compute in the backward pass. + :param func: the function to evaluate. + :param inputs: the argument sequence to pass to `func`. + :param params: a sequence of parameters `func` depends on but does not + explicitly take as arguments. + :param flag: if False, disable gradient checkpointing. + """ + if flag: + args = tuple(inputs) + tuple(params) + return CheckpointFunction.apply(func, len(inputs), *args) + else: + return func(*inputs) diff --git a/projects/PTv3/models/utils/misc.py b/projects/PTv3/models/utils/misc.py new file mode 100644 index 00000000..73e829e7 --- /dev/null +++ b/projects/PTv3/models/utils/misc.py @@ -0,0 +1,42 @@ +""" +General Utils for Models + +Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com) +Please cite our work if the code is helpful to you. +""" + +import torch + + +@torch.inference_mode() +def offset2bincount(offset): + # NOTE(knzo25): hack to avoid unsupported ops in export mode + if len(offset) == 1: + return offset + return torch.diff(offset, prepend=torch.tensor([0], device=offset.device, dtype=torch.long)) + + +@torch.inference_mode() +def offset2batch(offset, coords=None): + + # NOTE(knzo25): hack to avoid unsupported ops in export mode + if offset.size(0) == 1 and coords is not None: + return torch.zeros((coords.shape[0]), device=coords.device, dtype=torch.long) + + bincount = offset2bincount(offset) + return torch.arange(len(bincount), device=offset.device, dtype=torch.long).repeat_interleave(bincount) + + +@torch.inference_mode() +def batch2offset(batch): + # NOTE(knzo25): hack to avoid unsupported ops in export mode + if batch.detach().cpu().numpy().max() == 0: + return torch.Tensor([batch.size(0)]).to(device=batch.device).type(batch.dtype) + return torch.cumsum(batch.bincount(), dim=0).long() + + +def off_diagonal(x): + # return a flattened view of the off-diagonal elements of a square matrix + n, m = x.shape + assert n == m + return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten() diff --git a/projects/PTv3/models/utils/serialization/__init__.py b/projects/PTv3/models/utils/serialization/__init__.py new file mode 100644 index 00000000..c1ad0b5f --- /dev/null +++ b/projects/PTv3/models/utils/serialization/__init__.py @@ -0,0 +1,8 @@ +from .default import ( + decode, + encode, + hilbert_decode, + hilbert_encode, + z_order_decode, + z_order_encode, +) diff --git a/projects/PTv3/models/utils/serialization/default.py b/projects/PTv3/models/utils/serialization/default.py new file mode 100644 index 00000000..4ac74af4 --- /dev/null +++ b/projects/PTv3/models/utils/serialization/default.py @@ -0,0 +1,60 @@ +import torch + +from .hilbert import decode as hilbert_decode_ +from .hilbert import encode as hilbert_encode_ +from .z_order import key2xyz as z_order_decode_ +from .z_order import xyz2key as z_order_encode_ + + +@torch.inference_mode() +def encode(grid_coord, batch=None, depth=16, order="z"): + assert order in {"z", "z-trans", "hilbert", "hilbert-trans"} + if order == "z": + code = z_order_encode(grid_coord, depth=depth) + elif order == "z-trans": + code = z_order_encode(grid_coord[:, [1, 0, 2]], depth=depth) + elif order == "hilbert": + code = hilbert_encode(grid_coord, depth=depth) + elif order == "hilbert-trans": + code = hilbert_encode(grid_coord[:, [1, 0, 2]], depth=depth) + else: + raise NotImplementedError + if batch is not None: + batch = batch.long() + code = batch << depth * 3 | code + return code + + +@torch.inference_mode() +def decode(code, depth=16, order="z"): + assert order in {"z", "hilbert"} + batch = code >> depth * 3 + code = code & ((1 << depth * 3) - 1) + if order == "z": + grid_coord = z_order_decode(code, depth=depth) + elif order == "hilbert": + grid_coord = hilbert_decode(code, depth=depth) + else: + raise NotImplementedError + return grid_coord, batch + + +def z_order_encode(grid_coord: torch.Tensor, depth: int = 16): + x, y, z = grid_coord[:, 0].long(), grid_coord[:, 1].long(), grid_coord[:, 2].long() + # we block the support to batch, maintain batched code in Point class + code = z_order_encode_(x, y, z, b=None, depth=depth) + return code + + +def z_order_decode(code: torch.Tensor, depth): + x, y, z = z_order_decode_(code, depth=depth) + grid_coord = torch.stack([x, y, z], dim=-1) # (N, 3) + return grid_coord + + +def hilbert_encode(grid_coord: torch.Tensor, depth: int = 16): + return hilbert_encode_(grid_coord, num_dims=3, num_bits=depth) + + +def hilbert_decode(code: torch.Tensor, depth: int = 16): + return hilbert_decode_(code, num_dims=3, num_bits=depth) diff --git a/projects/PTv3/models/utils/serialization/hilbert.py b/projects/PTv3/models/utils/serialization/hilbert.py new file mode 100644 index 00000000..f8d62e05 --- /dev/null +++ b/projects/PTv3/models/utils/serialization/hilbert.py @@ -0,0 +1,276 @@ +""" +Hilbert Order +Modified from https://github.com/PrincetonLIPS/numpy-hilbert-curve + +Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com), Kaixin Xu +Please cite our work if the code is helpful to you. +""" + +import torch + + +def right_shift(binary, k=1, axis=-1): + """Right shift an array of binary values. + + Parameters: + ----------- + binary: An ndarray of binary values. + + k: The number of bits to shift. Default 1. + + axis: The axis along which to shift. Default -1. + + Returns: + -------- + Returns an ndarray with zero prepended and the ends truncated, along + whatever axis was specified.""" + + # If we're shifting the whole thing, just return zeros. + if binary.shape[axis] <= k: + return torch.zeros_like(binary) + + # Determine the padding pattern. + # padding = [(0,0)] * len(binary.shape) + # padding[axis] = (k,0) + + # Determine the slicing pattern to eliminate just the last one. + slicing = [slice(None)] * len(binary.shape) + slicing[axis] = slice(None, -k) + shifted = torch.nn.functional.pad(binary[tuple(slicing)], (k, 0), mode="constant", value=0) + + return shifted + + +def binary2gray(binary, axis=-1): + """Convert an array of binary values into Gray codes. + + This uses the classic X ^ (X >> 1) trick to compute the Gray code. + + Parameters: + ----------- + binary: An ndarray of binary values. + + axis: The axis along which to compute the gray code. Default=-1. + + Returns: + -------- + Returns an ndarray of Gray codes. + """ + shifted = right_shift(binary, axis=axis) + + # Do the X ^ (X >> 1) trick. + gray = torch.logical_xor(binary, shifted) + + return gray + + +def gray2binary(gray, axis=-1): + """Convert an array of Gray codes back into binary values. + + Parameters: + ----------- + gray: An ndarray of gray codes. + + axis: The axis along which to perform Gray decoding. Default=-1. + + Returns: + -------- + Returns an ndarray of binary values. + """ + + # Loop the log2(bits) number of times necessary, with shift and xor. + shift = 2 ** (torch.Tensor([gray.shape[axis]]).log2().ceil().int() - 1) + while shift > 0: + gray = torch.logical_xor(gray, right_shift(gray, shift)) + shift = torch.div(shift, 2, rounding_mode="floor") + return gray + + +def encode(locs, num_dims, num_bits): + """Decode an array of locations in a hypercube into a Hilbert integer. + + This is a vectorized-ish version of the Hilbert curve implementation by John + Skilling as described in: + + Skilling, J. (2004, April). Programming the Hilbert curve. In AIP Conference + Proceedings (Vol. 707, No. 1, pp. 381-387). American Institute of Physics. + + Params: + ------- + locs - An ndarray of locations in a hypercube of num_dims dimensions, in + which each dimension runs from 0 to 2**num_bits-1. The shape can + be arbitrary, as long as the last dimension of the same has size + num_dims. + + num_dims - The dimensionality of the hypercube. Integer. + + num_bits - The number of bits for each dimension. Integer. + + Returns: + -------- + The output is an ndarray of uint64 integers with the same shape as the + input, excluding the last dimension, which needs to be num_dims. + """ + + # Keep around the original shape for later. + orig_shape = locs.shape + bitpack_mask = 1 << torch.arange(0, 8).to(locs.device) + bitpack_mask_rev = bitpack_mask.flip(-1) + + if orig_shape[-1] != num_dims: + raise ValueError( + """ + The shape of locs was surprising in that the last dimension was of size + %d, but num_dims=%d. These need to be equal. + """ + % (orig_shape[-1], num_dims) + ) + + if num_dims * num_bits > 63: + raise ValueError( + """ + num_dims=%d and num_bits=%d for %d bits total, which can't be encoded + into a int64. Are you sure you need that many points on your Hilbert + curve? + """ + % (num_dims, num_bits, num_dims * num_bits) + ) + + # Treat the location integers as 64-bit unsigned and then split them up into + # a sequence of uint8s. Preserve the association by dimension. + locs_uint8 = locs.long().view(torch.uint8).reshape((-1, num_dims, 8)).flip(-1) + + # Now turn these into bits and truncate to num_bits. + gray = locs_uint8.unsqueeze(-1).bitwise_and(bitpack_mask_rev).ne(0).byte().flatten(-2, -1)[..., -num_bits:] + + # Run the decoding process the other way. + # Iterate forwards through the bits. + for bit in range(0, num_bits): + # Iterate forwards through the dimensions. + for dim in range(0, num_dims): + # Identify which ones have this bit active. + mask = gray[:, dim, bit] + + # Where this bit is on, invert the 0 dimension for lower bits. + gray[:, 0, bit + 1 :] = torch.logical_xor(gray[:, 0, bit + 1 :], mask[:, None]) + + # Where the bit is off, exchange the lower bits with the 0 dimension. + to_flip = torch.logical_and( + torch.logical_not(mask[:, None]).repeat(1, gray.shape[2] - bit - 1), + torch.logical_xor(gray[:, 0, bit + 1 :], gray[:, dim, bit + 1 :]), + ) + gray[:, dim, bit + 1 :] = torch.logical_xor(gray[:, dim, bit + 1 :], to_flip) + gray[:, 0, bit + 1 :] = torch.logical_xor(gray[:, 0, bit + 1 :], to_flip) + + # Now flatten out. + gray = gray.swapaxes(1, 2).reshape((-1, num_bits * num_dims)) + + # Convert Gray back to binary. + hh_bin = gray2binary(gray) + + # Pad back out to 64 bits. + extra_dims = 64 - num_bits * num_dims + padded = torch.nn.functional.pad(hh_bin, (extra_dims, 0), "constant", 0) + + # Convert binary values into uint8s. + hh_uint8 = (padded.flip(-1).reshape((-1, 8, 8)) * bitpack_mask).sum(2).squeeze().type(torch.uint8) + + # Convert uint8s into uint64s. + hh_uint64 = hh_uint8.view(torch.int64).squeeze() + + return hh_uint64 + + +def decode(hilberts, num_dims, num_bits): + """Decode an array of Hilbert integers into locations in a hypercube. + + This is a vectorized-ish version of the Hilbert curve implementation by John + Skilling as described in: + + Skilling, J. (2004, April). Programming the Hilbert curve. In AIP Conference + Proceedings (Vol. 707, No. 1, pp. 381-387). American Institute of Physics. + + Params: + ------- + hilberts - An ndarray of Hilbert integers. Must be an integer dtype and + cannot have fewer bits than num_dims * num_bits. + + num_dims - The dimensionality of the hypercube. Integer. + + num_bits - The number of bits for each dimension. Integer. + + Returns: + -------- + The output is an ndarray of unsigned integers with the same shape as hilberts + but with an additional dimension of size num_dims. + """ + + if num_dims * num_bits > 64: + raise ValueError( + """ + num_dims=%d and num_bits=%d for %d bits total, which can't be encoded + into a uint64. Are you sure you need that many points on your Hilbert + curve? + """ + % (num_dims, num_bits) + ) + + # Handle the case where we got handed a naked integer. + hilberts = torch.atleast_1d(hilberts) + + # Keep around the shape for later. + orig_shape = hilberts.shape + bitpack_mask = 2 ** torch.arange(0, 8).to(hilberts.device) + bitpack_mask_rev = bitpack_mask.flip(-1) + + # Treat each of the hilberts as a s equence of eight uint8. + # This treats all of the inputs as uint64 and makes things uniform. + hh_uint8 = hilberts.ravel().type(torch.int64).view(torch.uint8).reshape((-1, 8)).flip(-1) + + # Turn these lists of uints into lists of bits and then truncate to the size + # we actually need for using Skilling's procedure. + hh_bits = ( + hh_uint8.unsqueeze(-1).bitwise_and(bitpack_mask_rev).ne(0).byte().flatten(-2, -1)[:, -num_dims * num_bits :] + ) + + # Take the sequence of bits and Gray-code it. + gray = binary2gray(hh_bits) + + # There has got to be a better way to do this. + # I could index them differently, but the eventual packbits likes it this way. + gray = gray.reshape((-1, num_bits, num_dims)).swapaxes(1, 2) + + # Iterate backwards through the bits. + for bit in range(num_bits - 1, -1, -1): + # Iterate backwards through the dimensions. + for dim in range(num_dims - 1, -1, -1): + # Identify which ones have this bit active. + mask = gray[:, dim, bit] + + # Where this bit is on, invert the 0 dimension for lower bits. + gray[:, 0, bit + 1 :] = torch.logical_xor(gray[:, 0, bit + 1 :], mask[:, None]) + + # Where the bit is off, exchange the lower bits with the 0 dimension. + to_flip = torch.logical_and( + torch.logical_not(mask[:, None]), + torch.logical_xor(gray[:, 0, bit + 1 :], gray[:, dim, bit + 1 :]), + ) + gray[:, dim, bit + 1 :] = torch.logical_xor(gray[:, dim, bit + 1 :], to_flip) + gray[:, 0, bit + 1 :] = torch.logical_xor(gray[:, 0, bit + 1 :], to_flip) + + # Pad back out to 64 bits. + extra_dims = 64 - num_bits + padded = torch.nn.functional.pad(gray, (extra_dims, 0), "constant", 0) + + # Now chop these up into blocks of 8. + locs_chopped = padded.flip(-1).reshape((-1, num_dims, 8, 8)) + + # Take those blocks and turn them unto uint8s. + # from IPython import embed; embed() + locs_uint8 = (locs_chopped * bitpack_mask).sum(3).squeeze().type(torch.uint8) + + # Finally, treat these as uint64s. + flat_locs = locs_uint8.view(torch.int64) + + # Return them in the expected shape. + return flat_locs.reshape((*orig_shape, num_dims)) diff --git a/projects/PTv3/models/utils/serialization/z_order.py b/projects/PTv3/models/utils/serialization/z_order.py new file mode 100644 index 00000000..2f951b0e --- /dev/null +++ b/projects/PTv3/models/utils/serialization/z_order.py @@ -0,0 +1,122 @@ +# -------------------------------------------------------- +# Octree-based Sparse Convolutional Neural Networks +# Copyright (c) 2022 Peng-Shuai Wang +# Licensed under The MIT License [see LICENSE for details] +# Written by Peng-Shuai Wang +# -------------------------------------------------------- + +from typing import Optional, Union + +import torch + + +class KeyLUT: + def __init__(self): + r256 = torch.arange(256, dtype=torch.int64) + r512 = torch.arange(512, dtype=torch.int64) + zero = torch.zeros(256, dtype=torch.int64) + device = torch.device("cpu") + + self._encode = { + device: ( + self.xyz2key(r256, zero, zero, 8), + self.xyz2key(zero, r256, zero, 8), + self.xyz2key(zero, zero, r256, 8), + ) + } + self._decode = {device: self.key2xyz(r512, 9)} + + def encode_lut(self, device=torch.device("cpu")): + if device not in self._encode: + cpu = torch.device("cpu") + self._encode[device] = tuple(e.to(device) for e in self._encode[cpu]) + return self._encode[device] + + def decode_lut(self, device=torch.device("cpu")): + if device not in self._decode: + cpu = torch.device("cpu") + self._decode[device] = tuple(e.to(device) for e in self._decode[cpu]) + return self._decode[device] + + def xyz2key(self, x, y, z, depth): + key = torch.zeros_like(x) + for i in range(depth): + mask = 1 << i + key = key | ((x & mask) << (2 * i + 2)) | ((y & mask) << (2 * i + 1)) | ((z & mask) << (2 * i + 0)) + return key + + def key2xyz(self, key, depth): + x = torch.zeros_like(key) + y = torch.zeros_like(key) + z = torch.zeros_like(key) + for i in range(depth): + x = x | ((key & (1 << (3 * i + 2))) >> (2 * i + 2)) + y = y | ((key & (1 << (3 * i + 1))) >> (2 * i + 1)) + z = z | ((key & (1 << (3 * i + 0))) >> (2 * i + 0)) + return x, y, z + + +_key_lut = KeyLUT() + + +def xyz2key( + x: torch.Tensor, + y: torch.Tensor, + z: torch.Tensor, + b: Optional[Union[torch.Tensor, int]] = None, + depth: int = 16, +): + r"""Encodes :attr:`x`, :attr:`y`, :attr:`z` coordinates to the shuffled keys + based on pre-computed look up tables. The speed of this function is much + faster than the method based on for-loop. + + Args: + x (torch.Tensor): The x coordinate. + y (torch.Tensor): The y coordinate. + z (torch.Tensor): The z coordinate. + b (torch.Tensor or int): The batch index of the coordinates, and should be + smaller than 32768. If :attr:`b` is :obj:`torch.Tensor`, the size of + :attr:`b` must be the same as :attr:`x`, :attr:`y`, and :attr:`z`. + depth (int): The depth of the shuffled key, and must be smaller than 17 (< 17). + """ + + EX, EY, EZ = _key_lut.encode_lut(x.device) + x, y, z = x.long(), y.long(), z.long() + + mask = 255 if depth > 8 else (1 << depth) - 1 + key = EX[x & mask] | EY[y & mask] | EZ[z & mask] + if depth > 8: + mask = (1 << (depth - 8)) - 1 + key16 = EX[(x >> 8) & mask] | EY[(y >> 8) & mask] | EZ[(z >> 8) & mask] + key = key16 << 24 | key + + if b is not None: + b = b.long() + key = b << 48 | key + + return key + + +def key2xyz(key: torch.Tensor, depth: int = 16): + r"""Decodes the shuffled key to :attr:`x`, :attr:`y`, :attr:`z` coordinates + and the batch index based on pre-computed look up tables. + + Args: + key (torch.Tensor): The shuffled key. + depth (int): The depth of the shuffled key, and must be smaller than 17 (< 17). + """ + + DX, DY, DZ = _key_lut.decode_lut(key.device) + x, y, z = torch.zeros_like(key), torch.zeros_like(key), torch.zeros_like(key) + + b = key >> 48 + key = key & ((1 << 48) - 1) + + n = (depth + 2) // 3 + for i in range(n): + k = key >> (i * 9) & 511 + x = x | (DX[k] << (i * 3)) + y = y | (DY[k] << (i * 3)) + z = z | (DZ[k] << (i * 3)) + + return x, y, z, b diff --git a/projects/PTv3/models/utils/structure.py b/projects/PTv3/models/utils/structure.py new file mode 100644 index 00000000..8d027cec --- /dev/null +++ b/projects/PTv3/models/utils/structure.py @@ -0,0 +1,139 @@ +import spconv.pytorch as spconv +import torch +from addict import Dict +from models.utils import batch2offset, offset2batch +from models.utils.serialization import decode, encode + + +def bit_length_tensor(x: torch.Tensor) -> torch.Tensor: + # Ensure x is a positive integer tensor + x = torch.clamp(x, min=1) + return torch.floor(torch.log2(x)).to(torch.int64) + 1 + + +class Point(Dict): + """ + Point Structure of Pointcept + + A Point (point cloud) in Pointcept is a dictionary that contains various properties of + a batched point cloud. The property with the following names have a specific definition + as follows: + + - "coord": original coordinate of point cloud; + - "grid_coord": grid coordinate for specific grid size (related to GridSampling); + Point also support the following optional attributes: + - "offset": if not exist, initialized as batch size is 1; + - "batch": if not exist, initialized as batch size is 1; + - "feat": feature of point cloud, default input of model; + - "grid_size": Grid size of point cloud (related to GridSampling); + (related to Serialization) + - "serialized_depth": depth of serialization, 2 ** depth * grid_size describe the maximum of point cloud range; + - "serialized_code": a list of serialization codes; + - "serialized_order": a list of serialization order determined by code; + - "serialized_inverse": a list of inverse mapping determined by code; + (related to Sparsify: SpConv) + - "sparse_shape": Sparse shape for Sparse Conv Tensor; + - "sparse_conv_feat": SparseConvTensor init with information provide by Point; + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # If one of "offset" or "batch" do not exist, generate by the existing one + # If neither of them exist, initialize as batch size is 1 + if "offset" not in self.keys() and "batch" not in self.keys(): + self["offset"] = torch.tensor([self["coord"].size(0)], device=self["coord"].device, dtype=torch.int64) + self["batch"] = offset2batch(self.offset) + elif "batch" not in self.keys() and "offset" in self.keys(): + self["batch"] = offset2batch(self.offset, self["grid_coord"]) + elif "offset" not in self.keys() and "batch" in self.keys(): + self["offset"] = batch2offset(self.batch) + + def serialization(self, order="z", depth=None, shuffle_orders=False): + """ + Point Cloud Serialization + + relay on ["grid_coord" or "coord" + "grid_size", "batch", "feat"] + """ + assert "batch" in self.keys() + if "grid_coord" not in self.keys(): + # if you don't want to operate GridSampling in data augmentation, + # please add the following augmentation into your pipline: + # dict(type="Copy", keys_dict={"grid_size": 0.01}), + # (adjust `grid_size` to what your want) + assert {"grid_size", "coord"}.issubset(self.keys()) + self["grid_coord"] = torch.div( + self.coord - self.coord.min(0)[0], self.grid_size, rounding_mode="trunc" + ).int() + + if depth is None: + # Adaptive measure the depth of serialization cube (length = 2 ^ depth) + # depth = int(self.grid_coord.max()).bit_length() + depth = bit_length_tensor(self.grid_coord.max()) + + self["serialized_depth"] = depth + # Maximum bit length for serialization code is 63 (int64) + # assert depth * 3 + len(self.offset).bit_length() <= 63 + assert depth * 3 + bit_length_tensor(self.offset) <= 63 + # Here we follow OCNN and set the depth limitation to 16 (48bit) for the point position. + # Although depth is limited to less than 16, we can encode a 655.36^3 (2^16 * 0.01) meter^3 + # cube with a grid size of 0.01 meter. We consider it is enough for the current stage. + # We can unlock the limitation by optimizing the z-order encoding function if necessary. + assert depth <= 16 + + # The serialization codes are arranged as following structures: + # [Order1 ([n]), + # Order2 ([n]), + # ... + # OrderN ([n])] (k, n) + code = [encode(self.grid_coord, self.batch, depth, order=order_) for order_ in order] + code = torch.stack(code) + order = torch.argsort(code) + inverse = torch.zeros_like(order).scatter_( + dim=1, + index=order, + src=torch.arange(0, code.shape[1], device=order.device).repeat(code.shape[0], 1), + ) + + if shuffle_orders: + perm = torch.randperm(code.shape[0]) + code = code[perm] + order = order[perm] + inverse = inverse[perm] + + self["serialized_code"] = code + self["serialized_order"] = order + self["serialized_inverse"] = inverse + + def sparsify(self, pad=96): + """ + Point Cloud Serialization + + Point cloud is sparse, here we use "sparsify" to specifically refer to + preparing "spconv.SparseConvTensor" for SpConv. + + relay on ["grid_coord" or "coord" + "grid_size", "batch", "feat"] + + pad: padding sparse for sparse shape. + """ + assert {"feat", "batch"}.issubset(self.keys()) + if "grid_coord" not in self.keys(): + # if you don't want to operate GridSampling in data augmentation, + # please add the following augmentation into your pipline: + # dict(type="Copy", keys_dict={"grid_size": 0.01}), + # (adjust `grid_size` to what your want) + assert {"grid_size", "coord"}.issubset(self.keys()) + self["grid_coord"] = torch.div( + self.coord - self.coord.min(0)[0], self.grid_size, rounding_mode="trunc" + ).int() + if "sparse_shape" in self.keys(): + sparse_shape = self.sparse_shape + else: + sparse_shape = torch.add(torch.max(self.grid_coord, dim=0).values, pad) # .tolist() + sparse_conv_feat = spconv.SparseConvTensor( + features=self.feat, + indices=torch.cat([self.batch.unsqueeze(-1).int(), self.grid_coord.int()], dim=1).contiguous(), + spatial_shape=sparse_shape, + batch_size=self.batch[-1].tolist() + 1, + ) + self["sparse_shape"] = sparse_shape + self["sparse_conv_feat"] = sparse_conv_feat diff --git a/projects/PTv3/scripts/test.sh b/projects/PTv3/scripts/test.sh new file mode 100644 index 00000000..a104f98e --- /dev/null +++ b/projects/PTv3/scripts/test.sh @@ -0,0 +1,74 @@ +#!/bin/sh + +cd $(dirname $(dirname "$0")) || exit +PYTHON=python + +TEST_CODE=test.py + +DATASET=scannet +CONFIG="None" +EXP_NAME=debug +WEIGHT=model_best +GPU=None + +while getopts "p:d:c:n:w:g:" opt; do + case $opt in + p) + PYTHON=$OPTARG + ;; + d) + DATASET=$OPTARG + ;; + c) + CONFIG=$OPTARG + ;; + n) + EXP_NAME=$OPTARG + ;; + w) + WEIGHT=$OPTARG + ;; + g) + GPU=$OPTARG + ;; + \?) + echo "Invalid option: -$OPTARG" + ;; + esac +done + +if [ "${NUM_GPU}" = 'None' ] +then + NUM_GPU=`$PYTHON -c 'import torch; print(torch.cuda.device_count())'` +fi + +echo "Experiment name: $EXP_NAME" +echo "Python interpreter dir: $PYTHON" +echo "Dataset: $DATASET" +echo "GPU Num: $GPU" + +EXP_DIR=exp/${DATASET}/${EXP_NAME} +MODEL_DIR=${EXP_DIR}/model +CODE_DIR=${EXP_DIR}/code +CONFIG_DIR=${EXP_DIR}/config.py + +if [ "${CONFIG}" = "None" ] +then + CONFIG_DIR=${EXP_DIR}/config.py +else + CONFIG_DIR=configs/${DATASET}/${CONFIG}.py +fi + +echo "Loading config in:" $CONFIG_DIR +#export PYTHONPATH=./$CODE_DIR +export PYTHONPATH=./ +echo "Running code in: $CODE_DIR" + + +echo " =========> RUN TASK <=========" + +#$PYTHON -u "$CODE_DIR"/tools/$TEST_CODE \ +$PYTHON -u tools/$TEST_CODE \ + --config-file "$CONFIG_DIR" \ + --num-gpus "$GPU" \ + --options save_path="$EXP_DIR" weight="${MODEL_DIR}"/"${WEIGHT}".pth diff --git a/projects/PTv3/scripts/train.sh b/projects/PTv3/scripts/train.sh new file mode 100644 index 00000000..6cbeeef9 --- /dev/null +++ b/projects/PTv3/scripts/train.sh @@ -0,0 +1,92 @@ +#!/bin/sh + +cd $(dirname $(dirname "$0")) || exit +ROOT_DIR=$(pwd) +PYTHON=python + +TRAIN_CODE=train.py + +DATASET=scannet +CONFIG="None" +EXP_NAME=debug +WEIGHT="None" +RESUME=false +GPU=None + + +while getopts "p:d:c:n:w:g:r:" opt; do + case $opt in + p) + PYTHON=$OPTARG + ;; + d) + DATASET=$OPTARG + ;; + c) + CONFIG=$OPTARG + ;; + n) + EXP_NAME=$OPTARG + ;; + w) + WEIGHT=$OPTARG + ;; + r) + RESUME=$OPTARG + ;; + g) + GPU=$OPTARG + ;; + \?) + echo "Invalid option: -$OPTARG" + ;; + esac +done + +if [ "${NUM_GPU}" = 'None' ] +then + NUM_GPU=`$PYTHON -c 'import torch; print(torch.cuda.device_count())'` +fi + +echo "Experiment name: $EXP_NAME" +echo "Python interpreter dir: $PYTHON" +echo "Dataset: $DATASET" +echo "Config: $CONFIG" +echo "GPU Num: $GPU" + +EXP_DIR=exp/${DATASET}/${EXP_NAME} +MODEL_DIR=${EXP_DIR}/model +CODE_DIR=${EXP_DIR}/code +CONFIG_DIR=configs/${DATASET}/${CONFIG}.py + + +echo " =========> CREATE EXP DIR <=========" +echo "Experiment dir: $ROOT_DIR/$EXP_DIR" +if ${RESUME} +then + CONFIG_DIR=${EXP_DIR}/config.py + WEIGHT=$MODEL_DIR/model_last.pth +else + mkdir -p "$MODEL_DIR" "$CODE_DIR" + cp -r scripts tools pointcept "$CODE_DIR" +fi + +echo "Loading config in:" $CONFIG_DIR +export PYTHONPATH=./$CODE_DIR +echo "Running code in: $CODE_DIR" + + +echo " =========> RUN TASK <=========" + +if [ "${WEIGHT}" = "None" ] +then + $PYTHON "$CODE_DIR"/tools/$TRAIN_CODE \ + --config-file "$CONFIG_DIR" \ + --num-gpus "$GPU" \ + --options save_path="$EXP_DIR" +else + $PYTHON "$CODE_DIR"/tools/$TRAIN_CODE \ + --config-file "$CONFIG_DIR" \ + --num-gpus "$GPU" \ + --options save_path="$EXP_DIR" resume="$RESUME" weight="$WEIGHT" +fi diff --git a/projects/PTv3/tools/export.py b/projects/PTv3/tools/export.py new file mode 100644 index 00000000..a96ded2a --- /dev/null +++ b/projects/PTv3/tools/export.py @@ -0,0 +1,153 @@ +import numpy as np +import SparseConvolution # NOTE(knzo25): do not remove this import, it is needed for onnx export +import spconv.pytorch as spconv +import torch +from engines.defaults import ( + default_argument_parser, + default_config_parser, + default_setup, +) +from engines.train import TRAINERS +from models.scatter.functional import argsort +from models.utils.structure import Point, bit_length_tensor +from torch.nn import functional as F + + +class WrappedModel(torch.nn.Module): + + def __init__(self, model, cfg): + super(WrappedModel, self).__init__() + self.cfg = cfg + self.model = model.cuda() + self.model.backbone.forward = self.model.backbone.export_forward + + point_cloud_range = torch.tensor(cfg.point_cloud_range, dtype=torch.float32).cuda() + voxel_size = cfg.grid_size + voxel_size = torch.tensor([voxel_size, voxel_size, voxel_size], dtype=torch.float32).cuda() + + self.sparse_shape = (point_cloud_range[3:] - point_cloud_range[:3]) / voxel_size + self.sparse_shape = torch.round(self.sparse_shape).long().cuda() + + def forward( + self, + grid_coord, + feat, + serialized_depth, + serialized_code, + ): + + shape = torch._shape_as_tensor(grid_coord).to(grid_coord.device) + + serialized_order = torch.stack([argsort(code) for code in serialized_code], dim=0) + serialized_inverse = torch.zeros_like(serialized_order).scatter_( + dim=1, + index=serialized_order, + src=torch.arange(0, serialized_code.shape[1], device=serialized_order.device).repeat( + serialized_code.shape[0], 1 + ), + ) + + input_dict = { + "coord": feat[:, :3], + "grid_coord": grid_coord, + "offset": shape[:1], + "feat": feat, + "serialized_depth": serialized_depth, + "serialized_code": serialized_code, + "serialized_order": serialized_order, + "serialized_inverse": serialized_inverse, + "sparse_shape": self.sparse_shape, + } + + output = self.model(input_dict) + + pred_logits = output["seg_logits"] # (n, k) + pred_probs = F.softmax(pred_logits, -1) + pred_label = pred_probs.argmax(-1) + + return pred_label, pred_probs + + +def main(): + args = default_argument_parser().parse_args() + cfg = default_config_parser(args.config_file, args.options) + + cfg = default_setup(cfg) + cfg.num_worker = 1 + cfg.num_worker_per_gpu = 1 + + # NOTE(knzo25): hacks to allow onnx export + cfg.model.backbone.shuffle_orders = False + cfg.model.backbone.order = ["z", "z-trans"] + cfg.model.backbone.export_mode = True + + runner = TRAINERS.build(dict(type=cfg.train.type, cfg=cfg)) + + runner.before_train() + + model = WrappedModel(runner.model, cfg) + model.eval() + + runner.val_loader.prefetch_factor = 1 + data_dict = next(iter(runner.val_loader)) + + input_dict = data_dict + for key in input_dict.keys(): + if isinstance(input_dict[key], torch.Tensor): + input_dict[key] = input_dict[key].cuda(non_blocking=True) + + with torch.no_grad(): + + depth = bit_length_tensor( + torch.tensor([(max(cfg.point_cloud_range) - min(cfg.point_cloud_range)) / cfg.grid_size]) + ).cuda() + point = Point(input_dict) + point.serialization( + order=model.model.backbone.order, shuffle_orders=model.model.backbone.shuffle_orders, depth=depth + ) + + input_dict["serialized_depth"] = point["serialized_depth"] + input_dict["serialized_code"] = point["serialized_code"] + input_dict.pop("segment") + input_dict.pop("offset") + input_dict.pop("coord") + + pred_labels, pred_probs = model(**input_dict) + + np.savez_compressed("ptv3_sample.npz", pred=pred_labels.cpu().numpy(), feat=input_dict["feat"].cpu().numpy()) + + export_params = (True,) + keep_initializers_as_inputs = False + opset_version = 17 + input_names = ["grid_coord", "feat", "serialized_depth", "serialized_code"] + output_names = ["pred_labels", "pred_probs"] + dynamic_axes = { + "grid_coord": { + 0: "voxels_num", + }, + "feat": { + 0: "voxels_num", + }, + "serialized_code": { + 1: "voxels_num", + }, + } + torch.onnx.export( + model, + input_dict, + "ptv3.onnx", + export_params=export_params, + input_names=input_names, + output_names=output_names, + opset_version=opset_version, + dynamic_axes=dynamic_axes, + keep_initializers_as_inputs=keep_initializers_as_inputs, + verbose=False, + do_constant_folding=False, + ) + + print("Exported to ONNX format successfully.") + + +if __name__ == "__main__": + main() diff --git a/projects/PTv3/tools/test.py b/projects/PTv3/tools/test.py new file mode 100644 index 00000000..ca23eba9 --- /dev/null +++ b/projects/PTv3/tools/test.py @@ -0,0 +1,38 @@ +""" +Main Testing Script + +Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com) +Please cite our work if the code is helpful to you. +""" + +from engines.defaults import ( + default_argument_parser, + default_config_parser, + default_setup, +) +from engines.launch import launch +from engines.test import TESTERS + + +def main_worker(cfg): + cfg = default_setup(cfg) + tester = TESTERS.build(dict(type=cfg.test.type, cfg=cfg)) + tester.test() + + +def main(): + args = default_argument_parser().parse_args() + cfg = default_config_parser(args.config_file, args.options) + + launch( + main_worker, + num_gpus_per_machine=args.num_gpus, + num_machines=args.num_machines, + machine_rank=args.machine_rank, + dist_url=args.dist_url, + cfg=(cfg,), + ) + + +if __name__ == "__main__": + main() diff --git a/projects/PTv3/tools/train.py b/projects/PTv3/tools/train.py new file mode 100644 index 00000000..8305ad97 --- /dev/null +++ b/projects/PTv3/tools/train.py @@ -0,0 +1,38 @@ +""" +Main Training Script + +Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com) +Please cite our work if the code is helpful to you. +""" + +from engines.defaults import ( + default_argument_parser, + default_config_parser, + default_setup, +) +from engines.launch import launch +from engines.train import TRAINERS + + +def main_worker(cfg): + cfg = default_setup(cfg) + trainer = TRAINERS.build(dict(type=cfg.train.type, cfg=cfg)) + trainer.train() + + +def main(): + args = default_argument_parser().parse_args() + cfg = default_config_parser(args.config_file, args.options) + + launch( + main_worker, + num_gpus_per_machine=args.num_gpus, + num_machines=args.num_machines, + machine_rank=args.machine_rank, + dist_url=args.dist_url, + cfg=(cfg,), + ) + + +if __name__ == "__main__": + main() diff --git a/projects/PTv3/utils/__init__.py b/projects/PTv3/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/projects/PTv3/utils/cache.py b/projects/PTv3/utils/cache.py new file mode 100644 index 00000000..d199f363 --- /dev/null +++ b/projects/PTv3/utils/cache.py @@ -0,0 +1,57 @@ +""" +Data Cache Utils + +Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com) +Please cite our work if the code is helpful to you. +""" + +import os + +import SharedArray + +try: + from multiprocessing.shared_memory import ShareableList +except ImportError: + import warnings + + warnings.warn("Please update python version >= 3.8 to enable shared_memory") +import numpy as np + + +def shared_array(name, var=None): + if var is not None: + # check exist + if os.path.exists(f"/dev/shm/{name}"): + return SharedArray.attach(f"shm://{name}") + # create shared_array + data = SharedArray.create(f"shm://{name}", var.shape, dtype=var.dtype) + data[...] = var[...] + data.flags.writeable = False + else: + data = SharedArray.attach(f"shm://{name}").copy() + return data + + +def shared_dict(name, var=None): + name = str(name) + assert "." not in name # '.' is used as sep flag + data = {} + if var is not None: + assert isinstance(var, dict) + keys = var.keys() + # current version only cache np.array + keys_valid = [] + for key in keys: + if isinstance(var[key], np.ndarray): + keys_valid.append(key) + keys = keys_valid + + ShareableList(sequence=keys, name=name + ".keys") + for key in keys: + if isinstance(var[key], np.ndarray): + data[key] = shared_array(name=f"{name}.{key}", var=var[key]) + else: + keys = list(ShareableList(name=name + ".keys")) + for key in keys: + data[key] = shared_array(name=f"{name}.{key}") + return data diff --git a/projects/PTv3/utils/comm.py b/projects/PTv3/utils/comm.py new file mode 100644 index 00000000..0f7a8198 --- /dev/null +++ b/projects/PTv3/utils/comm.py @@ -0,0 +1,197 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +""" +This file contains primitives for multi-gpu communication. +This is useful when doing distributed training. +Modified from detectron2(https://github.com/facebookresearch/detectron2) + +Copyright (c) Xiaoyang Wu (xiaoyang.wu@connect.hku.hk). All Rights Reserved. +Please cite our work if you use any part of the code. +""" + +import functools + +import numpy as np +import torch +import torch.distributed as dist + +_LOCAL_PROCESS_GROUP = None +""" +A torch process group which only includes processes that on the same machine as the current process. +This variable is set when processes are spawned by `launch()` in "engine/launch.py". +""" + + +def get_world_size() -> int: + if not dist.is_available(): + return 1 + if not dist.is_initialized(): + return 1 + return dist.get_world_size() + + +def get_rank() -> int: + if not dist.is_available(): + return 0 + if not dist.is_initialized(): + return 0 + return dist.get_rank() + + +def get_local_rank() -> int: + """ + Returns: + The rank of the current process within the local (per-machine) process group. + """ + if not dist.is_available(): + return 0 + if not dist.is_initialized(): + return 0 + assert ( + _LOCAL_PROCESS_GROUP is not None + ), "Local process group is not created! Please use launch() to spawn processes!" + return dist.get_rank(group=_LOCAL_PROCESS_GROUP) + + +def get_local_size() -> int: + """ + Returns: + The size of the per-machine process group, + i.e. the number of processes per machine. + """ + if not dist.is_available(): + return 1 + if not dist.is_initialized(): + return 1 + return dist.get_world_size(group=_LOCAL_PROCESS_GROUP) + + +def is_main_process() -> bool: + return get_rank() == 0 + + +def synchronize(): + """ + Helper function to synchronize (barrier) among all processes when + using distributed training + """ + if not dist.is_available(): + return + if not dist.is_initialized(): + return + world_size = dist.get_world_size() + if world_size == 1: + return + if dist.get_backend() == dist.Backend.NCCL: + # This argument is needed to avoid warnings. + # It's valid only for NCCL backend. + dist.barrier(device_ids=[torch.cuda.current_device()]) + else: + dist.barrier() + + +@functools.lru_cache() +def _get_global_gloo_group(): + """ + Return a process group based on gloo backend, containing all the ranks + The result is cached. + """ + if dist.get_backend() == "nccl": + return dist.new_group(backend="gloo") + else: + return dist.group.WORLD + + +def all_gather(data, group=None): + """ + Run all_gather on arbitrary picklable data (not necessarily tensors). + Args: + data: any picklable object + group: a torch process group. By default, will use a group which + contains all ranks on gloo backend. + Returns: + list[data]: list of data gathered from each rank + """ + if get_world_size() == 1: + return [data] + if group is None: + group = _get_global_gloo_group() # use CPU group by default, to reduce GPU RAM usage. + world_size = dist.get_world_size(group) + if world_size == 1: + return [data] + + output = [None for _ in range(world_size)] + dist.all_gather_object(output, data, group=group) + return output + + +def gather(data, dst=0, group=None): + """ + Run gather on arbitrary picklable data (not necessarily tensors). + Args: + data: any picklable object + dst (int): destination rank + group: a torch process group. By default, will use a group which + contains all ranks on gloo backend. + Returns: + list[data]: on dst, a list of data gathered from each rank. Otherwise, + an empty list. + """ + if get_world_size() == 1: + return [data] + if group is None: + group = _get_global_gloo_group() + world_size = dist.get_world_size(group=group) + if world_size == 1: + return [data] + rank = dist.get_rank(group=group) + + if rank == dst: + output = [None for _ in range(world_size)] + dist.gather_object(data, output, dst=dst, group=group) + return output + else: + dist.gather_object(data, None, dst=dst, group=group) + return [] + + +def shared_random_seed(): + """ + Returns: + int: a random number that is the same across all workers. + If workers need a shared RNG, they can use this shared seed to + create one. + All workers must call this function, otherwise it will deadlock. + """ + ints = np.random.randint(2**31) + all_ints = all_gather(ints) + return all_ints[0] + + +def reduce_dict(input_dict, average=True): + """ + Reduce the values in the dictionary from all processes so that process with rank + 0 has the reduced results. + Args: + input_dict (dict): inputs to be reduced. All the values must be scalar CUDA Tensor. + average (bool): whether to do average or sum + Returns: + a dict with the same keys as input_dict, after reduction. + """ + world_size = get_world_size() + if world_size < 2: + return input_dict + with torch.no_grad(): + names = [] + values = [] + # sort the keys so that they are consistent across processes + for k in sorted(input_dict.keys()): + names.append(k) + values.append(input_dict[k]) + values = torch.stack(values, dim=0) + dist.reduce(values, dst=0) + if dist.get_rank() == 0 and average: + # only main process gets accumulated, so only divide by + # world_size in this case + values /= world_size + reduced_dict = {k: v for k, v in zip(names, values)} + return reduced_dict diff --git a/projects/PTv3/utils/config.py b/projects/PTv3/utils/config.py new file mode 100644 index 00000000..780fe303 --- /dev/null +++ b/projects/PTv3/utils/config.py @@ -0,0 +1,656 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import ast +import copy +import os +import os.path as osp +import platform +import shutil +import sys +import tempfile +import uuid +import warnings +from argparse import Action, ArgumentParser +from collections import abc +from importlib import import_module + +from addict import Dict +from yapf.yapflib.yapf_api import FormatCode + +from .misc import import_modules_from_strings +from .path import check_file_exist + +if platform.system() == "Windows": + import regex as re +else: + import re + +BASE_KEY = "_base_" +DELETE_KEY = "_delete_" +DEPRECATION_KEY = "_deprecation_" +RESERVED_KEYS = ["filename", "text", "pretty_text"] + + +class ConfigDict(Dict): + def __missing__(self, name): + raise KeyError(name) + + def __getattr__(self, name): + try: + value = super(ConfigDict, self).__getattr__(name) + except KeyError: + ex = AttributeError(f"'{self.__class__.__name__}' object has no " f"attribute '{name}'") + except Exception as e: + ex = e + else: + return value + raise ex + + +def add_args(parser, cfg, prefix=""): + for k, v in cfg.items(): + if isinstance(v, str): + parser.add_argument("--" + prefix + k) + elif isinstance(v, int): + parser.add_argument("--" + prefix + k, type=int) + elif isinstance(v, float): + parser.add_argument("--" + prefix + k, type=float) + elif isinstance(v, bool): + parser.add_argument("--" + prefix + k, action="store_true") + elif isinstance(v, dict): + add_args(parser, v, prefix + k + ".") + elif isinstance(v, abc.Iterable): + parser.add_argument("--" + prefix + k, type=type(v[0]), nargs="+") + else: + print(f"cannot parse key {prefix + k} of type {type(v)}") + return parser + + +class Config: + """A facility for config and config files. + + It supports common file formats as configs: python/json/yaml. The interface + is the same as a dict object and also allows access config values as + attributes. + + Example: + >>> cfg = Config(dict(a=1, b=dict(b1=[0, 1]))) + >>> cfg.a + 1 + >>> cfg.b + {'b1': [0, 1]} + >>> cfg.b.b1 + [0, 1] + >>> cfg = Config.fromfile('tests/data/config/a.py') + >>> cfg.filename + "/home/kchen/projects/mmcv/tests/data/config/a.py" + >>> cfg.item4 + 'test' + >>> cfg + "Config [path: /home/kchen/projects/mmcv/tests/data/config/a.py]: " + "{'item1': [1, 2], 'item2': {'a': 0}, 'item3': True, 'item4': 'test'}" + """ + + @staticmethod + def _validate_py_syntax(filename): + with open(filename, "r", encoding="utf-8") as f: + # Setting encoding explicitly to resolve coding issue on windows + content = f.read() + try: + ast.parse(content) + except SyntaxError as e: + raise SyntaxError("There are syntax errors in config " f"file {filename}: {e}") + + @staticmethod + def _substitute_predefined_vars(filename, temp_config_name): + file_dirname = osp.dirname(filename) + file_basename = osp.basename(filename) + file_basename_no_extension = osp.splitext(file_basename)[0] + file_extname = osp.splitext(filename)[1] + support_templates = dict( + fileDirname=file_dirname, + fileBasename=file_basename, + fileBasenameNoExtension=file_basename_no_extension, + fileExtname=file_extname, + ) + with open(filename, "r", encoding="utf-8") as f: + # Setting encoding explicitly to resolve coding issue on windows + config_file = f.read() + for key, value in support_templates.items(): + regexp = r"\{\{\s*" + str(key) + r"\s*\}\}" + value = value.replace("\\", "/") + config_file = re.sub(regexp, value, config_file) + with open(temp_config_name, "w", encoding="utf-8") as tmp_config_file: + tmp_config_file.write(config_file) + + @staticmethod + def _pre_substitute_base_vars(filename, temp_config_name): + """Substitute base variable placehoders to string, so that parsing + would work.""" + with open(filename, "r", encoding="utf-8") as f: + # Setting encoding explicitly to resolve coding issue on windows + config_file = f.read() + base_var_dict = {} + regexp = r"\{\{\s*" + BASE_KEY + r"\.([\w\.]+)\s*\}\}" + base_vars = set(re.findall(regexp, config_file)) + for base_var in base_vars: + randstr = f"_{base_var}_{uuid.uuid4().hex.lower()[:6]}" + base_var_dict[randstr] = base_var + regexp = r"\{\{\s*" + BASE_KEY + r"\." + base_var + r"\s*\}\}" + config_file = re.sub(regexp, f'"{randstr}"', config_file) + with open(temp_config_name, "w", encoding="utf-8") as tmp_config_file: + tmp_config_file.write(config_file) + return base_var_dict + + @staticmethod + def _substitute_base_vars(cfg, base_var_dict, base_cfg): + """Substitute variable strings to their actual values.""" + cfg = copy.deepcopy(cfg) + + if isinstance(cfg, dict): + for k, v in cfg.items(): + if isinstance(v, str) and v in base_var_dict: + new_v = base_cfg + for new_k in base_var_dict[v].split("."): + new_v = new_v[new_k] + cfg[k] = new_v + elif isinstance(v, (list, tuple, dict)): + cfg[k] = Config._substitute_base_vars(v, base_var_dict, base_cfg) + elif isinstance(cfg, tuple): + cfg = tuple(Config._substitute_base_vars(c, base_var_dict, base_cfg) for c in cfg) + elif isinstance(cfg, list): + cfg = [Config._substitute_base_vars(c, base_var_dict, base_cfg) for c in cfg] + elif isinstance(cfg, str) and cfg in base_var_dict: + new_v = base_cfg + for new_k in base_var_dict[cfg].split("."): + new_v = new_v[new_k] + cfg = new_v + + return cfg + + @staticmethod + def _file2dict(filename, use_predefined_variables=True): + filename = osp.abspath(osp.expanduser(filename)) + check_file_exist(filename) + fileExtname = osp.splitext(filename)[1] + if fileExtname not in [".py", ".json", ".yaml", ".yml"]: + raise IOError("Only py/yml/yaml/json type are supported now!") + + with tempfile.TemporaryDirectory() as temp_config_dir: + temp_config_file = tempfile.NamedTemporaryFile(dir=temp_config_dir, suffix=fileExtname) + if platform.system() == "Windows": + temp_config_file.close() + temp_config_name = osp.basename(temp_config_file.name) + # Substitute predefined variables + if use_predefined_variables: + Config._substitute_predefined_vars(filename, temp_config_file.name) + else: + shutil.copyfile(filename, temp_config_file.name) + # Substitute base variables from placeholders to strings + base_var_dict = Config._pre_substitute_base_vars(temp_config_file.name, temp_config_file.name) + + if filename.endswith(".py"): + temp_module_name = osp.splitext(temp_config_name)[0] + sys.path.insert(0, temp_config_dir) + Config._validate_py_syntax(filename) + mod = import_module(temp_module_name) + sys.path.pop(0) + cfg_dict = {name: value for name, value in mod.__dict__.items() if not name.startswith("__")} + # delete imported module + del sys.modules[temp_module_name] + elif filename.endswith((".yml", ".yaml", ".json")): + raise NotImplementedError + # close temp file + temp_config_file.close() + + # check deprecation information + if DEPRECATION_KEY in cfg_dict: + deprecation_info = cfg_dict.pop(DEPRECATION_KEY) + warning_msg = f"The config file {filename} will be deprecated " "in the future." + if "expected" in deprecation_info: + warning_msg += f' Please use {deprecation_info["expected"]} ' "instead." + if "reference" in deprecation_info: + warning_msg += " More information can be found at " f'{deprecation_info["reference"]}' + warnings.warn(warning_msg) + + cfg_text = filename + "\n" + with open(filename, "r", encoding="utf-8") as f: + # Setting encoding explicitly to resolve coding issue on windows + cfg_text += f.read() + + if BASE_KEY in cfg_dict: + cfg_dir = osp.dirname(filename) + base_filename = cfg_dict.pop(BASE_KEY) + base_filename = base_filename if isinstance(base_filename, list) else [base_filename] + + cfg_dict_list = list() + cfg_text_list = list() + for f in base_filename: + _cfg_dict, _cfg_text = Config._file2dict(osp.join(cfg_dir, f)) + cfg_dict_list.append(_cfg_dict) + cfg_text_list.append(_cfg_text) + + base_cfg_dict = dict() + for c in cfg_dict_list: + duplicate_keys = base_cfg_dict.keys() & c.keys() + if len(duplicate_keys) > 0: + raise KeyError("Duplicate key is not allowed among bases. " f"Duplicate keys: {duplicate_keys}") + base_cfg_dict.update(c) + + # Substitute base variables from strings to their actual values + cfg_dict = Config._substitute_base_vars(cfg_dict, base_var_dict, base_cfg_dict) + + base_cfg_dict = Config._merge_a_into_b(cfg_dict, base_cfg_dict) + cfg_dict = base_cfg_dict + + # merge cfg_text + cfg_text_list.append(cfg_text) + cfg_text = "\n".join(cfg_text_list) + + return cfg_dict, cfg_text + + @staticmethod + def _merge_a_into_b(a, b, allow_list_keys=False): + """merge dict ``a`` into dict ``b`` (non-inplace). + + Values in ``a`` will overwrite ``b``. ``b`` is copied first to avoid + in-place modifications. + + Args: + a (dict): The source dict to be merged into ``b``. + b (dict): The origin dict to be fetch keys from ``a``. + allow_list_keys (bool): If True, int string keys (e.g. '0', '1') + are allowed in source ``a`` and will replace the element of the + corresponding index in b if b is a list. Default: False. + + Returns: + dict: The modified dict of ``b`` using ``a``. + + Examples: + # Normally merge a into b. + >>> Config._merge_a_into_b( + ... dict(obj=dict(a=2)), dict(obj=dict(a=1))) + {'obj': {'a': 2}} + + # Delete b first and merge a into b. + >>> Config._merge_a_into_b( + ... dict(obj=dict(_delete_=True, a=2)), dict(obj=dict(a=1))) + {'obj': {'a': 2}} + + # b is a list + >>> Config._merge_a_into_b( + ... {'0': dict(a=2)}, [dict(a=1), dict(b=2)], True) + [{'a': 2}, {'b': 2}] + """ + b = b.copy() + for k, v in a.items(): + if allow_list_keys and k.isdigit() and isinstance(b, list): + k = int(k) + if len(b) <= k: + raise KeyError(f"Index {k} exceeds the length of list {b}") + b[k] = Config._merge_a_into_b(v, b[k], allow_list_keys) + elif isinstance(v, dict) and k in b and not v.pop(DELETE_KEY, False): + allowed_types = (dict, list) if allow_list_keys else dict + if not isinstance(b[k], allowed_types): + raise TypeError( + f"{k}={v} in child config cannot inherit from base " + f"because {k} is a dict in the child config but is of " + f"type {type(b[k])} in base config. You may set " + f"`{DELETE_KEY}=True` to ignore the base config" + ) + b[k] = Config._merge_a_into_b(v, b[k], allow_list_keys) + else: + b[k] = v + return b + + @staticmethod + def fromfile(filename, use_predefined_variables=True, import_custom_modules=True): + cfg_dict, cfg_text = Config._file2dict(filename, use_predefined_variables) + if import_custom_modules and cfg_dict.get("custom_imports", None): + import_modules_from_strings(**cfg_dict["custom_imports"]) + return Config(cfg_dict, cfg_text=cfg_text, filename=filename) + + @staticmethod + def fromstring(cfg_str, file_format): + """Generate config from config str. + + Args: + cfg_str (str): Config str. + file_format (str): Config file format corresponding to the + config str. Only py/yml/yaml/json type are supported now! + + Returns: + obj:`Config`: Config obj. + """ + if file_format not in [".py", ".json", ".yaml", ".yml"]: + raise IOError("Only py/yml/yaml/json type are supported now!") + if file_format != ".py" and "dict(" in cfg_str: + # check if users specify a wrong suffix for python + warnings.warn('Please check "file_format", the file format may be .py') + with tempfile.NamedTemporaryFile("w", encoding="utf-8", suffix=file_format, delete=False) as temp_file: + temp_file.write(cfg_str) + # on windows, previous implementation cause error + # see PR 1077 for details + cfg = Config.fromfile(temp_file.name) + os.remove(temp_file.name) + return cfg + + @staticmethod + def auto_argparser(description=None): + """Generate argparser from config file automatically (experimental)""" + partial_parser = ArgumentParser(description=description) + partial_parser.add_argument("config", help="config file path") + cfg_file = partial_parser.parse_known_args()[0].config + cfg = Config.fromfile(cfg_file) + parser = ArgumentParser(description=description) + parser.add_argument("config", help="config file path") + add_args(parser, cfg) + return parser, cfg + + def __init__(self, cfg_dict=None, cfg_text=None, filename=None): + if cfg_dict is None: + cfg_dict = dict() + elif not isinstance(cfg_dict, dict): + raise TypeError("cfg_dict must be a dict, but " f"got {type(cfg_dict)}") + for key in cfg_dict: + if key in RESERVED_KEYS: + raise KeyError(f"{key} is reserved for config file") + + super(Config, self).__setattr__("_cfg_dict", ConfigDict(cfg_dict)) + super(Config, self).__setattr__("_filename", filename) + if cfg_text: + text = cfg_text + elif filename: + with open(filename, "r") as f: + text = f.read() + else: + text = "" + super(Config, self).__setattr__("_text", text) + + @property + def filename(self): + return self._filename + + @property + def text(self): + return self._text + + @property + def pretty_text(self): + indent = 4 + + def _indent(s_, num_spaces): + s = s_.split("\n") + if len(s) == 1: + return s_ + first = s.pop(0) + s = [(num_spaces * " ") + line for line in s] + s = "\n".join(s) + s = first + "\n" + s + return s + + def _format_basic_types(k, v, use_mapping=False): + if isinstance(v, str): + v_str = f"'{v}'" + else: + v_str = str(v) + + if use_mapping: + k_str = f"'{k}'" if isinstance(k, str) else str(k) + attr_str = f"{k_str}: {v_str}" + else: + attr_str = f"{str(k)}={v_str}" + attr_str = _indent(attr_str, indent) + + return attr_str + + def _format_list(k, v, use_mapping=False): + # check if all items in the list are dict + if all(isinstance(_, dict) for _ in v): + v_str = "[\n" + v_str += "\n".join(f"dict({_indent(_format_dict(v_), indent)})," for v_ in v).rstrip(",") + if use_mapping: + k_str = f"'{k}'" if isinstance(k, str) else str(k) + attr_str = f"{k_str}: {v_str}" + else: + attr_str = f"{str(k)}={v_str}" + attr_str = _indent(attr_str, indent) + "]" + else: + attr_str = _format_basic_types(k, v, use_mapping) + return attr_str + + def _contain_invalid_identifier(dict_str): + contain_invalid_identifier = False + for key_name in dict_str: + contain_invalid_identifier |= not str(key_name).isidentifier() + return contain_invalid_identifier + + def _format_dict(input_dict, outest_level=False): + r = "" + s = [] + + use_mapping = _contain_invalid_identifier(input_dict) + if use_mapping: + r += "{" + for idx, (k, v) in enumerate(input_dict.items()): + is_last = idx >= len(input_dict) - 1 + end = "" if outest_level or is_last else "," + if isinstance(v, dict): + v_str = "\n" + _format_dict(v) + if use_mapping: + k_str = f"'{k}'" if isinstance(k, str) else str(k) + attr_str = f"{k_str}: dict({v_str}" + else: + attr_str = f"{str(k)}=dict({v_str}" + attr_str = _indent(attr_str, indent) + ")" + end + elif isinstance(v, list): + attr_str = _format_list(k, v, use_mapping) + end + else: + attr_str = _format_basic_types(k, v, use_mapping) + end + + s.append(attr_str) + r += "\n".join(s) + if use_mapping: + r += "}" + return r + + cfg_dict = self._cfg_dict.to_dict() + text = _format_dict(cfg_dict, outest_level=True) + # copied from setup.cfg + yapf_style = dict( + based_on_style="pep8", + blank_line_before_nested_class_or_def=True, + split_before_expression_after_opening_paren=True, + ) + text, _ = FormatCode(text, style_config=yapf_style, verify=True) + + return text + + def __repr__(self): + return f"Config (path: {self.filename}): {self._cfg_dict.__repr__()}" + + def __len__(self): + return len(self._cfg_dict) + + def __getattr__(self, name): + return getattr(self._cfg_dict, name) + + def __getitem__(self, name): + return self._cfg_dict.__getitem__(name) + + def __setattr__(self, name, value): + if isinstance(value, dict): + value = ConfigDict(value) + self._cfg_dict.__setattr__(name, value) + + def __setitem__(self, name, value): + if isinstance(value, dict): + value = ConfigDict(value) + self._cfg_dict.__setitem__(name, value) + + def __iter__(self): + return iter(self._cfg_dict) + + def __getstate__(self): + return (self._cfg_dict, self._filename, self._text) + + def __setstate__(self, state): + _cfg_dict, _filename, _text = state + super(Config, self).__setattr__("_cfg_dict", _cfg_dict) + super(Config, self).__setattr__("_filename", _filename) + super(Config, self).__setattr__("_text", _text) + + def dump(self, file=None): + cfg_dict = super(Config, self).__getattribute__("_cfg_dict").to_dict() + if self.filename.endswith(".py"): + if file is None: + return self.pretty_text + else: + with open(file, "w", encoding="utf-8") as f: + f.write(self.pretty_text) + else: + import mmcv + + if file is None: + file_format = self.filename.split(".")[-1] + return mmcv.dump(cfg_dict, file_format=file_format) + else: + mmcv.dump(cfg_dict, file) + + def merge_from_dict(self, options, allow_list_keys=True): + """Merge list into cfg_dict. + + Merge the dict parsed by MultipleKVAction into this cfg. + + Examples: + >>> options = {'models.backbone.depth': 50, + ... 'models.backbone.with_cp':True} + >>> cfg = Config(dict(models=dict(backbone=dict(type='ResNet')))) + >>> cfg.merge_from_dict(options) + >>> cfg_dict = super(Config, self).__getattribute__('_cfg_dict') + >>> assert cfg_dict == dict( + ... models=dict(backbone=dict(depth=50, with_cp=True))) + + # Merge list element + >>> cfg = Config(dict(pipeline=[ + ... dict(type='LoadImage'), dict(type='LoadAnnotations')])) + >>> options = dict(pipeline={'0': dict(type='SelfLoadImage')}) + >>> cfg.merge_from_dict(options, allow_list_keys=True) + >>> cfg_dict = super(Config, self).__getattribute__('_cfg_dict') + >>> assert cfg_dict == dict(pipeline=[ + ... dict(type='SelfLoadImage'), dict(type='LoadAnnotations')]) + + Args: + options (dict): dict of configs to merge from. + allow_list_keys (bool): If True, int string keys (e.g. '0', '1') + are allowed in ``options`` and will replace the element of the + corresponding index in the config if the config is a list. + Default: True. + """ + option_cfg_dict = {} + for full_key, v in options.items(): + d = option_cfg_dict + key_list = full_key.split(".") + for subkey in key_list[:-1]: + d.setdefault(subkey, ConfigDict()) + d = d[subkey] + subkey = key_list[-1] + d[subkey] = v + + cfg_dict = super(Config, self).__getattribute__("_cfg_dict") + super(Config, self).__setattr__( + "_cfg_dict", + Config._merge_a_into_b(option_cfg_dict, cfg_dict, allow_list_keys=allow_list_keys), + ) + + +class DictAction(Action): + """ + argparse action to split an argument into KEY=VALUE form + on the first = and append to a dictionary. List options can + be passed as comma separated values, i.e 'KEY=V1,V2,V3', or with explicit + brackets, i.e. 'KEY=[V1,V2,V3]'. It also support nested brackets to build + list/tuple values. e.g. 'KEY=[(V1,V2),(V3,V4)]' + """ + + @staticmethod + def _parse_int_float_bool(val): + try: + return int(val) + except ValueError: + pass + try: + return float(val) + except ValueError: + pass + if val.lower() in ["true", "false"]: + return True if val.lower() == "true" else False + return val + + @staticmethod + def _parse_iterable(val): + """Parse iterable values in the string. + + All elements inside '()' or '[]' are treated as iterable values. + + Args: + val (str): Value string. + + Returns: + list | tuple: The expanded list or tuple from the string. + + Examples: + >>> DictAction._parse_iterable('1,2,3') + [1, 2, 3] + >>> DictAction._parse_iterable('[a, b, c]') + ['a', 'b', 'c'] + >>> DictAction._parse_iterable('[(1, 2, 3), [a, b], c]') + [(1, 2, 3), ['a', 'b'], 'c'] + """ + + def find_next_comma(string): + """Find the position of next comma in the string. + + If no ',' is found in the string, return the string length. All + chars inside '()' and '[]' are treated as one element and thus ',' + inside these brackets are ignored. + """ + assert (string.count("(") == string.count(")")) and ( + string.count("[") == string.count("]") + ), f"Imbalanced brackets exist in {string}" + end = len(string) + for idx, char in enumerate(string): + pre = string[:idx] + # The string before this ',' is balanced + if (char == ",") and (pre.count("(") == pre.count(")")) and (pre.count("[") == pre.count("]")): + end = idx + break + return end + + # Strip ' and " characters and replace whitespace. + val = val.strip("'\"").replace(" ", "") + is_tuple = False + if val.startswith("(") and val.endswith(")"): + is_tuple = True + val = val[1:-1] + elif val.startswith("[") and val.endswith("]"): + val = val[1:-1] + elif "," not in val: + # val is a single value + return DictAction._parse_int_float_bool(val) + + values = [] + while len(val) > 0: + comma_idx = find_next_comma(val) + element = DictAction._parse_iterable(val[:comma_idx]) + values.append(element) + val = val[comma_idx + 1 :] + if is_tuple: + values = tuple(values) + return values + + def __call__(self, parser, namespace, values, option_string=None): + options = {} + for kv in values: + key, val = kv.split("=", maxsplit=1) + options[key] = self._parse_iterable(val) + setattr(namespace, self.dest, options) diff --git a/projects/PTv3/utils/env.py b/projects/PTv3/utils/env.py new file mode 100644 index 00000000..5ebb6924 --- /dev/null +++ b/projects/PTv3/utils/env.py @@ -0,0 +1,32 @@ +""" +Environment Utils + +Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com) +Please cite our work if the code is helpful to you. +""" + +import os +import random +from datetime import datetime + +import numpy as np +import torch +import torch.backends.cudnn as cudnn + + +def get_random_seed(): + seed = os.getpid() + int(datetime.now().strftime("%S%f")) + int.from_bytes(os.urandom(2), "big") + return seed + + +def set_seed(seed=None): + if seed is None: + seed = get_random_seed() + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + cudnn.benchmark = False + cudnn.deterministic = True + os.environ["PYTHONHASHSEED"] = str(seed) diff --git a/projects/PTv3/utils/events.py b/projects/PTv3/utils/events.py new file mode 100644 index 00000000..91e21d56 --- /dev/null +++ b/projects/PTv3/utils/events.py @@ -0,0 +1,284 @@ +""" +Events Utils + +Modified from Detectron2 + +Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com) +Please cite our work if the code is helpful to you. +""" + +import logging +import sys +import traceback +from collections import defaultdict +from contextlib import contextmanager + +__all__ = [ + "EventStorage", + "ExceptionWriter", +] + +_CURRENT_STORAGE_STACK = [] + + +class EventWriter: + """ + Base class for writers that obtain events from :class:`EventStorage` and process them. + """ + + def write(self): + raise NotImplementedError + + def close(self): + pass + + +class EventStorage: + """ + The user-facing class that provides metric storage functionalities. + In the future we may add support for storing / logging other types of data if needed. + """ + + def __init__(self, start_iter=0): + """ + Args: + start_iter (int): the iteration number to start with + """ + self._history = defaultdict(AverageMeter) + self._smoothing_hints = {} + self._latest_scalars = {} + self._iter = start_iter + self._current_prefix = "" + self._vis_data = [] + self._histograms = [] + + # def put_image(self, img_name, img_tensor): + # """ + # Add an `img_tensor` associated with `img_name`, to be shown on + # tensorboard. + # Args: + # img_name (str): The name of the image to put into tensorboard. + # img_tensor (torch.Tensor or numpy.array): An `uint8` or `float` + # Tensor of shape `[channel, height, width]` where `channel` is + # 3. The image format should be RGB. The elements in img_tensor + # can either have values in [0, 1] (float32) or [0, 255] (uint8). + # The `img_tensor` will be visualized in tensorboard. + # """ + # self._vis_data.append((img_name, img_tensor, self._iter)) + + def put_scalar(self, name, value, n=1, smoothing_hint=False): + """ + Add a scalar `value` to the `HistoryBuffer` associated with `name`. + Args: + smoothing_hint (bool): a 'hint' on whether this scalar is noisy and should be + smoothed when logged. The hint will be accessible through + :meth:`EventStorage.smoothing_hints`. A writer may ignore the hint + and apply custom smoothing rule. + It defaults to True because most scalars we save need to be smoothed to + provide any useful signal. + """ + name = self._current_prefix + name + history = self._history[name] + history.update(value, n) + self._latest_scalars[name] = (value, self._iter) + + existing_hint = self._smoothing_hints.get(name) + if existing_hint is not None: + assert existing_hint == smoothing_hint, "Scalar {} was put with a different smoothing_hint!".format(name) + else: + self._smoothing_hints[name] = smoothing_hint + + # def put_scalars(self, *, smoothing_hint=True, **kwargs): + # """ + # Put multiple scalars from keyword arguments. + # Examples: + # storage.put_scalars(loss=my_loss, accuracy=my_accuracy, smoothing_hint=True) + # """ + # for k, v in kwargs.items(): + # self.put_scalar(k, v, smoothing_hint=smoothing_hint) + # + # def put_histogram(self, hist_name, hist_tensor, bins=1000): + # """ + # Create a histogram from a tensor. + # Args: + # hist_name (str): The name of the histogram to put into tensorboard. + # hist_tensor (torch.Tensor): A Tensor of arbitrary shape to be converted + # into a histogram. + # bins (int): Number of histogram bins. + # """ + # ht_min, ht_max = hist_tensor.min().item(), hist_tensor.max().item() + # + # # Create a histogram with PyTorch + # hist_counts = torch.histc(hist_tensor, bins=bins) + # hist_edges = torch.linspace(start=ht_min, end=ht_max, steps=bins + 1, dtype=torch.float32) + # + # # Parameter for the add_histogram_raw function of SummaryWriter + # hist_params = dict( + # tag=hist_name, + # min=ht_min, + # max=ht_max, + # num=len(hist_tensor), + # sum=float(hist_tensor.sum()), + # sum_squares=float(torch.sum(hist_tensor**2)), + # bucket_limits=hist_edges[1:].tolist(), + # bucket_counts=hist_counts.tolist(), + # global_step=self._iter, + # ) + # self._histograms.append(hist_params) + + def history(self, name): + """ + Returns: + AverageMeter: the history for name + """ + ret = self._history.get(name, None) + if ret is None: + raise KeyError("No history metric available for {}!".format(name)) + return ret + + def histories(self): + """ + Returns: + dict[name -> HistoryBuffer]: the HistoryBuffer for all scalars + """ + return self._history + + def latest(self): + """ + Returns: + dict[str -> (float, int)]: mapping from the name of each scalar to the most + recent value and the iteration number its added. + """ + return self._latest_scalars + + def latest_with_smoothing_hint(self, window_size=20): + """ + Similar to :meth:`latest`, but the returned values + are either the un-smoothed original latest value, + or a median of the given window_size, + depend on whether the smoothing_hint is True. + This provides a default behavior that other writers can use. + """ + result = {} + for k, (v, itr) in self._latest_scalars.items(): + result[k] = ( + self._history[k].median(window_size) if self._smoothing_hints[k] else v, + itr, + ) + return result + + def smoothing_hints(self): + """ + Returns: + dict[name -> bool]: the user-provided hint on whether the scalar + is noisy and needs smoothing. + """ + return self._smoothing_hints + + def step(self): + """ + User should either: (1) Call this function to increment storage.iter when needed. Or + (2) Set `storage.iter` to the correct iteration number before each iteration. + The storage will then be able to associate the new data with an iteration number. + """ + self._iter += 1 + + @property + def iter(self): + """ + Returns: + int: The current iteration number. When used together with a trainer, + this is ensured to be the same as trainer.iter. + """ + return self._iter + + @iter.setter + def iter(self, val): + self._iter = int(val) + + @property + def iteration(self): + # for backward compatibility + return self._iter + + def __enter__(self): + _CURRENT_STORAGE_STACK.append(self) + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + assert _CURRENT_STORAGE_STACK[-1] == self + _CURRENT_STORAGE_STACK.pop() + + @contextmanager + def name_scope(self, name): + """ + Yields: + A context within which all the events added to this storage + will be prefixed by the name scope. + """ + old_prefix = self._current_prefix + self._current_prefix = name.rstrip("/") + "/" + yield + self._current_prefix = old_prefix + + def clear_images(self): + """ + Delete all the stored images for visualization. This should be called + after images are written to tensorboard. + """ + self._vis_data = [] + + def clear_histograms(self): + """ + Delete all the stored histograms for visualization. + This should be called after histograms are written to tensorboard. + """ + self._histograms = [] + + def reset_history(self, name): + ret = self._history.get(name, None) + if ret is None: + raise KeyError("No history metric available for {}!".format(name)) + ret.reset() + + def reset_histories(self): + for name in self._history.keys(): + self._history[name].reset() + + +class AverageMeter: + """Computes and stores the average and current value""" + + def __init__(self): + self.val = 0 + self.avg = 0 + self.total = 0 + self.count = 0 + + def reset(self): + self.val = 0 + self.avg = 0 + self.total = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.total += val * n + self.count += n + self.avg = self.total / self.count + + +class ExceptionWriter: + + def __init__(self): + self.logger = logging.getLogger(__name__) + + def __enter__(self): + pass + + def __exit__(self, exc_type, exc_val, exc_tb): + if exc_type: + tb = traceback.format_exception(exc_type, exc_val, exc_tb) + formatted_tb_str = "".join(tb) + self.logger.error(formatted_tb_str) + sys.exit(1) # This prevents double logging the error to the console diff --git a/projects/PTv3/utils/logger.py b/projects/PTv3/utils/logger.py new file mode 100644 index 00000000..2d16d30d --- /dev/null +++ b/projects/PTv3/utils/logger.py @@ -0,0 +1,169 @@ +""" +Logger Utils + +Modified from mmcv + +Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com) +Please cite our work if the code is helpful to you. +""" + +import logging + +import torch +import torch.distributed as dist +from termcolor import colored + +logger_initialized = {} +root_status = 0 + + +class _ColorfulFormatter(logging.Formatter): + def __init__(self, *args, **kwargs): + self._root_name = kwargs.pop("root_name") + "." + super(_ColorfulFormatter, self).__init__(*args, **kwargs) + + def formatMessage(self, record): + log = super(_ColorfulFormatter, self).formatMessage(record) + if record.levelno == logging.WARNING: + prefix = colored("WARNING", "red", attrs=["blink"]) + elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL: + prefix = colored("ERROR", "red", attrs=["blink", "underline"]) + else: + return log + return prefix + " " + log + + +def get_logger(name, log_file=None, log_level=logging.INFO, file_mode="a", color=False): + """Initialize and get a logger by name. + + If the logger has not been initialized, this method will initialize the + logger by adding one or two handlers, otherwise the initialized logger will + be directly returned. During initialization, a StreamHandler will always be + added. If `log_file` is specified and the process rank is 0, a FileHandler + will also be added. + + Args: + name (str): Logger name. + log_file (str | None): The log filename. If specified, a FileHandler + will be added to the logger. + log_level (int): The logger level. Note that only the process of + rank 0 is affected, and other processes will set the level to + "Error" thus be silent most of the time. + file_mode (str): The file mode used in opening log file. + Defaults to 'a'. + color (bool): Colorful log output. Defaults to True + + Returns: + logging.Logger: The expected logger. + """ + logger = logging.getLogger(name) + + if name in logger_initialized: + return logger + # handle hierarchical names + # e.g., logger "a" is initialized, then logger "a.b" will skip the + # initialization since it is a child of "a". + for logger_name in logger_initialized: + if name.startswith(logger_name): + return logger + + logger.propagate = False + + stream_handler = logging.StreamHandler() + handlers = [stream_handler] + + if dist.is_available() and dist.is_initialized(): + rank = dist.get_rank() + else: + rank = 0 + + # only rank 0 will add a FileHandler + if rank == 0 and log_file is not None: + # Here, the default behaviour of the official logger is 'a'. Thus, we + # provide an interface to change the file mode to the default + # behaviour. + file_handler = logging.FileHandler(log_file, file_mode) + handlers.append(file_handler) + + plain_formatter = logging.Formatter( + "[%(asctime)s %(levelname)s %(filename)s line %(lineno)d %(process)d] %(message)s" + ) + if color: + formatter = _ColorfulFormatter( + colored("[%(asctime)s %(name)s]: ", "green") + "%(message)s", + datefmt="%m/%d %H:%M:%S", + root_name=name, + ) + else: + formatter = plain_formatter + for handler in handlers: + handler.setFormatter(formatter) + handler.setLevel(log_level) + logger.addHandler(handler) + + if rank == 0: + logger.setLevel(log_level) + else: + logger.setLevel(logging.ERROR) + + logger_initialized[name] = True + + return logger + + +def print_log(msg, logger=None, level=logging.INFO): + """Print a log message. + + Args: + msg (str): The message to be logged. + logger (logging.Logger | str | None): The logger to be used. + Some special loggers are: + - "silent": no message will be printed. + - other str: the logger obtained with `get_root_logger(logger)`. + - None: The `print()` method will be used to print log messages. + level (int): Logging level. Only available when `logger` is a Logger + object or "root". + """ + if logger is None: + print(msg) + elif isinstance(logger, logging.Logger): + logger.log(level, msg) + elif logger == "silent": + pass + elif isinstance(logger, str): + _logger = get_logger(logger) + _logger.log(level, msg) + else: + raise TypeError( + "logger should be either a logging.Logger object, str, " f'"silent" or None, but got {type(logger)}' + ) + + +def get_root_logger(log_file=None, log_level=logging.INFO, file_mode="a"): + """Get the root logger. + + The logger will be initialized if it has not been initialized. By default a + StreamHandler will be added. If `log_file` is specified, a FileHandler will + also be added. The name of the root logger is the top-level package name. + + Args: + log_file (str | None): The log filename. If specified, a FileHandler + will be added to the root logger. + log_level (int): The root logger level. Note that only the process of + rank 0 is affected, while other processes will set the level to + "Error" and be silent most of the time. + file_mode (str): File Mode of logger. (w or a) + + Returns: + logging.Logger: The root logger. + """ + logger = get_logger(name="pointcept", log_file=log_file, log_level=log_level, file_mode=file_mode) + return logger + + +def _log_api_usage(identifier: str): + """ + Internal function used to log the usage of different detectron2 components + inside facebook's infra. + """ + torch._C._log_api_usage_once("pointcept." + identifier) diff --git a/projects/PTv3/utils/misc.py b/projects/PTv3/utils/misc.py new file mode 100644 index 00000000..873d6428 --- /dev/null +++ b/projects/PTv3/utils/misc.py @@ -0,0 +1,165 @@ +""" +Misc + +Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com) +Please cite our work if the code is helpful to you. +""" + +import os +import warnings +from collections import abc +from importlib import import_module + +import numpy as np +import torch + + +class AverageMeter(object): + """Computes and stores the average and current value""" + + def __init__(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + +def intersection_and_union(output, target, K, ignore_index=-1): + # 'K' classes, output and target sizes are N or N * L or N * H * W, each value in range 0 to K - 1. + assert output.ndim in [1, 2, 3] + assert output.shape == target.shape + output = output.reshape(output.size).copy() + target = target.reshape(target.size) + output[np.where(target == ignore_index)[0]] = ignore_index + intersection = output[np.where(output == target)[0]] + area_intersection, _ = np.histogram(intersection, bins=np.arange(K + 1)) + area_output, _ = np.histogram(output, bins=np.arange(K + 1)) + area_target, _ = np.histogram(target, bins=np.arange(K + 1)) + area_union = area_output + area_target - area_intersection + return area_intersection, area_union, area_target + + +def intersection_and_union_gpu(output, target, k, ignore_index=-1): + # 'K' classes, output and target sizes are N or N * L or N * H * W, each value in range 0 to K - 1. + assert output.dim() in [1, 2, 3] + assert output.shape == target.shape + output = output.view(-1) + target = target.view(-1) + output[target == ignore_index] = ignore_index + intersection = output[output == target] + area_intersection = torch.histc(intersection, bins=k, min=0, max=k - 1) + area_output = torch.histc(output, bins=k, min=0, max=k - 1) + area_target = torch.histc(target, bins=k, min=0, max=k - 1) + area_union = area_output + area_target - area_intersection + return area_intersection, area_union, area_target + + +def make_dirs(dir_name): + if not os.path.exists(dir_name): + os.makedirs(dir_name, exist_ok=True) + + +def find_free_port(): + import socket + + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + # Binding to port 0 will cause the OS to find an available port for us + sock.bind(("", 0)) + port = sock.getsockname()[1] + sock.close() + # NOTE: there is still a chance the port could be taken by other processes. + return port + + +def is_seq_of(seq, expected_type, seq_type=None): + """Check whether it is a sequence of some type. + + Args: + seq (Sequence): The sequence to be checked. + expected_type (type): Expected type of sequence items. + seq_type (type, optional): Expected sequence type. + + Returns: + bool: Whether the sequence is valid. + """ + if seq_type is None: + exp_seq_type = abc.Sequence + else: + assert isinstance(seq_type, type) + exp_seq_type = seq_type + if not isinstance(seq, exp_seq_type): + return False + for item in seq: + if not isinstance(item, expected_type): + return False + return True + + +def is_str(x): + """Whether the input is an string instance. + + Note: This method is deprecated since python 2 is no longer supported. + """ + return isinstance(x, str) + + +def import_modules_from_strings(imports, allow_failed_imports=False): + """Import modules from the given list of strings. + + Args: + imports (list | str | None): The given module names to be imported. + allow_failed_imports (bool): If True, the failed imports will return + None. Otherwise, an ImportError is raise. Default: False. + + Returns: + list[module] | module | None: The imported modules. + + Examples: + >>> osp, sys = import_modules_from_strings( + ... ['os.path', 'sys']) + >>> import os.path as osp_ + >>> import sys as sys_ + >>> assert osp == osp_ + >>> assert sys == sys_ + """ + if not imports: + return + single_import = False + if isinstance(imports, str): + single_import = True + imports = [imports] + if not isinstance(imports, list): + raise TypeError(f"custom_imports must be a list but got type {type(imports)}") + imported = [] + for imp in imports: + if not isinstance(imp, str): + raise TypeError(f"{imp} is of type {type(imp)} and cannot be imported.") + try: + imported_tmp = import_module(imp) + except ImportError: + if allow_failed_imports: + warnings.warn(f"{imp} failed to import and is ignored.", UserWarning) + imported_tmp = None + else: + raise ImportError + imported.append(imported_tmp) + if single_import: + imported = imported[0] + return imported + + +class DummyClass: + def __init__(self): + pass diff --git a/projects/PTv3/utils/optimizer.py b/projects/PTv3/utils/optimizer.py new file mode 100644 index 00000000..3562f3b8 --- /dev/null +++ b/projects/PTv3/utils/optimizer.py @@ -0,0 +1,55 @@ +""" +Optimizer + +Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com) +Please cite our work if the code is helpful to you. +""" + +import torch +from utils.logger import get_root_logger +from utils.registry import Registry + +OPTIMIZERS = Registry("optimizers") + + +OPTIMIZERS.register_module(module=torch.optim.SGD, name="SGD") +OPTIMIZERS.register_module(module=torch.optim.Adam, name="Adam") +OPTIMIZERS.register_module(module=torch.optim.AdamW, name="AdamW") + + +def build_optimizer(cfg, model, param_dicts=None): + if param_dicts is None: + cfg.params = model.parameters() + else: + cfg.params = [dict(names=[], params=[], lr=cfg.lr)] + for i in range(len(param_dicts)): + param_group = dict(names=[], params=[]) + if "lr" in param_dicts[i].keys(): + param_group["lr"] = param_dicts[i].lr + if "momentum" in param_dicts[i].keys(): + param_group["momentum"] = param_dicts[i].momentum + if "weight_decay" in param_dicts[i].keys(): + param_group["weight_decay"] = param_dicts[i].weight_decay + cfg.params.append(param_group) + + for n, p in model.named_parameters(): + flag = False + for i in range(len(param_dicts)): + if param_dicts[i].keyword in n: + cfg.params[i + 1]["names"].append(n) + cfg.params[i + 1]["params"].append(p) + flag = True + break + if not flag: + cfg.params[0]["names"].append(n) + cfg.params[0]["params"].append(p) + + logger = get_root_logger() + for i in range(len(cfg.params)): + param_names = cfg.params[i].pop("names") + message = "" + for key in cfg.params[i].keys(): + if key != "params": + message += f" {key}: {cfg.params[i][key]};" + logger.info(f"Params Group {i+1} -{message} Params: {param_names}.") + return OPTIMIZERS.build(cfg=cfg) diff --git a/projects/PTv3/utils/path.py b/projects/PTv3/utils/path.py new file mode 100644 index 00000000..2574510d --- /dev/null +++ b/projects/PTv3/utils/path.py @@ -0,0 +1,99 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os +import os.path as osp +from pathlib import Path + +from .misc import is_str + + +def is_filepath(x): + return is_str(x) or isinstance(x, Path) + + +def fopen(filepath, *args, **kwargs): + if is_str(filepath): + return open(filepath, *args, **kwargs) + elif isinstance(filepath, Path): + return filepath.open(*args, **kwargs) + raise ValueError("`filepath` should be a string or a Path") + + +def check_file_exist(filename, msg_tmpl='file "{}" does not exist'): + if not osp.isfile(filename): + raise FileNotFoundError(msg_tmpl.format(filename)) + + +def mkdir_or_exist(dir_name, mode=0o777): + if dir_name == "": + return + dir_name = osp.expanduser(dir_name) + os.makedirs(dir_name, mode=mode, exist_ok=True) + + +def symlink(src, dst, overwrite=True, **kwargs): + if os.path.lexists(dst) and overwrite: + os.remove(dst) + os.symlink(src, dst, **kwargs) + + +def scandir(dir_path, suffix=None, recursive=False, case_sensitive=True): + """Scan a directory to find the interested files. + + Args: + dir_path (str | obj:`Path`): Path of the directory. + suffix (str | tuple(str), optional): File suffix that we are + interested in. Default: None. + recursive (bool, optional): If set to True, recursively scan the + directory. Default: False. + case_sensitive (bool, optional) : If set to False, ignore the case of + suffix. Default: True. + + Returns: + A generator for all the interested files with relative paths. + """ + if isinstance(dir_path, (str, Path)): + dir_path = str(dir_path) + else: + raise TypeError('"dir_path" must be a string or Path object') + + if (suffix is not None) and not isinstance(suffix, (str, tuple)): + raise TypeError('"suffix" must be a string or tuple of strings') + + if suffix is not None and not case_sensitive: + suffix = suffix.lower() if isinstance(suffix, str) else tuple(item.lower() for item in suffix) + + root = dir_path + + def _scandir(dir_path, suffix, recursive, case_sensitive): + for entry in os.scandir(dir_path): + if not entry.name.startswith(".") and entry.is_file(): + rel_path = osp.relpath(entry.path, root) + _rel_path = rel_path if case_sensitive else rel_path.lower() + if suffix is None or _rel_path.endswith(suffix): + yield rel_path + elif recursive and os.path.isdir(entry.path): + # scan recursively if entry.path is a directory + yield from _scandir(entry.path, suffix, recursive, case_sensitive) + + return _scandir(dir_path, suffix, recursive, case_sensitive) + + +def find_vcs_root(path, markers=(".git",)): + """Finds the root directory (including itself) of specified markers. + + Args: + path (str): Path of directory or file. + markers (list[str], optional): List of file or directory names. + + Returns: + The directory contained one of the markers or None if not found. + """ + if osp.isfile(path): + path = osp.dirname(path) + + prev, cur = None, osp.abspath(osp.expanduser(path)) + while cur != prev: + if any(osp.exists(osp.join(cur, marker)) for marker in markers): + return cur + prev, cur = cur, osp.split(cur)[0] + return None diff --git a/projects/PTv3/utils/registry.py b/projects/PTv3/utils/registry.py new file mode 100644 index 00000000..a164119c --- /dev/null +++ b/projects/PTv3/utils/registry.py @@ -0,0 +1,303 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import inspect +import warnings +from functools import partial + +from .misc import is_seq_of + + +def build_from_cfg(cfg, registry, default_args=None): + """Build a module from configs dict. + + Args: + cfg (dict): Config dict. It should at least contain the key "type". + registry (:obj:`Registry`): The registry to search the type from. + default_args (dict, optional): Default initialization arguments. + + Returns: + object: The constructed object. + """ + if not isinstance(cfg, dict): + raise TypeError(f"cfg must be a dict, but got {type(cfg)}") + if "type" not in cfg: + if default_args is None or "type" not in default_args: + raise KeyError('`cfg` or `default_args` must contain the key "type", ' f"but got {cfg}\n{default_args}") + if not isinstance(registry, Registry): + raise TypeError("registry must be an mmcv.Registry object, " f"but got {type(registry)}") + if not (isinstance(default_args, dict) or default_args is None): + raise TypeError("default_args must be a dict or None, " f"but got {type(default_args)}") + + args = cfg.copy() + + if default_args is not None: + for name, value in default_args.items(): + args.setdefault(name, value) + + obj_type = args.pop("type") + if isinstance(obj_type, str): + obj_cls = registry.get(obj_type) + if obj_cls is None: + raise KeyError(f"{obj_type} is not in the {registry.name} registry") + elif inspect.isclass(obj_type): + obj_cls = obj_type + else: + raise TypeError(f"type must be a str or valid type, but got {type(obj_type)}") + try: + return obj_cls(**args) + except Exception as e: + # Normal TypeError does not print class name. + raise type(e)(f"{obj_cls.__name__}: {e}") + + +class Registry: + """A registry to map strings to classes. + + Registered object could be built from registry. + Example: + >>> MODELS = Registry('models') + >>> @MODELS.register_module() + >>> class ResNet: + >>> pass + >>> resnet = MODELS.build(dict(type='ResNet')) + + Please refer to + https://mmcv.readthedocs.io/en/latest/understand_mmcv/registry.html for + advanced usage. + + Args: + name (str): Registry name. + build_func(func, optional): Build function to construct instance from + Registry, func:`build_from_cfg` is used if neither ``parent`` or + ``build_func`` is specified. If ``parent`` is specified and + ``build_func`` is not given, ``build_func`` will be inherited + from ``parent``. Default: None. + parent (Registry, optional): Parent registry. The class registered in + children registry could be built from parent. Default: None. + scope (str, optional): The scope of registry. It is the key to search + for children registry. If not specified, scope will be the name of + the package where class is defined, e.g. mmdet, mmcls, mmseg. + Default: None. + """ + + def __init__(self, name, build_func=None, parent=None, scope=None): + self._name = name + self._module_dict = dict() + self._children = dict() + self._scope = self.infer_scope() if scope is None else scope + + # self.build_func will be set with the following priority: + # 1. build_func + # 2. parent.build_func + # 3. build_from_cfg + if build_func is None: + if parent is not None: + self.build_func = parent.build_func + else: + self.build_func = build_from_cfg + else: + self.build_func = build_func + if parent is not None: + assert isinstance(parent, Registry) + parent._add_children(self) + self.parent = parent + else: + self.parent = None + + def __len__(self): + return len(self._module_dict) + + def __contains__(self, key): + return self.get(key) is not None + + def __repr__(self): + format_str = self.__class__.__name__ + f"(name={self._name}, " f"items={self._module_dict})" + return format_str + + @staticmethod + def infer_scope(): + """Infer the scope of registry. + + The name of the package where registry is defined will be returned. + + Example: + # in mmdet/models/backbone/resnet.py + >>> MODELS = Registry('models') + >>> @MODELS.register_module() + >>> class ResNet: + >>> pass + The scope of ``ResNet`` will be ``mmdet``. + + + Returns: + scope (str): The inferred scope name. + """ + # inspect.stack() trace where this function is called, the index-2 + # indicates the frame where `infer_scope()` is called + filename = inspect.getmodule(inspect.stack()[2][0]).__name__ + split_filename = filename.split(".") + return split_filename[0] + + @staticmethod + def split_scope_key(key): + """Split scope and key. + + The first scope will be split from key. + + Examples: + >>> Registry.split_scope_key('mmdet.ResNet') + 'mmdet', 'ResNet' + >>> Registry.split_scope_key('ResNet') + None, 'ResNet' + + Return: + scope (str, None): The first scope. + key (str): The remaining key. + """ + split_index = key.find(".") + if split_index != -1: + return key[:split_index], key[split_index + 1 :] + else: + return None, key + + @property + def name(self): + return self._name + + @property + def scope(self): + return self._scope + + @property + def module_dict(self): + return self._module_dict + + @property + def children(self): + return self._children + + def get(self, key): + """Get the registry record. + + Args: + key (str): The class name in string format. + + Returns: + class: The corresponding class. + """ + scope, real_key = self.split_scope_key(key) + if scope is None or scope == self._scope: + # get from self + if real_key in self._module_dict: + return self._module_dict[real_key] + else: + # get from self._children + if scope in self._children: + return self._children[scope].get(real_key) + else: + # goto root + parent = self.parent + while parent.parent is not None: + parent = parent.parent + return parent.get(key) + + def build(self, *args, **kwargs): + return self.build_func(*args, **kwargs, registry=self) + + def _add_children(self, registry): + """Add children for a registry. + + The ``registry`` will be added as children based on its scope. + The parent registry could build objects from children registry. + + Example: + >>> models = Registry('models') + >>> mmdet_models = Registry('models', parent=models) + >>> @mmdet_models.register_module() + >>> class ResNet: + >>> pass + >>> resnet = models.build(dict(type='mmdet.ResNet')) + """ + + assert isinstance(registry, Registry) + assert registry.scope is not None + assert registry.scope not in self.children, f"scope {registry.scope} exists in {self.name} registry" + self.children[registry.scope] = registry + + def _register_module(self, module_class, module_name=None, force=False): + if not inspect.isclass(module_class): + raise TypeError("module must be a class, " f"but got {type(module_class)}") + + if module_name is None: + module_name = module_class.__name__ + if isinstance(module_name, str): + module_name = [module_name] + for name in module_name: + if not force and name in self._module_dict: + raise KeyError(f"{name} is already registered " f"in {self.name}") + self._module_dict[name] = module_class + + def deprecated_register_module(self, cls=None, force=False): + warnings.warn( + "The old API of register_module(module, force=False) " + "is deprecated and will be removed, please use the new API " + "register_module(name=None, force=False, module=None) instead." + ) + if cls is None: + return partial(self.deprecated_register_module, force=force) + self._register_module(cls, force=force) + return cls + + def register_module(self, name=None, force=False, module=None): + """Register a module. + + A record will be added to `self._module_dict`, whose key is the class + name or the specified name, and value is the class itself. + It can be used as a decorator or a normal function. + + Example: + >>> backbones = Registry('backbone') + >>> @backbones.register_module() + >>> class ResNet: + >>> pass + + >>> backbones = Registry('backbone') + >>> @backbones.register_module(name='mnet') + >>> class MobileNet: + >>> pass + + >>> backbones = Registry('backbone') + >>> class ResNet: + >>> pass + >>> backbones.register_module(ResNet) + + Args: + name (str | None): The module name to be registered. If not + specified, the class name will be used. + force (bool, optional): Whether to override an existing class with + the same name. Default: False. + module (type): Module class to be registered. + """ + if not isinstance(force, bool): + raise TypeError(f"force must be a boolean, but got {type(force)}") + # NOTE: This is a walkaround to be compatible with the old api, + # while it may introduce unexpected bugs. + if isinstance(name, type): + return self.deprecated_register_module(name, force=force) + + # raise the error ahead of time + if not (name is None or isinstance(name, str) or is_seq_of(name, str)): + raise TypeError( + "name must be either of None, an instance of str or a sequence" f" of str, but got {type(name)}" + ) + + # use it as a normal method: x.register_module(module=SomeClass) + if module is not None: + self._register_module(module_class=module, module_name=name, force=force) + return module + + # use it as a decorator: @x.register_module() + def _register(cls): + self._register_module(module_class=cls, module_name=name, force=force) + return cls + + return _register diff --git a/projects/PTv3/utils/scheduler.py b/projects/PTv3/utils/scheduler.py new file mode 100644 index 00000000..8899524e --- /dev/null +++ b/projects/PTv3/utils/scheduler.py @@ -0,0 +1,144 @@ +""" +Scheduler + +Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com) +Please cite our work if the code is helpful to you. +""" + +import torch.optim.lr_scheduler as lr_scheduler + +from .registry import Registry + +SCHEDULERS = Registry("schedulers") + + +@SCHEDULERS.register_module() +class MultiStepLR(lr_scheduler.MultiStepLR): + def __init__( + self, + optimizer, + milestones, + total_steps, + gamma=0.1, + last_epoch=-1, + verbose=False, + ): + super().__init__( + optimizer=optimizer, + milestones=[rate * total_steps for rate in milestones], + gamma=gamma, + last_epoch=last_epoch, + verbose=verbose, + ) + + +@SCHEDULERS.register_module() +class MultiStepWithWarmupLR(lr_scheduler.LambdaLR): + def __init__( + self, + optimizer, + milestones, + total_steps, + gamma=0.1, + warmup_rate=0.05, + warmup_scale=1e-6, + last_epoch=-1, + verbose=False, + ): + milestones = [rate * total_steps for rate in milestones] + + def multi_step_with_warmup(s): + factor = 1.0 + for i in range(len(milestones)): + if s < milestones[i]: + break + factor *= gamma + + if s <= warmup_rate * total_steps: + warmup_coefficient = 1 - (1 - s / warmup_rate / total_steps) * (1 - warmup_scale) + else: + warmup_coefficient = 1.0 + return warmup_coefficient * factor + + super().__init__( + optimizer=optimizer, + lr_lambda=multi_step_with_warmup, + last_epoch=last_epoch, + verbose=verbose, + ) + + +@SCHEDULERS.register_module() +class PolyLR(lr_scheduler.LambdaLR): + def __init__(self, optimizer, total_steps, power=0.9, last_epoch=-1, verbose=False): + super().__init__( + optimizer=optimizer, + lr_lambda=lambda s: (1 - s / (total_steps + 1)) ** power, + last_epoch=last_epoch, + verbose=verbose, + ) + + +@SCHEDULERS.register_module() +class ExpLR(lr_scheduler.LambdaLR): + def __init__(self, optimizer, total_steps, gamma=0.9, last_epoch=-1, verbose=False): + super().__init__( + optimizer=optimizer, + lr_lambda=lambda s: gamma ** (s / total_steps), + last_epoch=last_epoch, + verbose=verbose, + ) + + +@SCHEDULERS.register_module() +class CosineAnnealingLR(lr_scheduler.CosineAnnealingLR): + def __init__(self, optimizer, total_steps, eta_min=0, last_epoch=-1, verbose=False): + super().__init__( + optimizer=optimizer, + T_max=total_steps, + eta_min=eta_min, + last_epoch=last_epoch, + verbose=verbose, + ) + + +@SCHEDULERS.register_module() +class OneCycleLR(lr_scheduler.OneCycleLR): + r""" + torch.optim.lr_scheduler.OneCycleLR, Block total_steps + """ + + def __init__( + self, + optimizer, + max_lr, + total_steps=None, + pct_start=0.3, + anneal_strategy="cos", + cycle_momentum=True, + base_momentum=0.85, + max_momentum=0.95, + div_factor=25.0, + final_div_factor=1e4, + three_phase=False, + last_epoch=-1, + ): + super().__init__( + optimizer=optimizer, + max_lr=max_lr, + total_steps=total_steps, + pct_start=pct_start, + anneal_strategy=anneal_strategy, + cycle_momentum=cycle_momentum, + base_momentum=base_momentum, + max_momentum=max_momentum, + div_factor=div_factor, + final_div_factor=final_div_factor, + three_phase=three_phase, + last_epoch=last_epoch, + ) + + +def build_scheduler(cfg, optimizer): + cfg.optimizer = optimizer + return SCHEDULERS.build(cfg=cfg) diff --git a/projects/PTv3/utils/timer.py b/projects/PTv3/utils/timer.py new file mode 100644 index 00000000..3de4a16e --- /dev/null +++ b/projects/PTv3/utils/timer.py @@ -0,0 +1,70 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +# -*- coding: utf-8 -*- + +from time import perf_counter +from typing import Optional + + +class Timer: + """ + A timer which computes the time elapsed since the start/reset of the timer. + """ + + def __init__(self) -> None: + self.reset() + + def reset(self) -> None: + """ + Reset the timer. + """ + self._start = perf_counter() + self._paused: Optional[float] = None + self._total_paused = 0 + self._count_start = 1 + + def pause(self) -> None: + """ + Pause the timer. + """ + if self._paused is not None: + raise ValueError("Trying to pause a Timer that is already paused!") + self._paused = perf_counter() + + def is_paused(self) -> bool: + """ + Returns: + bool: whether the timer is currently paused + """ + return self._paused is not None + + def resume(self) -> None: + """ + Resume the timer. + """ + if self._paused is None: + raise ValueError("Trying to resume a Timer that is not paused!") + # pyre-fixme[58]: `-` is not supported for operand types `float` and + # `Optional[float]`. + self._total_paused += perf_counter() - self._paused + self._paused = None + self._count_start += 1 + + def seconds(self) -> float: + """ + Returns: + (float): the total number of seconds since the start/reset of the + timer, excluding the time when the timer is paused. + """ + if self._paused is not None: + end_time: float = self._paused # type: ignore + else: + end_time = perf_counter() + return end_time - self._start - self._total_paused + + def avg_seconds(self) -> float: + """ + Returns: + (float): the average number of seconds between every start/reset and + pause. + """ + return self.seconds() / self._count_start diff --git a/projects/PTv3/utils/visualization.py b/projects/PTv3/utils/visualization.py new file mode 100644 index 00000000..cde44e07 --- /dev/null +++ b/projects/PTv3/utils/visualization.py @@ -0,0 +1,84 @@ +""" +Visualization Utils + +Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com) +Please cite our work if the code is helpful to you. +""" + +import os + +import numpy as np +import open3d as o3d +import torch + + +def to_numpy(x): + if isinstance(x, torch.Tensor): + x = x.clone().detach().cpu().numpy() + assert isinstance(x, np.ndarray) + return x + + +def save_point_cloud(coord, color=None, file_path="pc.ply", logger=None): + os.makedirs(os.path.dirname(file_path), exist_ok=True) + coord = to_numpy(coord) + if color is not None: + color = to_numpy(color) + pcd = o3d.geometry.PointCloud() + pcd.points = o3d.utility.Vector3dVector(coord) + pcd.colors = o3d.utility.Vector3dVector(np.ones_like(coord) if color is None else color) + o3d.io.write_point_cloud(file_path, pcd) + if logger is not None: + logger.info(f"Save Point Cloud to: {file_path}") + + +def save_bounding_boxes(bboxes_corners, color=(1.0, 0.0, 0.0), file_path="bbox.ply", logger=None): + bboxes_corners = to_numpy(bboxes_corners) + # point list + points = bboxes_corners.reshape(-1, 3) + # line list + box_lines = np.array( + [ + [0, 1], + [1, 2], + [2, 3], + [3, 0], + [4, 5], + [5, 6], + [6, 7], + [7, 0], + [0, 4], + [1, 5], + [2, 6], + [3, 7], + ] + ) + lines = [] + for i, _ in enumerate(bboxes_corners): + lines.append(box_lines + i * 8) + lines = np.concatenate(lines) + # color list + color = np.array([color for _ in range(len(lines))]) + # generate line set + line_set = o3d.geometry.LineSet() + line_set.points = o3d.utility.Vector3dVector(points) + line_set.lines = o3d.utility.Vector2iVector(lines) + line_set.colors = o3d.utility.Vector3dVector(color) + o3d.io.write_line_set(file_path, line_set) + + if logger is not None: + logger.info(f"Save Boxes to: {file_path}") + + +def save_lines(points, lines, color=(1.0, 0.0, 0.0), file_path="lines.ply", logger=None): + points = to_numpy(points) + lines = to_numpy(lines) + colors = np.array([color for _ in range(len(lines))]) + line_set = o3d.geometry.LineSet() + line_set.points = o3d.utility.Vector3dVector(points) + line_set.lines = o3d.utility.Vector2iVector(lines) + line_set.colors = o3d.utility.Vector3dVector(colors) + o3d.io.write_line_set(file_path, line_set) + + if logger is not None: + logger.info(f"Save Lines to: {file_path}") diff --git a/projects/SparseConvolution/overwrite_spconv.py b/projects/SparseConvolution/overwrite_spconv.py index 3e58f246..cb519a20 100644 --- a/projects/SparseConvolution/overwrite_spconv.py +++ b/projects/SparseConvolution/overwrite_spconv.py @@ -2,7 +2,7 @@ # Partially copied from https://github.com/open-mmlab/mmdetection3d/blob/v1.4.0/mmdet3d/models/layers/spconv/overwrite_spconv/write_spconv2.py # NOTE(knzo25): needed to overwrite our custom deployment oriented operations -from mmdet3d.models.layers.spconv.overwrite_spconv.write_spconv2 import _load_from_state_dict +# from mmdet3d.models.layers.spconv.overwrite_spconv.write_spconv2 import _load_from_state_dict from mmengine.registry import MODELS @@ -40,5 +40,5 @@ def register_spconv2() -> bool: MODELS._register_module(SubMConv3d, "SubMConv3d", force=True) MODELS._register_module(SubMConv4d, "SubMConv4d", force=True) SparseModule._version = 2 - SparseModule._load_from_state_dict = _load_from_state_dict + # SparseModule._load_from_state_dict = _load_from_state_dict #NOTE(knzo25): this is the original function return True diff --git a/projects/SparseConvolution/sparse_conv.py b/projects/SparseConvolution/sparse_conv.py index cc23f94c..d4ed1d38 100644 --- a/projects/SparseConvolution/sparse_conv.py +++ b/projects/SparseConvolution/sparse_conv.py @@ -20,9 +20,10 @@ from cumm import tensorview as tv from spconv.core import ConvAlgo from spconv.debug_utils import spconv_save_debug_data +from spconv.pytorch import functional as Fsp from spconv.pytorch import ops from spconv.pytorch.conv import SparseConvolution as SparseConvolutionBase -from spconv.pytorch.core import ImplicitGemmIndiceData, SparseConvTensor +from spconv.pytorch.core import ImplicitGemmIndiceData, IndiceData, SparseConvTensor from spconv.utils import nullcontext from torch.nn import functional as F @@ -116,7 +117,109 @@ def _conv_forward( profile_ctx = input._timer.namespace(sparse_unique_name) with profile_ctx: if algo == ConvAlgo.Native: - raise NotImplementedError + + datas = input.find_indice_pair(self.indice_key) + if datas is not None: + assert isinstance(datas, IndiceData) + if self.inverse: + assert datas is not None and self.indice_key is not None + assert datas.is_subm is False, "inverse conv can only be used with standard conv and pool ops." + + outids = datas.indices + indice_pairs = datas.indice_pairs + indice_pair_num = datas.indice_pair_num + out_spatial_shape = datas.spatial_shape + self._check_inverse_reuse_valid(input, spatial_shape, datas) + else: + if self.indice_key is not None and datas is not None: + outids = datas.out_indices + indice_pairs = datas.indice_pairs + indice_pair_num = datas.indice_pair_num + assert self.subm, "only support reuse subm indices" + self._check_subm_reuse_valid(input, spatial_shape, datas) + else: + if input.benchmark: + torch.cuda.synchronize() + t = time.time() + try: + """outids, indice_pairs, indice_pair_num = ops.get_indice_pairs( + indices, batch_size, spatial_shape, algo, + self.kernel_size, self.stride, self.padding, + self.dilation, self.output_padding, self.subm, + self.transposed)""" + + outids, indice_pairs, indice_pair_num, num_act_out = Fsp_custom.get_indice_pairs( + indices, + batch_size, + spatial_shape, + algo, + self.kernel_size, + self.stride, + self.padding, + self.dilation, + self.output_padding, + self.subm, + self.transposed, + ) + + except Exception as e: + msg = "[Exception|native_pair]" + msg += f"indices={indices.shape},bs={batch_size},ss={spatial_shape}," + msg += f"algo={algo},ksize={self.kernel_size},stride={self.stride}," + msg += f"padding={self.padding},dilation={self.dilation},subm={self.subm}," + msg += f"transpose={self.transposed}" + print(msg, file=sys.stderr) + spconv_save_debug_data(indices) + raise e + if input.benchmark: + torch.cuda.synchronize() + interval = time.time() - t + out_tensor.benchmark_record[name]["indice_gen_time"].append(interval) + + indice_data = IndiceData( + outids, + indices, + indice_pairs, + indice_pair_num, + spatial_shape, + out_spatial_shape, + is_subm=self.subm, + algo=algo, + ksize=self.kernel_size, + stride=self.stride, + padding=self.padding, + dilation=self.dilation, + ) + if self.indice_key is not None: + msg = f"your indice key {self.indice_key} already exists in this sparse tensor." + assert self.indice_key not in indice_dict, msg + indice_dict[self.indice_key] = indice_data + if input.benchmark: + torch.cuda.synchronize() + t = time.time() + indice_pairs_calc = indice_pairs + if indice_pairs.device != features.device: + indice_pairs_calc = indice_pairs.to(features.device) + + assert not self.inverse, "We are unlikely to ever use this" + assert not training, "This code is for inference only" + + out_features = Fsp_custom.indice_conv( + features, + weight, + indice_pairs_calc, + indice_pair_num, + num_act_out, + algo, + training, + self.subm, + input._timer, + bias_for_infer, + act_alpha, + act_beta, + act_type, + ) + else: data = input.find_indice_pair(self.indice_key) if data is not None: @@ -147,6 +250,7 @@ def _conv_forward( mask_argsort_fwd_splits = data.mask_argsort_fwd_splits mask_argsort_bwd_splits = data.mask_argsort_bwd_splits masks = data.masks + num_act_out = data.out_voxel_num assert self.subm, "only support reuse subm indices" self._check_subm_reuse_valid(input, spatial_shape, data) else: @@ -157,21 +261,23 @@ def _conv_forward( # we need to gen bwd indices for regular conv # because it may be inversed. try: - res = Fsp_custom.get_indice_pairs_implicit_gemm( - indices, - batch_size, - spatial_shape, - algo, - self.kernel_size, - self.stride, - self.padding, - self.dilation, - self.output_padding, - self.subm, - self.transposed, - (not self.subm) or training, - input.thrust_allocator, - input._timer, + outids, pair_fwd, pair_mask_fwd_splits, mask_argsort_fwd_splits, num_act_out = ( + Fsp_custom.get_indice_pairs_implicit_gemm( + indices, + batch_size, + spatial_shape, + algo, + self.kernel_size, + self.stride, + self.padding, + self.dilation, + self.output_padding, + self.subm, + self.transposed, + (not self.subm) or training, + input.thrust_allocator, + input._timer, + ) ) except Exception as e: msg = "[Exception|implicit_gemm_pair]" @@ -186,13 +292,8 @@ def _conv_forward( torch.cuda.synchronize() interval = time.time() - t out_tensor.benchmark_record[name]["indice_gen_time"].append(interval) - outids = res[0] - # num_inds_per_loc = None #res[1] - pair_fwd = res[1] # res[2] pair_bwd = None # res[3] - pair_mask_fwd_splits = res[2] # res[4] pair_mask_bwd_splits = None # res[5] - mask_argsort_fwd_splits = res[3] # res[6] mask_argsort_bwd_splits = None # res[7] # masks = res[8] we should not use this for test masks = [None] @@ -215,6 +316,7 @@ def _conv_forward( stride=self.stride, padding=self.padding, dilation=self.dilation, + out_voxel_num=num_act_out, ) msg = f"your indice key {self.indice_key} " "already exists in this sparse tensor." assert self.indice_key not in indice_dict, msg @@ -222,7 +324,8 @@ def _conv_forward( if input.benchmark: torch.cuda.synchronize() t = time.time() - num_activate_out = outids.shape[0] # TODO(knzo25): should use the output of res to force the graph + # num_activate_out = outids.shape[ + # 0] # TODO(knzo25): should use the output of res to force the graph weight_cur = weight bias_cur = bias_for_infer # if self.enable_int8_test_mode: @@ -242,13 +345,13 @@ def _conv_forward( pair_fwd, pair_mask_fwd_splits, mask_argsort_fwd_splits, - num_activate_out, + num_act_out, masks, training, self.subm, input._timer, self.fp32_accum, - bias_cur, + None, act_alpha, act_beta, act_type, @@ -261,6 +364,9 @@ def _conv_forward( output_dtype, ) + if bias_cur is not None: + out_features = out_features + bias_cur + if bias_for_training is not None: out_features += bias_for_training if input.benchmark: diff --git a/projects/SparseConvolution/sparse_functional.py b/projects/SparseConvolution/sparse_functional.py index 1965051c..efb72f7e 100644 --- a/projects/SparseConvolution/sparse_functional.py +++ b/projects/SparseConvolution/sparse_functional.py @@ -1,3 +1,4 @@ +import sys from typing import Any, List, Optional import numpy as np @@ -9,6 +10,7 @@ from spconv.core import ConvAlgo from spconv.core_cc.csrc.sparse.all import SpconvOps from spconv.core_cc.csrc.sparse.convops.spops import ConvGemmOps +from spconv.pytorch import ops from spconv.pytorch.core import ThrustSortAllocator from spconv.pytorch.cppcore import _TORCH_DTYPE_TO_TV, TorchAllocator, get_arch, get_current_stream, torch_tensor_to_tv from spconv.tools import CUDAKernelTimer @@ -16,6 +18,199 @@ from torch.onnx.symbolic_helper import _get_tensor_sizes +class GetIndicePairs(Function): + + @staticmethod + def symbolic( + g, + indices: torch.Tensor, + batch_size: int, + spatial_shape: List[int], + algo: ConvAlgo, + ksize: List[int], + stride: List[int], + padding: List[int], + dilation: List[int], + out_padding: List[int], + subm: bool, + transpose: bool, + ): + outputs = g.op( + "autoware::GetIndicePairs", + indices, + batch_size_i=batch_size, + spatial_shape_i=spatial_shape, + algo_i=algo.value, + ksize_i=ksize, + stride_i=stride, + padding_i=padding, + dilation_i=dilation, + out_padding_i=out_padding, + subm_i=subm, + transpose_i=transpose, + outputs=4, + ) + indices_shape = _get_tensor_sizes(indices) + if indices_shape is not None and hasattr(indices.type(), "with_sizes"): + output_type_1 = indices.type().with_sizes([None, indices_shape[1]]) + output_type_2 = indices.type().with_sizes([2, np.prod(ksize), None]) + output_type_3 = indices.type().with_sizes([np.prod(ksize)]) + output_type_4 = indices.type().with_sizes([]) + + outputs[0].setType(output_type_1) + outputs[1].setType(output_type_2) + outputs[2].setType(output_type_3) + outputs[3].setType(output_type_4) + return outputs + + @staticmethod + def forward( + ctx, + indices: torch.Tensor, + batch_size: int, + spatial_shape: List[int], + algo: ConvAlgo, + ksize: List[int], + stride: List[int], + padding: List[int], + dilation: List[int], + out_padding: List[int], + subm: bool, + transpose: bool, + ) -> torch.Tensor: + + alloc = TorchAllocator(indices.device) + stream = 0 + if indices.is_cuda: + stream = get_current_stream() + + num_act_out = SpconvOps.get_indice_pairs( + alloc, + torch_tensor_to_tv(indices), + batch_size, + spatial_shape, + algo.value, + ksize, + stride, + padding, + dilation, + out_padding, + subm, + transpose, + stream, + ) + if subm: + out_inds = indices + else: + out_inds = alloc.allocated[AllocKeys.OutIndices] + + pair = alloc.allocated[AllocKeys.PairFwd] + indice_num_per_loc = alloc.allocated[AllocKeys.IndiceNumPerLoc] + + num_act_out = torch.tensor([num_act_out], dtype=torch.int32).to(out_inds.device) + + return out_inds[:num_act_out], pair, indice_num_per_loc, num_act_out + + @staticmethod + def backward(ctx: Any, grad_output: torch.Tensor) -> tuple: + return None, None, None, None, None, None, None, None, None, None + + +class IndiceConvFunction(Function): + + @staticmethod + def symbolic( + g, + features, + filters, + indice_pairs, + indice_pair_num, + num_activate_out, + algo, + is_train, + is_subm, + timer: CUDAKernelTimer = CUDAKernelTimer(False), + bias: Optional[torch.Tensor] = None, + act_alpha: float = 0.0, + act_beta: float = 0.0, + act_type: tv.gemm.Activation = tv.gemm.Activation.None_, + ): + + output = g.op( + "autoware::IndiceConv", + features, + filters, + indice_pairs, + indice_pair_num, + num_activate_out, + is_subm_i=is_subm, + outputs=1, + ) + + features_shape = _get_tensor_sizes(features) + filters_shape = _get_tensor_sizes(filters) + if features_shape is not None and hasattr(features.type(), "with_sizes"): + output_type = features.type().with_sizes([features_shape[0], filters_shape[0]]) + output.setType(output_type) + + return output + + @staticmethod + def forward( + ctx, + features, + filters, + indice_pairs, + indice_pair_num, + num_activate_out, + algo, + is_train: False, + is_subm: True, + timer: CUDAKernelTimer = CUDAKernelTimer(False), + bias: Optional[torch.Tensor] = None, + act_alpha: float = 0.0, + act_beta: float = 0.0, + act_type: tv.gemm.Activation = tv.gemm.Activation.None_, + ): + + assert bias is None, "bias is not supported" + assert act_alpha == 0.0 + assert act_beta == 0.0 + assert act_type == tv.gemm.Activation.None_ + + ctx.save_for_backward(indice_pairs, indice_pair_num, features, filters) + ctx.algo = algo + ctx.timer = timer + try: + out = ops.indice_conv( + features, + filters, + indice_pairs, + indice_pair_num, + num_activate_out, + is_train, + is_subm, + algo=algo, + timer=timer, + bias=bias, + act_alpha=act_alpha, + act_beta=act_beta, + act_type=act_type, + ) + + return out + except Exception as e: + msg = f"[Exception|indice_conv|{'subm' if is_subm else 'not_subm'}]" + msg += f"feat={features.shape},w={filters.shape},pair={indice_pairs.shape}," + msg += f"pairnum={indice_pair_num},act={num_activate_out},algo={algo}" + print(msg, file=sys.stderr) + raise e + + @staticmethod + def backward(ctx: Any, grad_output: torch.Tensor) -> tuple: + return None, None, None, None, None, None, None, None, None, None + + class GetIndicePairsImplicitGemm(Function): @staticmethod @@ -136,6 +331,8 @@ def forward( do_sort=do_sort, ) + num_act_out = torch.tensor([num_act_out], dtype=torch.int32).to(indices.device) + mask_split_count = mask_tensor.dim(0) # NOTE(knzo25): we support only the simplest case assert mask_split_count == 1 @@ -320,5 +517,9 @@ def forward( return out_features +get_indice_pairs = GetIndicePairs.apply +indice_conv = IndiceConvFunction.apply + + get_indice_pairs_implicit_gemm = GetIndicePairsImplicitGemm.apply implicit_gemm = ImplicitGemm.apply