diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..bff7192 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,75 @@ +{ + "files.associations": { + "array": "cpp", + "__bit_reference": "cpp", + "__config": "cpp", + "__debug": "cpp", + "__errc": "cpp", + "__functional_base": "cpp", + "__hash_table": "cpp", + "__locale": "cpp", + "__mutex_base": "cpp", + "__node_handle": "cpp", + "__nullptr": "cpp", + "__split_buffer": "cpp", + "__string": "cpp", + "__threading_support": "cpp", + "__tree": "cpp", + "__tuple": "cpp", + "algorithm": "cpp", + "atomic": "cpp", + "bit": "cpp", + "bitset": "cpp", + "cctype": "cpp", + "chrono": "cpp", + "cmath": "cpp", + "complex": "cpp", + "csignal": "cpp", + "cstdarg": "cpp", + "cstddef": "cpp", + "cstdint": "cpp", + "cstdio": "cpp", + "cstdlib": "cpp", + "cstring": "cpp", + "ctime": "cpp", + "cwchar": "cpp", + "cwctype": "cpp", + "deque": "cpp", + "exception": "cpp", + "fstream": "cpp", + "functional": "cpp", + "initializer_list": "cpp", + "iomanip": "cpp", + "ios": "cpp", + "iosfwd": "cpp", + "iostream": "cpp", + "istream": "cpp", + "iterator": "cpp", + "limits": "cpp", + "locale": "cpp", + "map": "cpp", + "memory": "cpp", + "mutex": "cpp", + "new": "cpp", + "numeric": "cpp", + "optional": "cpp", + "ostream": "cpp", + "random": "cpp", + "ratio": "cpp", + "set": "cpp", + "sstream": "cpp", + "stack": "cpp", + "stdexcept": "cpp", + "streambuf": "cpp", + "string": "cpp", + "string_view": "cpp", + "system_error": "cpp", + "tuple": "cpp", + "type_traits": "cpp", + "typeinfo": "cpp", + "unordered_map": "cpp", + "utility": "cpp", + "vector": "cpp" + }, + "C_Cpp.errorSquiggles": "Disabled" +} \ No newline at end of file diff --git a/CMakeLists.txt b/CMakeLists.txt index 75c5a1a..94f4e3d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -21,6 +21,7 @@ else () message(STATUS "Compiler is recent enough to support C++17.") endif () +# set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++1z -pthread -msse4.2 -mbmi2 -Wall") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++1z -pthread -Wall") set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -DNDEBUG -march=native -O3") set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -fsanitize=address -fno-omit-frame-pointer -O0 -g -DDEBUG") diff --git a/include/xcdat/bit_tools.hpp b/include/xcdat/bit_tools.hpp index 559993f..b814b40 100644 --- a/include/xcdat/bit_tools.hpp +++ b/include/xcdat/bit_tools.hpp @@ -3,8 +3,13 @@ #include #include -#include +#ifdef __SSE4_2__ #include +#endif + +#ifdef __BMI2__ +#include +#endif // From https://github.com/ot/succinct namespace xcdat::bit_tools { @@ -28,8 +33,37 @@ inline std::uint64_t popcount(std::uint64_t x) { #endif } +static constexpr std::uint8_t debruijn64_mapping[64] = { + 63, 0, 58, 1, 59, 47, 53, 2, 60, 39, 48, 27, 54, 33, 42, 3, 61, 51, 37, 40, 49, 18, + 28, 20, 55, 30, 34, 11, 43, 14, 22, 4, 62, 57, 46, 52, 38, 26, 32, 41, 50, 36, 17, 19, + 29, 10, 13, 21, 56, 45, 25, 31, 35, 16, 9, 12, 44, 24, 15, 8, 23, 7, 6, 5, +}; + +static constexpr std::uint64_t debruijn64 = 0x07EDD5E59A4E28C2ULL; + +// return the position of the single bit set in the word x +inline std::uint8_t bit_position(std::uint64_t x) { + return debruijn64_mapping[(x * debruijn64) >> 58]; +} + inline std::uint64_t msb(std::uint64_t x) { +#ifdef __SSE4_2__ return x == 0 ? 0 : 63 - __builtin_clzll(x); +#else + if (x == 0) { + return 0; + } + // right-saturate the word + x |= x >> 1; + x |= x >> 2; + x |= x >> 4; + x |= x >> 8; + x |= x >> 16; + x |= x >> 32; + // isolate the MSB + x ^= x >> 1; + return bit_position(x); +#endif } inline std::uint64_t uleq_step_9(std::uint64_t x, std::uint64_t y) { diff --git a/include/xcdat/code_table.hpp b/include/xcdat/code_table.hpp new file mode 100644 index 0000000..4604c2b --- /dev/null +++ b/include/xcdat/code_table.hpp @@ -0,0 +1,85 @@ +#pragma once + +#include +#include + +#include "mm_vector.hpp" + +namespace xcdat { + +class code_table { + private: + std::uint64_t m_max_length = 0; + std::array m_table; + mm_vector m_alphabet; + + struct cf_type { + std::uint8_t ch; + std::uint64_t freq; + }; + + public: + code_table() = default; + + virtual ~code_table() = default; + + void build(const std::vector& keys) { + std::array counter; + for (std::uint32_t ch = 0; ch < 256; ++ch) { + counter[ch] = {static_cast(ch), 0}; + } + + m_max_length = 0; + for (const auto& key : keys) { + for (std::uint8_t ch : key) { + counter[ch].freq += 1; + } + m_max_length = std::max(m_max_length, key.length()); + } + + { + std::vector alphabet; + for (const auto& cf : counter) { + if (cf.freq != 0) { + alphabet.push_back(cf.ch); + } + } + m_alphabet.steal(alphabet); + } + + std::sort(counter.begin(), counter.end(), [](const cf_type& a, const cf_type& b) { return a.freq > b.freq; }); + + for (std::uint32_t ch = 0; ch < 256; ++ch) { + m_table[counter[ch].ch] = static_cast(ch); + } + for (std::uint32_t ch = 0; ch < 256; ++ch) { + m_table[m_table[ch] + 256] = static_cast(ch); + } + } + + inline std::uint64_t alphabet_size() const { + return m_alphabet.size(); + } + + inline std::uint64_t max_length() const { + return m_max_length; + } + + inline std::uint8_t get_code(char ch) const { + return m_table[static_cast(ch)]; + } + + inline char get_char(std::uint8_t cd) const { + return static_cast(m_table[cd + 256]); + } + + inline auto begin() const { + return m_alphabet.begin(); + } + + inline auto end() const { + return m_alphabet.end(); + } +}; + +} // namespace xcdat diff --git a/include/xcdat/compact_vector.hpp b/include/xcdat/compact_vector.hpp new file mode 100644 index 0000000..fac1d1b --- /dev/null +++ b/include/xcdat/compact_vector.hpp @@ -0,0 +1,71 @@ +#pragma once + +#include "mm_vector.hpp" +#include "utils.hpp" + +namespace xcdat { + +class compact_vector { + private: + std::uint64_t m_size = 0; + std::uint64_t m_bits = 0; + std::uint64_t m_mask = 0; + mm_vector m_chunks; + + public: + compact_vector() = default; + + template + compact_vector(const Vec& vec) { + build(vec); + } + + virtual ~compact_vector() = default; + + template + void build(const Vec& vec) { + const std::uint64_t maxv = *std::max_element(vec.begin(), vec.end()); + + m_size = vec.size(); + m_bits = utils::bits_for_int(maxv); + m_mask = (1ULL << m_bits) - 1; + + std::vector chunks(utils::words_for_bits(m_size * m_bits)); + + for (std::uint64_t i = 0; i < m_size; i++) { + const auto [quo, mod] = utils::decompose<64>(i * m_bits); + chunks[quo] &= ~(m_mask << mod); + chunks[quo] |= (vec[i] & m_mask) << mod; + if (64 < mod + m_bits) { + const std::uint64_t diff = 64ULL - mod; + chunks[quo + 1] &= ~(m_mask >> diff); + chunks[quo + 1] |= (vec[i] & m_mask) >> diff; + } + } + m_chunks.steal(chunks); + } + + inline std::uint64_t operator[](std::uint64_t i) const { + assert(i < m_size); + const auto [quo, mod] = utils::decompose<64>(i * m_bits); + if (mod + m_bits <= 64) { + return (m_chunks[quo] >> mod) & m_mask; + } else { + return ((m_chunks[quo] >> mod) | (m_chunks[quo + 1] << (64 - mod))) & m_mask; + } + } + + inline std::uint64_t size() const { + return m_size; + } + + inline std::uint64_t bits() const { + return m_bits; + } + + inline std::uint64_t memory_in_bytes() const { + return m_chunks.size() * sizeof(std::uint64_t); + } +}; + +} // namespace xcdat \ No newline at end of file diff --git a/include/xcdat/dac_bc.hpp b/include/xcdat/dac_bc.hpp new file mode 100644 index 0000000..b30da58 --- /dev/null +++ b/include/xcdat/dac_bc.hpp @@ -0,0 +1,123 @@ +#pragma once + +#include + +#include "bit_vector.hpp" +#include "compact_vector.hpp" + +namespace xcdat { + +class dac_bc { + public: + static constexpr std::uint32_t l1_bits = 8; + static constexpr std::uint32_t max_levels = sizeof(std::uint64_t); + + private: + std::uint32_t m_num_levels = 0; + std::uint64_t m_num_frees = 0; + std::array, max_levels> m_bytes; + std::array m_next_flags; + compact_vector m_links; + bit_vector m_leaf_flags; + + public: + dac_bc() = default; + + template + dac_bc(const std::vector& bc_units, bit_vector::builder&& leaf_flags) { + std::array, max_levels> bytes; + std::array next_flags; + std::vector links; + + bytes[0].reserve(bc_units.size() * 2); + next_flags[0].reserve(bc_units.size() * 2); + links.reserve(bc_units.size()); + + m_num_levels = 0; + + auto append_unit = [&](std::uint64_t x) { + std::uint32_t j = 0; + bytes[j].push_back(static_cast(x & 0xFF)); + next_flags[j].push_back(true); + x >>= 8; + while (x) { + ++j; + bytes[j].push_back(static_cast(x & 0xFF)); + next_flags[j].push_back(true); + x >>= 8; + } + next_flags[j].set_bit(next_flags[j].size() - 1, false); + m_num_levels = std::max(m_num_levels, j); + }; + + auto append_leaf = [&](std::uint64_t x) { + bytes[0].push_back(static_cast(x & 0xFF)); + next_flags[0].push_back(false); + links.push_back(x >> 8); + }; + + for (std::uint64_t i = 0; i < bc_units.size(); ++i) { + if (leaf_flags[i]) { + append_leaf(bc_units[i].base); + } else { + append_unit(bc_units[i].base ^ i); + } + append_unit(bc_units[i].check ^ i); + if (bc_units[i].check == i) { + m_num_frees += 1; + } + } + + // release + for (uint8_t i = 0; i < m_num_levels; ++i) { + m_bytes[i].steal(bytes[i]); + m_next_flags[i].build(next_flags[i], true, false); + } + m_bytes[m_num_levels].steal(bytes[m_num_levels]); + m_links.build(links); + m_leaf_flags.build(leaf_flags, true, false); + } + + virtual ~dac_bc() = default; + + inline std::uint64_t base(std::uint64_t i) const { + return access(i * 2) ^ i; + } + + inline std::uint64_t check(std::uint64_t i) const { + return access(i * 2 + 1) ^ i; + } + + inline std::uint64_t link(std::uint64_t i) const { + return m_bytes[0][i * 2] | (m_links[m_leaf_flags.rank(i)] << 8); + } + + inline bool is_leaf(std::uint64_t i) const { + return m_leaf_flags[i]; + } + + inline bool is_used(std::uint64_t i) const { + return check(i) != i; + } + + inline std::uint64_t num_units() const { + return m_bytes[0].size() / 2; + } + + inline std::uint64_t num_leaves() const { + return m_leaf_flags.num_ones(); + } + + private: + inline std::uint64_t access(std::uint64_t i) const { + std::uint32_t j = 0; + std::uint64_t x = m_bytes[j][i]; + while (j < m_num_levels and m_next_flags[j][i]) { + i = m_next_flags[j++].rank(i); + x |= static_cast(m_bytes[j][i]) << (j * 8); + } + return x; + } +}; + +} // namespace xcdat \ No newline at end of file diff --git a/include/xcdat/shared_tail.hpp b/include/xcdat/shared_tail.hpp new file mode 100644 index 0000000..6c7b7c0 --- /dev/null +++ b/include/xcdat/shared_tail.hpp @@ -0,0 +1,147 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include "bit_vector.hpp" +#include "mm_vector.hpp" + +namespace xcdat { + +//! TAIL implementation with suffix merge +class shared_tail { + public: + struct suffix_type { + std::string_view str; + std::uint64_t npos; + + inline char operator[](std::uint64_t i) const { + return str[length() - i - 1]; + } + inline std::uint64_t length() const { + return str.length(); + } + + inline const char* begin() const { + return str.data(); + } + inline const char* end() const { + return str.data() + str.length(); + } + + inline std::reverse_iterator rbegin() const { + return std::make_reverse_iterator(str.data() + str.length()); + } + inline std::reverse_iterator rend() const { + return std::make_reverse_iterator(str.data()); + } + }; + + class builder { + private: + // Buffer + std::vector m_suffixes; + + // Released + std::vector m_chars; + bit_vector::builder m_term_flags; + + public: + builder() = default; + + virtual ~builder() = default; + + void set_suffix(std::string_view str, std::uint64_t npos) { + m_suffixes.push_back({str, npos}); + } + + // setter(npos, tpos): Set units[npos].base = tpos. + void complete(bool bin_mode, const std::function& setter) { + std::sort(m_suffixes.begin(), m_suffixes.end(), [](const suffix_type& a, const suffix_type& b) { + return std::lexicographical_compare(std::rbegin(a), std::rend(a), std::rbegin(b), std::rend(b)); + }); + + // Sentinel for an empty suffix + m_chars.emplace_back('\0'); + if (bin_mode) { + m_term_flags.push_back(false); + } + + const suffix_type dmmy_suffix = {{nullptr, 0}, 0}; + const suffix_type* prev_suffix = &dmmy_suffix; + + std::uint64_t prev_tpos = 0; + + for (std::uint64_t i = m_suffixes.size(); i > 0; --i) { + const suffix_type& curr_suffix = m_suffixes[i - 1]; + if (curr_suffix.length() == 0) { + // throw TrieBuilder::Exception("A suffix is empty."); + } + + std::uint64_t match = 0; + while ((match < curr_suffix.length()) && (match < prev_suffix->length()) && + ((*prev_suffix)[match] == curr_suffix[match])) { + ++match; + } + + if ((match == curr_suffix.length()) && (prev_suffix->length() != 0)) { // sharable + setter(curr_suffix.npos, prev_tpos + (prev_suffix->length() - match)); + prev_tpos += prev_suffix->length() - match; + } else { // append + setter(curr_suffix.npos, m_chars.size()); + prev_tpos = m_chars.size(); + std::copy(curr_suffix.begin(), curr_suffix.end(), std::back_inserter(m_chars)); + if (bin_mode) { + for (std::uint64_t j = 1; j < curr_suffix.length(); ++j) { + m_term_flags.push_back(false); + } + m_term_flags.push_back(true); + } else { + m_chars.emplace_back('\0'); + } + } + + prev_suffix = &curr_suffix; + } + } + + friend class shared_tail; + }; + + private: + mm_vector m_chars; + bit_vector m_term_flags; + + public: + shared_tail() = default; + + explicit shared_tail(builder& b) { + m_chars.steal(b.m_chars); + m_term_flags.build(b.m_term_flags); + } + + ~shared_tail() = default; + + inline bool bin_mode() const { + return m_term_flags.size() == 0; + } + + inline bool match(std::string_view key, size_t tpos) const {} + + inline void decode(std::string& decoded, size_t tpos) const { + if (bin_mode()) { + do { + decoded.push_back(m_chars[tpos]); + } while (!m_term_flags[tpos++]); + } else { + do { + decoded.push_back(m_chars[tpos++]); + } while (m_chars[tpos]); + } + } +}; + +} // namespace xcdat \ No newline at end of file diff --git a/include/xcdat/trie.hpp b/include/xcdat/trie.hpp new file mode 100644 index 0000000..58adafe --- /dev/null +++ b/include/xcdat/trie.hpp @@ -0,0 +1,107 @@ +#pragma once + +#include +#include + +#include "dac_bc.hpp" +#include "trie_builder.hpp" + +namespace xcdat { + +class trie { + public: + private: + std::uint64_t m_num_keys = 0; + code_table m_table; + dac_bc m_bc; + bit_vector m_term_flags; + shared_tail m_tail; + + public: + trie() = default; + + virtual ~trie() = default; + + static trie build(const std::vector& keys, bool bin_mode = false) { + trie_builder b(keys, 8, bin_mode); + return trie(b); + } + + inline std::uint64_t num_keys() const { + return m_num_keys; + } + + inline bool bin_mode() const { + return m_tail.bin_mode(); + } + + inline std::uint64_t alphabet_size() const { + return m_table.alphabet_size(); + } + + inline std::uint64_t max_length() const { + return m_table.max_length(); + } + + inline std::optional lookup(std::string_view key) const { + std::uint64_t kpos = 0, npos = 0; + while (!m_bc.is_leaf(npos)) { + if (kpos == key.length()) { + if (m_term_flags[npos]) { + return npos_to_id(npos); + } + return std::nullopt; + } + const auto cpos = m_bc.base(npos) ^ m_table.get_code(key[kpos++]); + if (m_bc.check(cpos) != npos) { + return std::nullopt; + } + npos = cpos; + } + const std::uint64_t tpos = m_bc.link(npos); + if (!m_tail.match(key.substr(kpos, key.length() - kpos), tpos)) { + return std::nullopt; + } + return npos_to_id(npos); + } + + inline std::string access(std::uint64_t id) const { + if (num_keys() <= id) { + return {}; + } + + std::string decoded; + decoded.reserve(max_length()); + + auto npos = id_to_npos(id); + auto tpos = m_bc.is_leaf(npos) ? m_bc.link(npos) : UINT64_MAX; + + while (npos != 0) { + const auto ppos = m_bc.check(npos); + decoded.push_back(m_table.get_char(m_bc.base(ppos) ^ npos)); + npos = ppos; + } + + std::reverse(decoded.begin(), decoded.end()); + + if (tpos != 0 && tpos != UINT64_MAX) { + m_tail.decode(decoded, tpos); + } + return decoded; + } + + private: + trie(trie_builder& b) + : m_num_keys(b.m_keys.size()), m_table(b.m_table), m_bc(b.m_units, b.m_leaf_flags), + m_term_flags(b.m_term_flags, true, true), m_tail(b.m_suffixes) {} + + inline std::uint64_t npos_to_id(std::uint64_t npos) const { + return m_term_flags.rank(npos); + }; + + inline std::uint64_t id_to_npos(std::uint64_t id) const { + return m_term_flags.select(id); + }; +}; + +} // namespace xcdat \ No newline at end of file diff --git a/include/xcdat/trie_builder.hpp b/include/xcdat/trie_builder.hpp new file mode 100644 index 0000000..b712b11 --- /dev/null +++ b/include/xcdat/trie_builder.hpp @@ -0,0 +1,246 @@ +#pragma once + +#include +#include + +#include "code_table.hpp" +#include "dac_bc.hpp" +#include "shared_tail.hpp" + +namespace xcdat { + +class trie_builder { + public: + struct unit_type { + std::uint64_t base; + std::uint64_t check; + }; + + private: + static constexpr std::uint64_t taboo_npos = 1; + static constexpr std::uint64_t free_blocks = 16; + + const std::vector& m_keys; + const std::uint32_t m_l1_bits; // # of bits for L1 layer of DACs + const std::uint64_t m_l1_size; + + bool m_bin_mode = false; + + code_table m_table; + std::vector m_units; + bit_vector::builder m_leaf_flags; + bit_vector::builder m_term_flags; + bit_vector::builder m_used_flags; + std::vector m_heads; // for L1 blocks + std::vector m_edges; + shared_tail::builder m_suffixes; + + public: + trie_builder(const std::vector& keys, std::uint32_t l1_bits, bool bin_mode) + : m_keys(keys), m_l1_bits(l1_bits), m_l1_size(1ULL << l1_bits), m_bin_mode(bin_mode) { + if (m_keys.empty()) { + // throw TrieBuilder::Exception("The input data is empty."); + } + + // Reserve + { + std::uint64_t init_capa = 1; + while (init_capa < m_keys.size()) { + init_capa <<= 1; + } + m_units.reserve(init_capa); + m_leaf_flags.reserve(init_capa); + m_term_flags.reserve(init_capa); + m_used_flags.reserve(init_capa); + m_heads.reserve(init_capa >> m_l1_bits); + m_edges.reserve(256); + } + + // Initialize an empty list. + for (std::uint64_t npos = 0; npos < 256; ++npos) { + m_units.push_back(unit_type{npos + 1, npos - 1}); + m_leaf_flags.push_back(false); + m_term_flags.push_back(false); + m_used_flags.push_back(false); + } + m_units[255].base = 0; + m_units[0].check = 255; + + for (std::uint64_t npos = 0; npos < 256; npos += m_l1_size) { + m_heads.push_back(npos); + } + + // Fix the root + use_unit(0); + m_units[0].check = taboo_npos; + m_used_flags.set_bit(taboo_npos, true); + m_heads[taboo_npos >> m_l1_bits] = m_units[taboo_npos].base; + + // Build the code table + m_table.build(keys); + m_bin_mode |= (*m_table.begin() == '\0'); + + // Build the BC unites + arrange(0, m_keys.size(), 0, 0); + + m_suffixes.complete(m_bin_mode, [&](std::uint64_t npos, std::uint64_t tpos) { m_units[npos].base = tpos; }); + } + + virtual ~trie_builder() = default; + + private: + inline void use_unit(std::uint64_t npos) { + m_used_flags.set_bit(npos); + + const auto next = m_units[npos].base; + const auto prev = m_units[npos].check; + m_units[prev].base = next; + m_units[next].check = prev; + + const auto lpos = npos >> m_l1_bits; + if (m_heads[lpos] == npos) { + m_heads[lpos] = (lpos != next >> m_l1_bits) ? taboo_npos : next; + } + } + + inline void close_block(std::uint64_t bpos) { + const auto beg_npos = bpos * 256; + const auto end_npos = beg_npos + 256; + + for (auto npos = beg_npos; npos < end_npos; ++npos) { + if (!m_used_flags[npos]) { + use_unit(npos); + m_used_flags.set_bit(npos, false); + m_units[npos].base = npos; + m_units[npos].check = npos; + } + } + + for (auto npos = beg_npos; npos < end_npos; npos += m_l1_size) { + m_heads[npos >> m_l1_bits] = taboo_npos; + } + } + + void expand() { + const auto old_size = static_cast(m_units.size()); + const auto new_size = old_size + 256; + + for (auto npos = old_size; npos < new_size; ++npos) { + m_units.push_back({npos + 1, npos - 1}); + m_leaf_flags.push_back(false); + m_term_flags.push_back(false); + m_used_flags.push_back(false); + } + + { + const auto last_npos = m_units[taboo_npos].check; + m_units[old_size].check = last_npos; + m_units[last_npos].base = old_size; + m_units[new_size - 1].base = taboo_npos; + m_units[taboo_npos].check = new_size - 1; + } + + for (auto npos = old_size; npos < new_size; npos += m_l1_size) { + m_heads.push_back(npos); + } + + const auto bpos = old_size / 256; + if (free_blocks <= bpos) { + close_block(bpos - free_blocks); + } + } + + void arrange(std::uint64_t beg, std::uint64_t end, std::uint64_t depth, std::uint64_t npos) { + if (m_keys[beg].length() == depth) { + m_term_flags.set_bit(npos, true); + if (++beg == end) { // without link? + m_units[npos].base = 0; // with an empty suffix + m_leaf_flags.set_bit(npos, true); + return; + } + } else if (beg + 1 == end) { // leaf? + m_term_flags.set_bit(npos, true); + m_leaf_flags.set_bit(npos, true); + m_suffixes.set_suffix({m_keys[beg].data() + depth, m_keys[beg].length() - depth}, npos); + return; + } + + // fetching edges + { + m_edges.clear(); + auto ch = static_cast(m_keys[beg][depth]); + for (auto i = beg + 1; i < end; ++i) { + const auto next_ch = static_cast(m_keys[i][depth]); + if (ch != next_ch) { + if (next_ch < ch) { + // throw TrieBuilder::Exception("The input data is not in lexicographical order."); + } + m_edges.push_back(ch); + ch = next_ch; + } + } + m_edges.push_back(ch); + } + + const auto base = xcheck(npos >> m_l1_bits); + if (m_units.size() <= base) { + expand(); + } + + // defining new edges + m_units[npos].base = base; + for (const auto ch : m_edges) { + const auto child_id = base ^ m_table.get_code(ch); + use_unit(child_id); + m_units[child_id].check = npos; + } + + // following the children + auto i = beg; + auto ch = static_cast(m_keys[beg][depth]); + for (auto j = beg + 1; j < end; ++j) { + const auto next_ch = static_cast(m_keys[j][depth]); + if (ch != next_ch) { + arrange(i, j, depth + 1, base ^ m_table.get_code(ch)); + ch = next_ch; + i = j; + } + } + arrange(i, end, depth + 1, base ^ m_table.get_code(ch)); + } + + inline std::uint64_t xcheck(std::uint64_t lpos) const { + if (m_units[taboo_npos].base == taboo_npos) { // Full? + return m_units.size() ^ m_table.get_code(m_edges[0]); + } + + // search in the same L1 block + for (auto i = m_heads[lpos]; i != taboo_npos && i >> m_l1_bits == lpos; i = m_units[i].base) { + const auto base = i ^ m_table.get_code(m_edges[0]); + if (is_target(base)) { + return base; // base / block_size_ == lpos + } + } + + for (auto i = m_units[taboo_npos].base; i != taboo_npos; i = m_units[i].base) { + const auto base = i ^ m_table.get_code(m_edges[0]); + if (is_target(base)) { + return base; // base / block_size_ != lpos + } + } + return m_units.size() ^ m_table.get_code(m_edges[0]); + } + + inline bool is_target(std::uint64_t base) const { + for (const auto ch : m_edges) { + if (m_used_flags[base ^ m_table.get_code(ch)]) { + return false; + } + } + return true; + } + + friend class trie; +}; + +} // namespace xcdat \ No newline at end of file diff --git a/include/xcdat/utils.hpp b/include/xcdat/utils.hpp index d3e4d5e..8febf65 100644 --- a/include/xcdat/utils.hpp +++ b/include/xcdat/utils.hpp @@ -8,7 +8,7 @@ namespace xcdat::utils { template constexpr std::tuple decompose(std::uint64_t x) { - return std::make_tuple(x / N, x % N); + return {x / N, x % N}; } template diff --git a/test/test_bc.cpp b/test/test_bc.cpp new file mode 100644 index 0000000..522ce16 --- /dev/null +++ b/test/test_bc.cpp @@ -0,0 +1,72 @@ +#define DOCTEST_CONFIG_IMPLEMENT_WITH_MAIN + +#include +#include + +#include + +#include "doctest/doctest.h" + +struct bc_unit { + std::uint64_t base; + std::uint64_t check; +}; + +std::vector make_random_units(std::uint64_t n) { + static constexpr std::uint64_t seed = 13; + + std::mt19937_64 engine(seed); + std::uniform_int_distribution dist(0, n - 1); + + std::vector bc_units(n); + for (std::uint64_t i = 0; i < n; i++) { + bc_units[i].base = dist(engine); + bc_units[i].check = dist(engine); + } + return bc_units; +} + +std::vector make_random_bits(std::uint64_t n, double dens) { + static constexpr std::uint64_t seed = 17; + + std::mt19937_64 engine(seed); + std::uniform_real_distribution dist(0.0, 1.0); + + std::vector bits(n); + for (std::uint64_t i = 0; i < n; i++) { + bits[i] = dist(engine) < dens; + } + return bits; +} + +xcdat::bit_vector::builder to_bit_vector_builder(const std::vector& bits) { + xcdat::bit_vector::builder bvb(bits.size()); + for (std::uint64_t i = 0; i < bits.size(); i++) { + bvb.set_bit(i, bits[i]); + } + return bvb; +} + +std::uint64_t get_num_ones(const std::vector& bits) { + return std::accumulate(bits.begin(), bits.end(), 0ULL); +} + +TEST_CASE("Test xcdat::dac_bc") { + auto bc_units = make_random_units(10000); + auto leaf_flags = make_random_bits(10000, 0.2); + + xcdat::dac_bc bc(bc_units, to_bit_vector_builder(leaf_flags)); + + REQUIRE_EQ(bc.num_units(), bc_units.size()); + REQUIRE_EQ(bc.num_leaves(), get_num_ones(leaf_flags)); + + for (std::uint64_t i = 0; i < bc.num_units(); i++) { + REQUIRE_EQ(bc.is_leaf(i), leaf_flags[i]); + if (leaf_flags[i]) { + REQUIRE_EQ(bc.link(i), bc_units[i].base); + } else { + REQUIRE_EQ(bc.base(i), bc_units[i].base); + } + REQUIRE_EQ(bc.check(i), bc_units[i].check); + } +} \ No newline at end of file diff --git a/test/test_bit_vector.cpp b/test/test_bit_vector.cpp index 4ff425b..8046855 100644 --- a/test/test_bit_vector.cpp +++ b/test/test_bit_vector.cpp @@ -7,10 +7,12 @@ #include "doctest/doctest.h" -std::vector generate_random_bits(std::uint64_t n) { +std::vector make_random_bits(std::uint64_t n) { static constexpr std::uint64_t seed = 13; + std::vector bits; - std::mt19937 engine(seed); + std::mt19937_64 engine(seed); + for (std::uint64_t i = 0; i < n; i++) { bits.push_back(engine() & 1); } @@ -31,16 +33,15 @@ std::uint64_t select_naive(const std::vector& bits, std::uint64_t n) { if (bits[i]) { if (n == 0) { break; - } else { - n -= 1; } + n -= 1; } } return i; } TEST_CASE("Test bit_vector::builder with resize") { - const auto bits = generate_random_bits(10000); + const auto bits = make_random_bits(10000); xcdat::bit_vector::builder b; b.resize(bits.size()); @@ -56,7 +57,7 @@ TEST_CASE("Test bit_vector::builder with resize") { } TEST_CASE("Test bit_vector::builder with push_back") { - const auto bits = generate_random_bits(10000); + const auto bits = make_random_bits(10000); xcdat::bit_vector::builder b; b.reserve(bits.size()); @@ -73,7 +74,7 @@ TEST_CASE("Test bit_vector::builder with push_back") { } TEST_CASE("Test bit_vector") { - const auto bits = generate_random_bits(10000); + const auto bits = make_random_bits(10000); xcdat::bit_vector bv; { @@ -93,6 +94,7 @@ TEST_CASE("Test bit_vector") { static constexpr std::uint64_t seed = 17; std::mt19937_64 engine(seed); + { std::uniform_int_distribution dist(0, bv.size()); for (std::uint64_t r = 0; r < 100; r++) { diff --git a/test/test_compact_vector.cpp b/test/test_compact_vector.cpp new file mode 100644 index 0000000..55f0daa --- /dev/null +++ b/test/test_compact_vector.cpp @@ -0,0 +1,43 @@ +#define DOCTEST_CONFIG_IMPLEMENT_WITH_MAIN + +#include +#include + +#include + +#include "doctest/doctest.h" + +std::vector make_random_ints(std::uint64_t n) { + static constexpr std::uint64_t seed = 13; + + std::mt19937_64 engine(seed); + std::uniform_int_distribution dist(0, UINT16_MAX); + + std::vector ints(n); + for (std::uint64_t i = 0; i < n; i++) { + ints[i] = dist(engine); + } + return ints; +} + +TEST_CASE("Test xcdat::compact_vector (tiny)") { + std::vector ints = {2, 0, 14, 456, 32, 5544, 23}; + xcdat::compact_vector cv(ints); + + REQUIRE_EQ(cv.size(), ints.size()); + + for (std::uint64_t i = 0; i < ints.size(); i++) { + REQUIRE_EQ(cv[i], ints[i]); + } +} + +TEST_CASE("Test xcdat::compact_vector (random)") { + std::vector ints = make_random_ints(10000); + xcdat::compact_vector cv(ints); + + REQUIRE_EQ(cv.size(), ints.size()); + + for (std::uint64_t i = 0; i < ints.size(); i++) { + REQUIRE_EQ(cv[i], ints[i]); + } +}