mirror of https://github.com/koide3/small_gicp.git
parallel batch nearest neighbor search (#68)
This commit is contained in:
parent
ac6c79acb6
commit
7e42a90d27
|
|
@ -58,6 +58,7 @@ void define_kdtree(py::module& m) {
|
|||
k_sq_dist : float
|
||||
The squared distance to the nearest neighbor.
|
||||
)""")
|
||||
|
||||
.def(
|
||||
"knn_search",
|
||||
[](const KdTree<PointCloud>& kdtree, const Eigen::Vector3d& pt, int k) {
|
||||
|
|
@ -85,11 +86,18 @@ void define_kdtree(py::module& m) {
|
|||
k_sq_dists : NDArray, shape (k,)
|
||||
The squared distances to the k nearest neighbors.
|
||||
)""")
|
||||
|
||||
.def(
|
||||
"batch_nearest_neighbor_search",
|
||||
[](const KdTree<PointCloud>& kdtree, const Eigen::MatrixXd& pts) {
|
||||
[](const KdTree<PointCloud>& kdtree, const Eigen::MatrixXd& pts, int num_threads) {
|
||||
if (pts.cols() != 3 && pts.cols() != 4) {
|
||||
throw std::invalid_argument("pts must have shape (n, 3) or (n, 4)");
|
||||
}
|
||||
|
||||
std::vector<size_t> k_indices(pts.rows(), -1);
|
||||
std::vector<double> k_sq_dists(pts.rows(), std::numeric_limits<double>::max());
|
||||
|
||||
#pragma omp parallel for num_threads(num_threads)
|
||||
for (int i = 0; i < pts.rows(); ++i) {
|
||||
const size_t found = traits::nearest_neighbor_search(kdtree, Eigen::Vector4d(pts(i, 0), pts(i, 1), pts(i, 2), 1.0), &k_indices[i], &k_sq_dists[i]);
|
||||
if (!found) {
|
||||
|
|
@ -97,16 +105,20 @@ void define_kdtree(py::module& m) {
|
|||
k_sq_dists[i] = std::numeric_limits<double>::max();
|
||||
}
|
||||
}
|
||||
|
||||
return std::make_pair(k_indices, k_sq_dists);
|
||||
},
|
||||
py::arg("pts"),
|
||||
py::arg("num_threads") = 1,
|
||||
R"""(
|
||||
Find the nearest neighbors for a batch of points.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
pts : NDArray, shape (n, 3)
|
||||
pts : NDArray, shape (n, 3) or (n, 4)
|
||||
The input points.
|
||||
num_threads : int, optional
|
||||
The number of threads to use for the search. Default is 1.
|
||||
|
||||
Returns
|
||||
-------
|
||||
|
|
@ -115,11 +127,18 @@ void define_kdtree(py::module& m) {
|
|||
k_sq_dists : NDArray, shape (n,)
|
||||
The squared distances to the nearest neighbors for each input point.
|
||||
)""")
|
||||
|
||||
.def(
|
||||
"batch_knn_search",
|
||||
[](const KdTree<PointCloud>& kdtree, const Eigen::MatrixXd& pts, int k) {
|
||||
[](const KdTree<PointCloud>& kdtree, const Eigen::MatrixXd& pts, int k, int num_threads) {
|
||||
if (pts.cols() != 3 && pts.cols() != 4) {
|
||||
throw std::invalid_argument("pts must have shape (n, 3) or (n, 4)");
|
||||
}
|
||||
|
||||
std::vector<std::vector<size_t>> k_indices(pts.rows(), std::vector<size_t>(k, -1));
|
||||
std::vector<std::vector<double>> k_sq_dists(pts.rows(), std::vector<double>(k, std::numeric_limits<double>::max()));
|
||||
|
||||
#pragma omp parallel for num_threads(num_threads)
|
||||
for (int i = 0; i < pts.rows(); ++i) {
|
||||
const size_t found = traits::knn_search(kdtree, Eigen::Vector4d(pts(i, 0), pts(i, 1), pts(i, 2), 1.0), k, k_indices[i].data(), k_sq_dists[i].data());
|
||||
if (found < k) {
|
||||
|
|
@ -129,19 +148,23 @@ void define_kdtree(py::module& m) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
return std::make_pair(k_indices, k_sq_dists);
|
||||
},
|
||||
py::arg("pts"),
|
||||
py::arg("k"),
|
||||
py::arg("num_threads") = 1,
|
||||
R"""(
|
||||
Find the k nearest neighbors for a batch of points.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
pts : NDArray, shape (n, 3)
|
||||
pts : NDArray, shape (n, 3) or (n, 4)
|
||||
The input points.
|
||||
k : int
|
||||
The number of nearest neighbors to search for.
|
||||
num_threads : int, optional
|
||||
The number of threads to use for the search. Default is 1.
|
||||
|
||||
Returns
|
||||
-------
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
# SPDX-FileCopyrightText: Copyright 2024 Kenji Koide
|
||||
# SPDX-License-Identifier: MIT
|
||||
import numpy
|
||||
from scipy.spatial import KDTree
|
||||
from scipy.spatial.transform import Rotation
|
||||
|
||||
import small_gicp
|
||||
|
|
@ -188,3 +189,69 @@ def test_registration(load_points):
|
|||
|
||||
result = small_gicp.align(target_voxelmap, source)
|
||||
verify_result(result.T_target_source, gt_T_target_source)
|
||||
|
||||
# KdTree test
|
||||
def test_kdtree(load_points):
|
||||
_, target_raw_numpy, source_raw_numpy = load_points
|
||||
|
||||
target, target_tree = small_gicp.preprocess_points(target_raw_numpy, downsampling_resolution=0.5)
|
||||
source, source_tree = small_gicp.preprocess_points(source_raw_numpy, downsampling_resolution=0.5)
|
||||
|
||||
target_tree_ref = KDTree(target.points())
|
||||
source_tree_ref = KDTree(source.points())
|
||||
|
||||
def batch_test(points, queries, tree, tree_ref, num_threads):
|
||||
# test for batch interface
|
||||
k_dists_ref, k_indices_ref = tree_ref.query(queries, k=1)
|
||||
k_indices, k_sq_dists = tree.batch_nearest_neighbor_search(queries)
|
||||
assert numpy.all(numpy.abs(numpy.square(k_dists_ref) - k_sq_dists) < 1e-6)
|
||||
assert numpy.all(numpy.abs(numpy.linalg.norm(points[k_indices] - queries, axis=1) ** 2 - k_sq_dists) < 1e-6)
|
||||
|
||||
for k in [2, 10]:
|
||||
k_dists_ref, k_indices_ref = tree_ref.query(queries, k=k)
|
||||
k_sq_dists_ref, k_indices_ref = numpy.array(k_dists_ref) ** 2, numpy.array(k_indices_ref)
|
||||
|
||||
k_indices, k_sq_dists = tree.batch_knn_search(queries, k, num_threads=num_threads)
|
||||
k_indices, k_sq_dists = numpy.array(k_indices), numpy.array(k_sq_dists)
|
||||
|
||||
assert(numpy.all(numpy.abs(k_sq_dists_ref - k_sq_dists) < 1e-6))
|
||||
for i in range(k):
|
||||
diff = numpy.linalg.norm(points[k_indices[:, i]] - queries, axis=1) ** 2 - k_sq_dists[:, i]
|
||||
assert(numpy.all(numpy.abs(diff) < 1e-6))
|
||||
|
||||
# test for single query interface
|
||||
if num_threads != 1:
|
||||
return
|
||||
|
||||
k_dists_ref, k_indices_ref = tree_ref.query(queries, k=1)
|
||||
k_indices2, k_sq_dists2 = [], []
|
||||
for query in queries:
|
||||
found, index, sq_dist = tree.nearest_neighbor_search(query[:3])
|
||||
assert found
|
||||
k_indices2.append(index)
|
||||
k_sq_dists2.append(sq_dist)
|
||||
|
||||
assert numpy.all(numpy.abs(numpy.square(k_dists_ref) - k_sq_dists2) < 1e-6)
|
||||
assert numpy.all(numpy.abs(numpy.linalg.norm(points[k_indices2] - queries, axis=1) ** 2 - k_sq_dists2) < 1e-6)
|
||||
|
||||
for k in [2, 10]:
|
||||
k_dists_ref, k_indices_ref = tree_ref.query(queries, k=k)
|
||||
k_sq_dists_ref, k_indices_ref = numpy.array(k_dists_ref) ** 2, numpy.array(k_indices_ref)
|
||||
|
||||
k_indices2, k_sq_dists2 = [], []
|
||||
for query in queries:
|
||||
indices, sq_dists = tree.knn_search(query[:3], k)
|
||||
k_indices2.append(indices)
|
||||
k_sq_dists2.append(sq_dists)
|
||||
k_indices2, k_sq_dists2 = numpy.array(k_indices2), numpy.array(k_sq_dists2)
|
||||
|
||||
assert(numpy.all(numpy.abs(k_sq_dists_ref - k_sq_dists2) < 1e-6))
|
||||
for i in range(k):
|
||||
diff = numpy.linalg.norm(points[k_indices2[:, i]] - queries, axis=1) ** 2 - k_sq_dists2[:, i]
|
||||
assert(numpy.all(numpy.abs(diff) < 1e-6))
|
||||
|
||||
|
||||
for num_threads in [1, 2]:
|
||||
batch_test(target.points(), target.points(), target_tree, target_tree_ref, num_threads=num_threads)
|
||||
batch_test(target.points(), source.points(), target_tree, target_tree_ref, num_threads=num_threads)
|
||||
batch_test(source.points(), target.points(), source_tree, source_tree_ref, num_threads=num_threads)
|
||||
|
|
|
|||
Loading…
Reference in New Issue