From 4762de7460ce4e9a2ae3f8e8680d5fecdde1e6a1 Mon Sep 17 00:00:00 2001 From: koide3 <31344317+koide3@users.noreply.github.com> Date: Wed, 12 Jun 2024 10:28:33 +0900 Subject: [PATCH] radix sort (#60) --- include/small_gicp/util/sort_tbb.hpp | 143 +++++++++++++++++++++++++++ src/test/sort_tbb_test.cpp | 71 +++++++++++++ 2 files changed, 214 insertions(+) create mode 100644 include/small_gicp/util/sort_tbb.hpp create mode 100644 src/test/sort_tbb_test.cpp diff --git a/include/small_gicp/util/sort_tbb.hpp b/include/small_gicp/util/sort_tbb.hpp new file mode 100644 index 0000000..0c3beec --- /dev/null +++ b/include/small_gicp/util/sort_tbb.hpp @@ -0,0 +1,143 @@ +// SPDX-FileCopyrightText: Copyright 2024 Kenji Koide +// SPDX-License-Identifier: MIT +#pragma once + +#include +#include +#include +#include + +namespace small_gicp { + +/// @brief Temporal buffers for radix sort. +template +struct RadixSortBuffers { + std::vector tile_buckets; //< Tiled buckets + std::vector global_offsets; //< Global offsets + std::vector sorted_buffer; //< Sorted objects +}; + +/// @brief Radix sort with TBB parallelization. +/// @note This function outperforms tbb::parallel_sort only in case with many elements and threads. +/// For usual data size and number of threads, use tbb::parallel_sort. +/// @tparam T Data type (must be unsigned integral type) +/// @tparam KeyFunc Key function +/// @tparam bits Number of bits per step +/// @tparam tile_size Tile size +/// @param first_ [in/out] First iterator +/// @param last_ [in/out] Last iterator +/// @param key_ Key function (T => uint) +/// @param buffers Temporal buffers +template +void radix_sort_tbb(T* first_, T* last_, const KeyFunc& key_, RadixSortBuffers& buffers) { + if (first_ == last_) { + return; + } + + auto first = first_; + auto last = last_; + + using Key = decltype(key_(*first)); + static_assert(std::is_unsigned_v, "Key must be unsigned integral type"); + + // Number of total radix sort steps. + constexpr int num_steps = (sizeof(Key) * 8 + bits - 1) / bits; + + constexpr int num_bins = 1 << bits; + const std::uint64_t N = std::distance(first, last); + const std::uint64_t num_tiles = (N + tile_size - 1) / tile_size; + + // Allocate buffers. + auto& tile_buckets = buffers.tile_buckets; + auto& global_offsets = buffers.global_offsets; + auto& sorted_buffer = buffers.sorted_buffer; + tile_buckets.resize(num_bins * num_tiles); + global_offsets.resize(num_bins); + sorted_buffer.resize(N); + + auto sorted = sorted_buffer.data(); + + // Radix sort. + for (int step = 0; step < num_steps; step++) { + const auto key = [&](const auto& x) { return ((key_(x) >> (step * bits))) & ((1 << bits) - 1); }; + + // Create per-tile histograms. + std::fill(tile_buckets.begin(), tile_buckets.end(), 0); + tbb::parallel_for(tbb::blocked_range(0, num_tiles, 4), [&](const tbb::blocked_range& r) { + for (std::uint64_t tile = r.begin(); tile < r.end(); tile++) { + std::uint64_t data_begin = tile * tile_size; + std::uint64_t data_end = std::min((tile + 1) * tile_size, N); + + for (int i = data_begin; i < data_end; ++i) { + auto buckets = tile_buckets.data() + key(*(first + i)) * num_tiles; + ++buckets[tile]; + } + } + }); + + // Store the number of elements of the last tile, which will be overwritten by the next step, in global_offsets. + std::fill(global_offsets.begin(), global_offsets.end(), 0); + for (int i = 1; i < num_bins; i++) { + global_offsets[i] = tile_buckets[i * num_tiles - 1]; + } + + // Calculate per-tile offsets. + tbb::parallel_for(tbb::blocked_range(0, num_bins, 1), [&](const tbb::blocked_range& r) { + for (std::uint64_t bin = r.begin(); bin < r.end(); bin++) { + auto buckets = tile_buckets.data() + bin * num_tiles; + std::uint64_t last = buckets[0]; + buckets[0] = 0; + + for (std::uint64_t tile = 1; tile < num_tiles; tile++) { + std::uint64_t tmp = buckets[tile]; + buckets[tile] = buckets[tile - 1] + last; + last = tmp; + } + } + }); + + // Calculate global offsets for each sorting bin. + for (int i = 1; i < num_bins; i++) { + global_offsets[i] += global_offsets[i - 1] + tile_buckets[i * num_tiles - 1]; + } + + // Sort elements. + tbb::parallel_for(tbb::blocked_range(0, num_tiles, 8), [&](const tbb::blocked_range& r) { + for (std::uint64_t tile = r.begin(); tile < r.end(); ++tile) { + std::uint64_t data_begin = tile * tile_size; + std::uint64_t data_end = std::min((tile + 1) * tile_size, static_cast(N)); + + for (std::uint64_t i = data_begin; i < data_end; ++i) { + const T x = *(first + i); + const int bin = key(x); + auto offset = tile_buckets.data() + bin * num_tiles + tile; + sorted[global_offsets[bin] + ((*offset)++)] = x; + } + } + }); + + // Swap input and output buffers. + std::swap(first, sorted); + } + + // Copy the result to the original buffer. + if (num_steps % 2 == 1) { + std::copy(sorted_buffer.begin(), sorted_buffer.end(), first_); + } +} + +/// @brief Radix sort with TBB parallelization. +/// @tparam T Data type (must be unsigned integral type) +/// @tparam KeyFunc Key function +/// @tparam bits Number of bits per step +/// @tparam tile_size Tile size +/// @param first_ [in/out] First iterator +/// @param last_ [in/out] Last iterator +/// @param key_ Key function (T => uint) +template +void radix_sort_tbb(T* first_, T* last_, const KeyFunc& key_) { + RadixSortBuffers buffers; + radix_sort_tbb(first_, last_, key_, buffers); +} + +} // namespace small_gicp \ No newline at end of file diff --git a/src/test/sort_tbb_test.cpp b/src/test/sort_tbb_test.cpp new file mode 100644 index 0000000..5772e2a --- /dev/null +++ b/src/test/sort_tbb_test.cpp @@ -0,0 +1,71 @@ +#include +#include +#include +#include +#include +#include + +#include + +using namespace small_gicp; + +// Check if two vectors are identical +template +bool identical(const std::vector& arr1, const std::vector& arr2) { + if (arr1.size() != arr2.size()) { + return false; + } + + for (size_t i = 0; i < arr1.size(); i++) { + if (arr1[i] != arr2[i]) { + return false; + } + } + return true; +} + +template +void test_radix_sort(std::mt19937& mt) { + std::uniform_int_distribution<> size_dist(0, 8192); + + for (int i = 0; i < 20; i++) { + std::vector data(size_dist(mt)); + std::generate(data.begin(), data.end(), [&] { return mt(); }); + + std::vector sorted = data; + std::stable_sort(sorted.begin(), sorted.end()); + + std::vector sorted_tbb = data; + radix_sort_tbb(sorted_tbb.data(), sorted_tbb.data() + sorted_tbb.size(), [](const T x) { return x; }); + + EXPECT_TRUE(identical(sorted, sorted_tbb)) << fmt::format("i={} N={}", i, data.size()); + } + + for (int i = 0; i < 20; i++) { + std::vector> data(size_dist(mt)); + std::generate(data.begin(), data.end(), [&] { return std::make_pair(mt(), mt()); }); + + std::vector> sorted = data; + std::stable_sort(sorted.begin(), sorted.end(), [](const auto& lhs, const auto& rhs) { return lhs.first < rhs.first; }); + + std::vector> sorted_tbb = data; + radix_sort_tbb(sorted_tbb.data(), sorted_tbb.data() + sorted_tbb.size(), [](const auto& x) -> T { return x.first; }); + + EXPECT_TRUE(identical(sorted, sorted_tbb)) << fmt::format("i={} N={}", i, data.size()); + } +} + +// Test radix_sort_tbb +TEST(SortTBB, RadixSortTest) { + std::mt19937 mt; + + test_radix_sort(mt); + test_radix_sort(mt); + test_radix_sort(mt); + test_radix_sort(mt); + + // empty + std::vector empty_vector; + radix_sort_tbb(empty_vector.data(), empty_vector.data() + empty_vector.size(), [](const std::uint64_t x) { return x; }); + EXPECT_TRUE(empty_vector.empty()) << "Empty vector check"; +}