diff --git a/include/xcdat/tail_vector.hpp b/include/xcdat/tail_vector.hpp index fea2033..b51d295 100644 --- a/include/xcdat/tail_vector.hpp +++ b/include/xcdat/tail_vector.hpp @@ -2,6 +2,7 @@ #include #include +#include #include #include #include @@ -166,31 +167,42 @@ class tail_vector { } } - inline bool prefix_match(std::string_view key, std::uint64_t tpos) const { - assert(key.size() != 0); - std::uint64_t kpos = 0; + // Returns epos-tpos+1 if TAIL[tpos..epos] is a prefix of key. + inline std::optional prefix_match(std::string_view key, std::uint64_t tpos) const { + if (tpos == 0) { + // suffix is empty, always matched. + return 0; + } + if (key.size() == 0) { + // When key is empty, match fails since the suffix is not empty here. + return std::nullopt; + } + std::uint64_t kpos = 0; if (bin_mode()) { do { if (key[kpos] != m_chars[tpos]) { - return false; + return std::nullopt; } kpos += 1; if (m_terms[tpos]) { - return kpos == key.size(); + return kpos; } tpos += 1; } while (kpos < key.size()); - return true; + return kpos; } else { do { - if (!m_chars[tpos] || key[kpos] != m_chars[tpos]) { - return false; + if (!m_chars[tpos]) { + return kpos; + } + if (key[kpos] != m_chars[tpos]) { + return std::nullopt; } kpos += 1; tpos += 1; } while (kpos < key.size()); - return true; + return kpos; } } diff --git a/include/xcdat/trie.hpp b/include/xcdat/trie.hpp index af7b9d5..07e21bc 100644 --- a/include/xcdat/trie.hpp +++ b/include/xcdat/trie.hpp @@ -166,7 +166,7 @@ class trie { class prefix_iterator { private: const trie_type* m_obj = nullptr; - std::string m_key; + std::string_view m_key; std::uint64_t m_id = 0; std::uint64_t m_kpos = 0; std::uint64_t m_npos = 0; @@ -231,7 +231,7 @@ class trie { private: const trie_type* m_obj = nullptr; - std::string m_key; + std::string_view m_key; std::uint64_t m_id = 0; std::string m_decoded; std::vector m_stack; @@ -317,8 +317,7 @@ class trie { : m_num_keys(b.m_keys.size()), m_table(std::move(b.m_table)), m_terms(b.m_terms, true, true), m_bcvec(b.m_units, std::move(b.m_leaves)), m_tvec(std::move(b.m_suffixes)) {} - template - static constexpr String get_suffix(const String& s, std::uint64_t i) { + static constexpr std::string_view get_suffix(std::string_view s, std::uint64_t i) { assert(i <= s.size()); return s.substr(i, s.size() - i); } @@ -376,12 +375,12 @@ class trie { itr->is_end = true; const std::uint64_t tpos = m_bcvec.link(itr->m_npos); - if (!m_tvec.match(get_suffix(itr->m_key, itr->m_kpos), tpos)) { + const auto matched = m_tvec.prefix_match(get_suffix(itr->m_key, itr->m_kpos), tpos); + if (!matched.has_value()) { itr->m_id = num_keys(); return false; } - - itr->m_kpos = itr->m_key.size(); + itr->m_kpos += matched.value(); itr->m_id = npos_to_id(itr->m_npos); return true; } diff --git a/tests/test_trie.cpp b/tests/test_trie.cpp index f1f48ee..ac5a4a5 100644 --- a/tests/test_trie.cpp +++ b/tests/test_trie.cpp @@ -167,7 +167,7 @@ TEST_CASE("Test " TRIE_NAME " (tiny)") { test_basic_operations(trie, keys, others); { - auto itr = trie.make_prefix_iterator("MacBook_Pro"); + auto itr = trie.make_prefix_iterator("MacBook_Pro_13inch"); std::vector expected = {"Mac", "MacBook", "MacBook_Pro"}; for (const auto& exp : expected) { REQUIRE(itr.next());