drum / dataset.py
Kazuto Nakashima
Update README, add core functionality, and set up dependencies for DRUM project
1cd38c4
import re
import zipfile
from collections import defaultdict
from pathlib import Path
import numba
import numpy as np
import torch
import torchvision.transforms.functional as TF
from huggingface_hub import snapshot_download
from tqdm.auto import tqdm
@numba.jit(nopython=True)
def scatter(array, index, value):
for i in range(len(index)):
h, w = index[i]
v = value[i]
array[h, w] = v
return array
class SynLiDAR(torch.utils.data.Dataset):
_LABEL_MAP_SynLiDAR = {
0: 0, # "unlabeled"
1: 1, # "car"
2: 4, # "pick-up"
3: 4, # "truck"
4: 5, # "bus"
5: 2, # "bicycle"
6: 3, # "motorcycle"
7: 5, # "other-vehicle"
8: 9, # "road"
9: 11, # "sidewalk"
10: 10, # "parking"
11: 12, # "other-ground"
12: 6, # "female"
13: 6, # "male"
14: 6, # "kid"
15: 6, # "crowd"
16: 7, # "bicyclist"
17: 8, # "motorcyclist"
18: 13, # "building"
19: 0, # "other-structure"
20: 15, # "vegetation"
21: 16, # "trunk"
22: 17, # "terrain"
23: 19, # "traffic-sign"
24: 18, # "pole"
25: 0, # "traffic-cone"
26: 14, # "fence"
27: 0, # "garbage-can"
28: 0, # "electric-box"
29: 0, # "table"
30: 0, # "chair"
31: 0, # "bench"
32: 0, # "other-object"
}
def __init__(
self,
root=None,
revision="sub",
split="all",
shape=(64, 1024),
min_depth=0.0,
max_depth=80.0,
flip=False,
scan_unfolding=False,
force_download=False,
):
super().__init__()
self.root = root
self.revision = revision
self.split = split
self.shape = tuple(shape)
self.min_depth = min_depth
self.max_depth = max_depth
self.flip = flip
self.scan_unfolding = scan_unfolding
self.force_download = force_download
self.data_list = []
self.dataset_dir = None
self._download()
def _download(self):
dataset_dir = snapshot_download(
repo_id="AR-X/SynLiDAR",
repo_type="dataset",
revision=self.revision,
local_dir=self.root,
force_download=self.force_download,
)
self.dataset_dir = dataset_dir
# parse zip files and build data list
data_list = defaultdict(dict)
for zip_path in Path(dataset_dir).glob("*/sequences/*.zip"):
with zipfile.ZipFile(zip_path) as zf:
for info in zf.infolist():
if not info.is_dir():
if info.filename.endswith(".bin"):
key = "points"
pattern = r"^(\d{2})/velodyne/(\d+)\.bin$"
elif info.filename.endswith(".label"):
key = "labels"
pattern = r"^(\d{2})/labels/(\d+)\.label$"
sample_id = re.sub(pattern, r"\1-\2", info.filename)
data_list[sample_id][key] = (zip_path, info.filename)
# integrity check
for info in tqdm(data_list.values(), leave=False):
assert "points" in info, f"Missing points in key={info}"
assert "labels" in info, f"Missing labels in key={info}"
self.data_list = list(data_list.values())
def _read_data_from_zip(self, index):
info = self.data_list[index]
point_zipfile, point_filename = info["points"]
with zipfile.ZipFile(point_zipfile) as zf:
points = np.frombuffer(zf.read(point_filename), dtype=np.float32)
label_zipfile, label_filename = info["labels"]
with zipfile.ZipFile(label_zipfile) as zf:
labels = np.frombuffer(zf.read(label_filename), dtype=np.uint32)
return points.reshape((-1, 4)), labels.reshape((-1, 1))
def __getitem__(self, index):
points_flat, labels_flat = self._read_data_from_zip(index)
xyzrdm, labels = self.projection(points_flat, labels_flat)
xyzrdm = xyzrdm.transpose(2, 0, 1)
labels = labels.transpose(2, 0, 1)
xyzrdm *= xyzrdm[[5]]
xyzrdm = torch.from_numpy(xyzrdm).float()
labels = torch.from_numpy(labels).long()
return {
"xyz": xyzrdm[:3],
"reflectance": xyzrdm[[3]],
"depth": xyzrdm[[4]],
"mask": xyzrdm[[5]],
"label": labels,
}
def __len__(self):
return len(self.data_list)
def __repr__(self) -> str:
head = "Dataset " + self.__class__.__name__
body = [f"Number of datapoints: {self.__len__()}"]
body += [f"Root location: {self.dataset_dir}"]
body += [f"Split: {self.split}"]
body += [f"Scan unfolding: {self.scan_unfolding}"]
lines = [head] + [" " + line for line in body]
return "\n".join(lines)
def normalize(self, item):
new_item = {}
for key, value in item.items():
if key in ("xyz", "reflectance", "depth"):
value = TF.normalize(value, self.mean[key], self.std[key])
new_item[key] = value
return new_item
@property
def mean(self):
return {
"xyz": [-0.01506443, 0.45959818, -0.89225304],
"reflectance": 0.24130844,
"depth": 9.689281,
}
@property
def std(self):
return {
"xyz": [11.224804, 8.237693, 0.88183135],
"reflectance": 0.16860831,
"depth": 10.08752,
}
@property
def class_list(self):
return [
"unlabeled", # 0
"car", # 1
"bicycle", # 2
"motorcycle", # 3
"truck", # 4
"other-vehicle", # 5
"person", # 6
"bicyclist", # 7
"motorcyclist", # 8
"road", # 9
"parking", # 10
"sidewalk", # 11
"other-ground", # 12
"building", # 13
"fence", # 14
"vegetation", # 15
"trunk", # 16
"terrain", # 17
"pole", # 18
"traffic-sign", # 19
]
def projection(
self,
points,
labels,
):
H, W = self.shape
xyz = points[:, :3] # xyz
x = xyz[:, [0]]
y = xyz[:, [1]]
z = xyz[:, [2]]
depth = np.linalg.norm(xyz, ord=2, axis=1, keepdims=True)
mask = (depth >= self.min_depth) & (depth <= self.max_depth)
points = np.concatenate([points, depth, mask], axis=1)
if self.scan_unfolding:
# the i-th quadrant
# suppose the points are ordered counterclockwise
quads = np.zeros_like(x, dtype=np.int32)
quads[(x >= 0) & (y >= 0)] = 0 # 1st
quads[(x < 0) & (y >= 0)] = 1 # 2nd
quads[(x < 0) & (y < 0)] = 2 # 3rd
quads[(x >= 0) & (y < 0)] = 3 # 4th
# split between the 3rd and 1st quadrants
diff = np.roll(quads, shift=1, axis=0) - quads
delim_inds, _ = np.where(diff == 3) # number of lines
inds = [*list(delim_inds), len(points)] # add the last index
# vertical grid
grid_h = np.zeros_like(x, dtype=np.int32)
cur_ring_idx = H - 1 # ...0
for i in reversed(range(len(delim_inds))):
grid_h[inds[i] : inds[i + 1]] = cur_ring_idx
if cur_ring_idx >= 0:
cur_ring_idx -= 1
else:
break
else:
h_up, h_down = np.deg2rad(3), np.deg2rad(-25)
elevation = np.arcsin(z / depth) + abs(h_down)
grid_h = 1 - elevation / (h_up - h_down)
grid_h = np.floor(grid_h * H).clip(0, H - 1).astype(np.int32)
# horizontal grid
azimuth = -np.arctan2(y, x) # [-pi,pi]
grid_w = (azimuth / np.pi + 1) / 2 % 1 # [0,1]
grid_w = np.floor(grid_w * W).clip(0, W - 1).astype(np.int32)
grid = np.concatenate((grid_h, grid_w), axis=1)
# projection
order = np.argsort(-depth.squeeze(1))
proj_points = np.zeros((H, W, 4 + 2), dtype=points.dtype)
proj_points = scatter(proj_points, grid[order], points[order])
labels = np.vectorize(self._LABEL_MAP_SynLiDAR.__getitem__)(labels & 0xFFFF)
proj_labels = np.zeros((H, W, 1), dtype=labels.dtype)
proj_labels = scatter(proj_labels, grid[order], labels[order])
return proj_points.astype(np.float32), proj_labels.astype(np.int64)