From a6e7f6a07bf4661db33a82fb9f44bb96a8075e84 Mon Sep 17 00:00:00 2001 From: Shunsuke Kanda Date: Sun, 27 Jun 2021 01:40:11 +0900 Subject: [PATCH] add functions --- CMakeLists.txt | 2 + include/xcdat/code_table.hpp | 8 ++ include/xcdat/mm_vector.hpp | 8 ++ include/xcdat/tail_vector.hpp | 35 +++++- include/xcdat/trie.hpp | 214 ++++++++++++++++++++++++++++++++- include/xcdat/trie_builder.hpp | 5 +- include/xcdat/utils.hpp | 5 + sample/CMakeLists.txt | 1 + sample/sample.cpp | 44 +++++++ 9 files changed, 313 insertions(+), 9 deletions(-) create mode 100644 sample/CMakeLists.txt create mode 100644 sample/sample.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 69ba4ba..023fa4c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -36,4 +36,6 @@ include_directories(include) enable_testing() add_subdirectory(test) +add_subdirectory(sample) + file(COPY ${CMAKE_SOURCE_DIR}/test/keys.txt DESTINATION ${CMAKE_CURRENT_BINARY_DIR}) diff --git a/include/xcdat/code_table.hpp b/include/xcdat/code_table.hpp index 883292a..b90f01d 100644 --- a/include/xcdat/code_table.hpp +++ b/include/xcdat/code_table.hpp @@ -86,6 +86,14 @@ class code_table { inline auto end() const { return m_alphabet.end(); } + + inline auto rbegin() const { + return m_alphabet.rbegin(); + } + + inline auto rend() const { + return m_alphabet.rend(); + } }; } // namespace xcdat diff --git a/include/xcdat/mm_vector.hpp b/include/xcdat/mm_vector.hpp index 04a7f5f..b1b095d 100644 --- a/include/xcdat/mm_vector.hpp +++ b/include/xcdat/mm_vector.hpp @@ -41,6 +41,14 @@ class mm_vector { return m_vec.end(); } + inline auto rbegin() const { + return m_vec.rbegin(); + } + + inline auto rend() const { + return m_vec.rend(); + } + inline const T& operator[](std::uint64_t i) const { return m_vec[i]; } diff --git a/include/xcdat/tail_vector.hpp b/include/xcdat/tail_vector.hpp index 1764417..f495973 100644 --- a/include/xcdat/tail_vector.hpp +++ b/include/xcdat/tail_vector.hpp @@ -131,7 +131,7 @@ class tail_vector { return m_term_flags.size() != 0; } - inline bool match(std::string_view key, size_t tpos) const { + inline bool match(std::string_view key, std::uint64_t tpos) const { if (key.size() == 0) { return tpos == 0; } @@ -162,7 +162,38 @@ class tail_vector { } } - inline void decode(size_t tpos, const std::function& fn) const { + inline bool prefix_match(std::string_view key, std::uint64_t& tpos) const { + if (key.size() == 0) { + return true; + } + + std::uint64_t kpos = 0; + + if (bin_mode()) { + do { + if (key[kpos] != m_chars[tpos]) { + return false; + } + kpos += 1; + if (m_term_flags[tpos]) { + return kpos == key.size(); + } + tpos += 1; + } while (kpos < key.size()); + return true; + } else { + do { + if (!m_chars[tpos] || key[kpos] != m_chars[tpos]) { + return false; + } + kpos += 1; + tpos += 1; + } while (kpos < key.size()); + return true; + } + } + + inline void decode(std::uint64_t tpos, const std::function& fn) const { if (bin_mode()) { do { fn(m_chars[tpos]); diff --git a/include/xcdat/trie.hpp b/include/xcdat/trie.hpp index 670643c..16db615 100644 --- a/include/xcdat/trie.hpp +++ b/include/xcdat/trie.hpp @@ -6,10 +6,14 @@ #include "bc_vector.hpp" #include "trie_builder.hpp" +#include "utils.hpp" namespace xcdat { class trie { + public: + using type = trie; + private: std::uint64_t m_num_keys = 0; code_table m_table; @@ -22,7 +26,8 @@ class trie { virtual ~trie() = default; - static trie build(const std::vector& keys, bool bin_mode = false) { + template + static trie build(const Strings& keys, bool bin_mode = false) { trie_builder b(keys, bc_vector::l1_bits, bin_mode); return trie(b); } @@ -52,7 +57,7 @@ class trie { } return npos_to_id(npos); } - const auto cpos = m_bcvec.base(npos) ^ m_table.get_code(key[kpos++]); + const std::uint64_t cpos = m_bcvec.base(npos) ^ m_table.get_code(key[kpos++]); if (m_bcvec.check(cpos) != npos) { return std::nullopt; } @@ -60,7 +65,7 @@ class trie { } const std::uint64_t tpos = m_bcvec.link(npos); - if (!m_tvec.match(key.substr(kpos, key.size() - kpos), tpos)) { + if (!m_tvec.match(utils::get_suffix(key, kpos), tpos)) { return std::nullopt; } return npos_to_id(npos); @@ -84,15 +89,90 @@ class trie { } std::reverse(decoded.begin(), decoded.end()); - if (tpos != 0 && tpos != UINT64_MAX) { m_tvec.decode(tpos, [&](char c) { decoded.push_back(c); }); } return decoded; } + class prefix_iterator { + private: + const type* m_obj = nullptr; + std::string_view m_key; + std::uint64_t m_id = 0; + std::uint64_t m_kpos = 0; + std::uint64_t m_npos = 0; + bool is_beg = true; + bool is_end = false; + + public: + prefix_iterator() = default; + + inline bool next() { + return m_obj != nullptr && m_obj->next_prefix(this); + } + + inline std::uint64_t id() const { + return m_id; + } + inline std::string_view prefix() const { + return {m_key.data(), m_kpos}; + } + + private: + prefix_iterator(const type* obj, std::string_view key) : m_obj(obj), m_key(key) {} + + friend class trie; + }; + + inline prefix_iterator make_prefix_iterator(std::string_view key) const { + return prefix_iterator(this, key); + } + + class predictive_iterator { + public: + struct cursor_type { + char label; + std::uint64_t kpos; + std::uint64_t npos; + }; + + private: + const type* m_obj = nullptr; + std::string_view m_key; + std::uint64_t m_id = 0; + std::string m_prefix; + std::vector m_stack; + bool is_beg = true; + bool is_end = false; + + public: + predictive_iterator() = default; + + inline bool next() { + return m_obj != nullptr && m_obj->next_predictive(this); + } + + inline std::uint64_t id() const { + return m_id; + } + inline std::string_view prefix() const { + return m_prefix; + } + + private: + predictive_iterator(const type* obj, std::string_view key) : m_obj(obj), m_key(key) {} + + friend class trie; + }; + + inline predictive_iterator make_predictive_iterator(std::string_view key) const { + return predictive_iterator(this, key); + } + private: - trie(trie_builder& b) + template + trie(trie_builder& b) : m_num_keys(b.m_keys.size()), m_table(b.m_table), m_terms(b.m_terms, true, true), m_bcvec(b.m_units, std::move(b.m_leaves)), m_tvec(b.m_suffixes) {} @@ -103,6 +183,130 @@ class trie { inline std::uint64_t id_to_npos(std::uint64_t id) const { return m_terms.select(id); }; + + inline bool next_prefix(prefix_iterator* itr) const { + if (itr->is_end) { + return false; + } + + if (itr->is_beg) { + itr->is_beg = false; + if (m_terms[itr->m_npos]) { + itr->m_id = npos_to_id(itr->m_npos); + return true; + } + } + + while (!m_bcvec.is_leaf(itr->m_npos)) { + const std::uint64_t cpos = m_bcvec.base(itr->m_npos) ^ m_table.get_code(itr->m_key[itr->m_kpos++]); + if (m_bcvec.check(cpos) != itr->m_npos) { + itr->is_end = true; + itr->m_id = num_keys(); + return false; + } + itr->m_npos = cpos; + if (!m_bcvec.is_leaf(itr->m_npos) && m_terms[itr->m_npos]) { + itr->m_id = npos_to_id(itr->m_npos); + return true; + } + } + itr->is_end = true; + + const std::uint64_t tpos = m_bcvec.link(itr->m_npos); + if (!m_tvec.match(utils::get_suffix(itr->m_key, itr->m_kpos), tpos)) { + itr->m_id = num_keys(); + return false; + } + + itr->m_kpos = itr->m_key.size(); + itr->m_id = npos_to_id(itr->m_npos); + return true; + } + + inline bool next_predictive(predictive_iterator* itr) const { + if (itr->is_end) { + return false; + } + + if (itr->is_beg) { + itr->is_beg = false; + + std::uint64_t kpos = 0; + std::uint64_t npos = 0; + + for (; kpos < itr->m_key.size(); ++kpos) { + if (m_bcvec.is_leaf(npos)) { + itr->is_end = true; + + std::uint64_t tpos = m_bcvec.link(npos); + if (tpos == 0) { + return false; + } + + if (!m_tvec.prefix_match(utils::get_suffix(itr->m_key, kpos), tpos)) { + return false; + } + + itr->m_id = npos_to_id(npos); + m_tvec.decode(tpos, [&](char c) { itr->m_prefix.push_back(c); }); + + return true; + } + + const std::uint64_t cpos = m_bcvec.base(npos) ^ m_table.get_code(itr->m_key[kpos]); + + if (m_bcvec.check(cpos) != npos) { + itr->is_end = true; + return false; + } + + npos = cpos; + itr->m_prefix.push_back(itr->m_key[kpos]); + } + + if (!itr->m_prefix.empty()) { + itr->m_stack.push_back({itr->m_prefix.back(), kpos, npos}); + } else { + itr->m_stack.push_back({'\0', kpos, npos}); + } + } + + while (!itr->m_stack.empty()) { + const char label = itr->m_stack.back().label; + const std::uint64_t kpos = itr->m_stack.back().kpos; + const std::uint64_t npos = itr->m_stack.back().npos; + + itr->m_stack.pop_back(); + + if (0 < kpos) { + itr->m_prefix.resize(kpos); + itr->m_prefix.back() = label; + } + + if (m_bcvec.is_leaf(npos)) { + itr->m_id = npos_to_id(npos); + m_tvec.decode(m_bcvec.link(npos), [&](char c) { itr->m_prefix.push_back(c); }); + return true; + } + + const std::uint64_t base = m_bcvec.base(npos); + + for (auto cit = m_table.rbegin(); cit != m_table.rend(); ++cit) { + const std::uint64_t cpos = base ^ m_table.get_code(*cit); + if (m_bcvec.check(cpos) == npos) { + itr->m_stack.push_back({static_cast(*cit), kpos + 1, cpos}); + } + } + + if (m_terms[npos]) { + itr->m_id = npos_to_id(npos); + return true; + } + } + + itr->is_end = true; + return false; + } }; } // namespace xcdat \ No newline at end of file diff --git a/include/xcdat/trie_builder.hpp b/include/xcdat/trie_builder.hpp index 1e157e0..91ced73 100644 --- a/include/xcdat/trie_builder.hpp +++ b/include/xcdat/trie_builder.hpp @@ -11,6 +11,7 @@ namespace xcdat { +template class trie_builder { public: struct unit_type { @@ -22,7 +23,7 @@ class trie_builder { static constexpr std::uint64_t taboo_npos = 1; static constexpr std::uint64_t free_blocks = 16; - const std::vector& m_keys; + const Strings& m_keys; const std::uint32_t m_l1_bits; // # of bits for L1 layer of DACs const std::uint64_t m_l1_size; @@ -38,7 +39,7 @@ class trie_builder { tail_vector::builder m_suffixes; public: - trie_builder(const std::vector& keys, std::uint32_t l1_bits, bool bin_mode) + trie_builder(const Strings& keys, std::uint32_t l1_bits, bool bin_mode) : m_keys(keys), m_l1_bits(l1_bits), m_l1_size(1ULL << l1_bits), m_bin_mode(bin_mode) { XCDAT_THROW_IF(m_keys.size() == 0, "The input dataset is empty."); diff --git a/include/xcdat/utils.hpp b/include/xcdat/utils.hpp index bb9ec56..7fe37ca 100644 --- a/include/xcdat/utils.hpp +++ b/include/xcdat/utils.hpp @@ -22,4 +22,9 @@ inline std::uint64_t needed_bits(std::uint64_t x) { return bit_tools::msb(x) + 1; } +template +inline String get_suffix(const String& s, std::uint64_t i) { + return s.substr(i, s.size() - i); +} + } // namespace xcdat::utils \ No newline at end of file diff --git a/sample/CMakeLists.txt b/sample/CMakeLists.txt new file mode 100644 index 0000000..1709b59 --- /dev/null +++ b/sample/CMakeLists.txt @@ -0,0 +1 @@ +add_executable(sample sample.cpp) diff --git a/sample/sample.cpp b/sample/sample.cpp new file mode 100644 index 0000000..c4c64bb --- /dev/null +++ b/sample/sample.cpp @@ -0,0 +1,44 @@ +#include +#include + +#include + +int main() { + std::vector keys = { + "Mac", "MacBook", "MacBook_Air", "MacBook_Pro", "Mac_Pro", "iMac", "Mac_Mini", + }; + + // The dataset must be sorted and unique. + std::sort(keys.begin(), keys.end()); + keys.erase(std::unique(keys.begin(), keys.end()), keys.end()); + + auto trie = xcdat::trie::build(keys); + + std::cout << "Basic operations" << std::endl; + { + const auto id = trie.lookup("MacBook_Pro"); + if (id.has_value()) { + std::cout << trie.access(id.value()) << " -> " << id.has_value() << std::endl; + } else { + std::cout << "Not found" << std::endl; + } + } + + std::cout << "Common prefix search" << std::endl; + { + auto itr = trie.make_prefix_iterator("MacBook_Air"); + while (itr.next()) { + std::cout << itr.prefix() << " -> " << itr.id() << std::endl; + } + } + + std::cout << "Predictive search" << std::endl; + { + auto itr = trie.make_predictive_iterator("Mac"); + while (itr.next()) { + std::cout << itr.prefix() << " -> " << itr.id() << std::endl; + } + } + + return 0; +}