parallel batch nearest neighbor search (#68)

This commit is contained in:
koide3 2024-06-20 11:25:51 +09:00 committed by GitHub
parent ac6c79acb6
commit 7e42a90d27
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 94 additions and 4 deletions

View File

@ -58,6 +58,7 @@ void define_kdtree(py::module& m) {
k_sq_dist : float
The squared distance to the nearest neighbor.
)""")
.def(
"knn_search",
[](const KdTree<PointCloud>& kdtree, const Eigen::Vector3d& pt, int k) {
@ -85,11 +86,18 @@ void define_kdtree(py::module& m) {
k_sq_dists : NDArray, shape (k,)
The squared distances to the k nearest neighbors.
)""")
.def(
"batch_nearest_neighbor_search",
[](const KdTree<PointCloud>& kdtree, const Eigen::MatrixXd& pts) {
[](const KdTree<PointCloud>& kdtree, const Eigen::MatrixXd& pts, int num_threads) {
if (pts.cols() != 3 && pts.cols() != 4) {
throw std::invalid_argument("pts must have shape (n, 3) or (n, 4)");
}
std::vector<size_t> k_indices(pts.rows(), -1);
std::vector<double> k_sq_dists(pts.rows(), std::numeric_limits<double>::max());
#pragma omp parallel for num_threads(num_threads)
for (int i = 0; i < pts.rows(); ++i) {
const size_t found = traits::nearest_neighbor_search(kdtree, Eigen::Vector4d(pts(i, 0), pts(i, 1), pts(i, 2), 1.0), &k_indices[i], &k_sq_dists[i]);
if (!found) {
@ -97,16 +105,20 @@ void define_kdtree(py::module& m) {
k_sq_dists[i] = std::numeric_limits<double>::max();
}
}
return std::make_pair(k_indices, k_sq_dists);
},
py::arg("pts"),
py::arg("num_threads") = 1,
R"""(
Find the nearest neighbors for a batch of points.
Parameters
----------
pts : NDArray, shape (n, 3)
pts : NDArray, shape (n, 3) or (n, 4)
The input points.
num_threads : int, optional
The number of threads to use for the search. Default is 1.
Returns
-------
@ -115,11 +127,18 @@ void define_kdtree(py::module& m) {
k_sq_dists : NDArray, shape (n,)
The squared distances to the nearest neighbors for each input point.
)""")
.def(
"batch_knn_search",
[](const KdTree<PointCloud>& kdtree, const Eigen::MatrixXd& pts, int k) {
[](const KdTree<PointCloud>& kdtree, const Eigen::MatrixXd& pts, int k, int num_threads) {
if (pts.cols() != 3 && pts.cols() != 4) {
throw std::invalid_argument("pts must have shape (n, 3) or (n, 4)");
}
std::vector<std::vector<size_t>> k_indices(pts.rows(), std::vector<size_t>(k, -1));
std::vector<std::vector<double>> k_sq_dists(pts.rows(), std::vector<double>(k, std::numeric_limits<double>::max()));
#pragma omp parallel for num_threads(num_threads)
for (int i = 0; i < pts.rows(); ++i) {
const size_t found = traits::knn_search(kdtree, Eigen::Vector4d(pts(i, 0), pts(i, 1), pts(i, 2), 1.0), k, k_indices[i].data(), k_sq_dists[i].data());
if (found < k) {
@ -129,19 +148,23 @@ void define_kdtree(py::module& m) {
}
}
}
return std::make_pair(k_indices, k_sq_dists);
},
py::arg("pts"),
py::arg("k"),
py::arg("num_threads") = 1,
R"""(
Find the k nearest neighbors for a batch of points.
Parameters
----------
pts : NDArray, shape (n, 3)
pts : NDArray, shape (n, 3) or (n, 4)
The input points.
k : int
The number of nearest neighbors to search for.
num_threads : int, optional
The number of threads to use for the search. Default is 1.
Returns
-------

View File

@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright 2024 Kenji Koide
# SPDX-License-Identifier: MIT
import numpy
from scipy.spatial import KDTree
from scipy.spatial.transform import Rotation
import small_gicp
@ -188,3 +189,69 @@ def test_registration(load_points):
result = small_gicp.align(target_voxelmap, source)
verify_result(result.T_target_source, gt_T_target_source)
# KdTree test
def test_kdtree(load_points):
_, target_raw_numpy, source_raw_numpy = load_points
target, target_tree = small_gicp.preprocess_points(target_raw_numpy, downsampling_resolution=0.5)
source, source_tree = small_gicp.preprocess_points(source_raw_numpy, downsampling_resolution=0.5)
target_tree_ref = KDTree(target.points())
source_tree_ref = KDTree(source.points())
def batch_test(points, queries, tree, tree_ref, num_threads):
# test for batch interface
k_dists_ref, k_indices_ref = tree_ref.query(queries, k=1)
k_indices, k_sq_dists = tree.batch_nearest_neighbor_search(queries)
assert numpy.all(numpy.abs(numpy.square(k_dists_ref) - k_sq_dists) < 1e-6)
assert numpy.all(numpy.abs(numpy.linalg.norm(points[k_indices] - queries, axis=1) ** 2 - k_sq_dists) < 1e-6)
for k in [2, 10]:
k_dists_ref, k_indices_ref = tree_ref.query(queries, k=k)
k_sq_dists_ref, k_indices_ref = numpy.array(k_dists_ref) ** 2, numpy.array(k_indices_ref)
k_indices, k_sq_dists = tree.batch_knn_search(queries, k, num_threads=num_threads)
k_indices, k_sq_dists = numpy.array(k_indices), numpy.array(k_sq_dists)
assert(numpy.all(numpy.abs(k_sq_dists_ref - k_sq_dists) < 1e-6))
for i in range(k):
diff = numpy.linalg.norm(points[k_indices[:, i]] - queries, axis=1) ** 2 - k_sq_dists[:, i]
assert(numpy.all(numpy.abs(diff) < 1e-6))
# test for single query interface
if num_threads != 1:
return
k_dists_ref, k_indices_ref = tree_ref.query(queries, k=1)
k_indices2, k_sq_dists2 = [], []
for query in queries:
found, index, sq_dist = tree.nearest_neighbor_search(query[:3])
assert found
k_indices2.append(index)
k_sq_dists2.append(sq_dist)
assert numpy.all(numpy.abs(numpy.square(k_dists_ref) - k_sq_dists2) < 1e-6)
assert numpy.all(numpy.abs(numpy.linalg.norm(points[k_indices2] - queries, axis=1) ** 2 - k_sq_dists2) < 1e-6)
for k in [2, 10]:
k_dists_ref, k_indices_ref = tree_ref.query(queries, k=k)
k_sq_dists_ref, k_indices_ref = numpy.array(k_dists_ref) ** 2, numpy.array(k_indices_ref)
k_indices2, k_sq_dists2 = [], []
for query in queries:
indices, sq_dists = tree.knn_search(query[:3], k)
k_indices2.append(indices)
k_sq_dists2.append(sq_dists)
k_indices2, k_sq_dists2 = numpy.array(k_indices2), numpy.array(k_sq_dists2)
assert(numpy.all(numpy.abs(k_sq_dists_ref - k_sq_dists2) < 1e-6))
for i in range(k):
diff = numpy.linalg.norm(points[k_indices2[:, i]] - queries, axis=1) ** 2 - k_sq_dists2[:, i]
assert(numpy.all(numpy.abs(diff) < 1e-6))
for num_threads in [1, 2]:
batch_test(target.points(), target.points(), target_tree, target_tree_ref, num_threads=num_threads)
batch_test(target.points(), source.points(), target_tree, target_tree_ref, num_threads=num_threads)
batch_test(source.points(), target.points(), source_tree, source_tree_ref, num_threads=num_threads)