radix sort (#60)

This commit is contained in:
koide3 2024-06-12 10:28:33 +09:00 committed by GitHub
parent 11f5a304df
commit 4762de7460
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 214 additions and 0 deletions

View File

@ -0,0 +1,143 @@
// SPDX-FileCopyrightText: Copyright 2024 Kenji Koide
// SPDX-License-Identifier: MIT
#pragma once
#include <vector>
#include <algorithm>
#include <functional>
#include <tbb/tbb.h>
namespace small_gicp {
/// @brief Temporal buffers for radix sort.
template <typename T>
struct RadixSortBuffers {
std::vector<std::uint64_t> tile_buckets; //< Tiled buckets
std::vector<std::uint64_t> global_offsets; //< Global offsets
std::vector<T> sorted_buffer; //< Sorted objects
};
/// @brief Radix sort with TBB parallelization.
/// @note This function outperforms tbb::parallel_sort only in case with many elements and threads.
/// For usual data size and number of threads, use tbb::parallel_sort.
/// @tparam T Data type (must be unsigned integral type)
/// @tparam KeyFunc Key function
/// @tparam bits Number of bits per step
/// @tparam tile_size Tile size
/// @param first_ [in/out] First iterator
/// @param last_ [in/out] Last iterator
/// @param key_ Key function (T => uint)
/// @param buffers Temporal buffers
template <typename T, typename KeyFunc, int bits = 8, int tile_size = 256>
void radix_sort_tbb(T* first_, T* last_, const KeyFunc& key_, RadixSortBuffers<T>& buffers) {
if (first_ == last_) {
return;
}
auto first = first_;
auto last = last_;
using Key = decltype(key_(*first));
static_assert(std::is_unsigned_v<Key>, "Key must be unsigned integral type");
// Number of total radix sort steps.
constexpr int num_steps = (sizeof(Key) * 8 + bits - 1) / bits;
constexpr int num_bins = 1 << bits;
const std::uint64_t N = std::distance(first, last);
const std::uint64_t num_tiles = (N + tile_size - 1) / tile_size;
// Allocate buffers.
auto& tile_buckets = buffers.tile_buckets;
auto& global_offsets = buffers.global_offsets;
auto& sorted_buffer = buffers.sorted_buffer;
tile_buckets.resize(num_bins * num_tiles);
global_offsets.resize(num_bins);
sorted_buffer.resize(N);
auto sorted = sorted_buffer.data();
// Radix sort.
for (int step = 0; step < num_steps; step++) {
const auto key = [&](const auto& x) { return ((key_(x) >> (step * bits))) & ((1 << bits) - 1); };
// Create per-tile histograms.
std::fill(tile_buckets.begin(), tile_buckets.end(), 0);
tbb::parallel_for(tbb::blocked_range<std::uint64_t>(0, num_tiles, 4), [&](const tbb::blocked_range<std::uint64_t>& r) {
for (std::uint64_t tile = r.begin(); tile < r.end(); tile++) {
std::uint64_t data_begin = tile * tile_size;
std::uint64_t data_end = std::min<std::uint64_t>((tile + 1) * tile_size, N);
for (int i = data_begin; i < data_end; ++i) {
auto buckets = tile_buckets.data() + key(*(first + i)) * num_tiles;
++buckets[tile];
}
}
});
// Store the number of elements of the last tile, which will be overwritten by the next step, in global_offsets.
std::fill(global_offsets.begin(), global_offsets.end(), 0);
for (int i = 1; i < num_bins; i++) {
global_offsets[i] = tile_buckets[i * num_tiles - 1];
}
// Calculate per-tile offsets.
tbb::parallel_for(tbb::blocked_range<std::uint64_t>(0, num_bins, 1), [&](const tbb::blocked_range<std::uint64_t>& r) {
for (std::uint64_t bin = r.begin(); bin < r.end(); bin++) {
auto buckets = tile_buckets.data() + bin * num_tiles;
std::uint64_t last = buckets[0];
buckets[0] = 0;
for (std::uint64_t tile = 1; tile < num_tiles; tile++) {
std::uint64_t tmp = buckets[tile];
buckets[tile] = buckets[tile - 1] + last;
last = tmp;
}
}
});
// Calculate global offsets for each sorting bin.
for (int i = 1; i < num_bins; i++) {
global_offsets[i] += global_offsets[i - 1] + tile_buckets[i * num_tiles - 1];
}
// Sort elements.
tbb::parallel_for(tbb::blocked_range<std::uint64_t>(0, num_tiles, 8), [&](const tbb::blocked_range<std::uint64_t>& r) {
for (std::uint64_t tile = r.begin(); tile < r.end(); ++tile) {
std::uint64_t data_begin = tile * tile_size;
std::uint64_t data_end = std::min((tile + 1) * tile_size, static_cast<std::uint64_t>(N));
for (std::uint64_t i = data_begin; i < data_end; ++i) {
const T x = *(first + i);
const int bin = key(x);
auto offset = tile_buckets.data() + bin * num_tiles + tile;
sorted[global_offsets[bin] + ((*offset)++)] = x;
}
}
});
// Swap input and output buffers.
std::swap(first, sorted);
}
// Copy the result to the original buffer.
if (num_steps % 2 == 1) {
std::copy(sorted_buffer.begin(), sorted_buffer.end(), first_);
}
}
/// @brief Radix sort with TBB parallelization.
/// @tparam T Data type (must be unsigned integral type)
/// @tparam KeyFunc Key function
/// @tparam bits Number of bits per step
/// @tparam tile_size Tile size
/// @param first_ [in/out] First iterator
/// @param last_ [in/out] Last iterator
/// @param key_ Key function (T => uint)
template <typename T, typename KeyFunc, int bits = 4, int tile_size = 256>
void radix_sort_tbb(T* first_, T* last_, const KeyFunc& key_) {
RadixSortBuffers<T> buffers;
radix_sort_tbb(first_, last_, key_, buffers);
}
} // namespace small_gicp

View File

@ -0,0 +1,71 @@
#include <random>
#include <algorithm>
#include <fmt/format.h>
#include <small_gicp/util/sort_omp.hpp>
#include <small_gicp/util/sort_tbb.hpp>
#include <small_gicp/benchmark/benchmark.hpp>
#include <gtest/gtest.h>
using namespace small_gicp;
// Check if two vectors are identical
template <typename T>
bool identical(const std::vector<T>& arr1, const std::vector<T>& arr2) {
if (arr1.size() != arr2.size()) {
return false;
}
for (size_t i = 0; i < arr1.size(); i++) {
if (arr1[i] != arr2[i]) {
return false;
}
}
return true;
}
template <typename T>
void test_radix_sort(std::mt19937& mt) {
std::uniform_int_distribution<> size_dist(0, 8192);
for (int i = 0; i < 20; i++) {
std::vector<T> data(size_dist(mt));
std::generate(data.begin(), data.end(), [&] { return mt(); });
std::vector<T> sorted = data;
std::stable_sort(sorted.begin(), sorted.end());
std::vector<T> sorted_tbb = data;
radix_sort_tbb(sorted_tbb.data(), sorted_tbb.data() + sorted_tbb.size(), [](const T x) { return x; });
EXPECT_TRUE(identical(sorted, sorted_tbb)) << fmt::format("i={} N={}", i, data.size());
}
for (int i = 0; i < 20; i++) {
std::vector<std::pair<T, std::uint64_t>> data(size_dist(mt));
std::generate(data.begin(), data.end(), [&] { return std::make_pair<T, std::uint64_t>(mt(), mt()); });
std::vector<std::pair<T, std::uint64_t>> sorted = data;
std::stable_sort(sorted.begin(), sorted.end(), [](const auto& lhs, const auto& rhs) { return lhs.first < rhs.first; });
std::vector<std::pair<T, std::uint64_t>> sorted_tbb = data;
radix_sort_tbb(sorted_tbb.data(), sorted_tbb.data() + sorted_tbb.size(), [](const auto& x) -> T { return x.first; });
EXPECT_TRUE(identical(sorted, sorted_tbb)) << fmt::format("i={} N={}", i, data.size());
}
}
// Test radix_sort_tbb
TEST(SortTBB, RadixSortTest) {
std::mt19937 mt;
test_radix_sort<std::uint8_t>(mt);
test_radix_sort<std::uint16_t>(mt);
test_radix_sort<std::uint32_t>(mt);
test_radix_sort<std::uint64_t>(mt);
// empty
std::vector<std::uint64_t> empty_vector;
radix_sort_tbb(empty_vector.data(), empty_vector.data() + empty_vector.size(), [](const std::uint64_t x) { return x; });
EXPECT_TRUE(empty_vector.empty()) << "Empty vector check";
}