add functions
This commit is contained in:
parent
d0bd44652e
commit
a6e7f6a07b
|
@ -36,4 +36,6 @@ include_directories(include)
|
|||
enable_testing()
|
||||
add_subdirectory(test)
|
||||
|
||||
add_subdirectory(sample)
|
||||
|
||||
file(COPY ${CMAKE_SOURCE_DIR}/test/keys.txt DESTINATION ${CMAKE_CURRENT_BINARY_DIR})
|
||||
|
|
|
@ -86,6 +86,14 @@ class code_table {
|
|||
inline auto end() const {
|
||||
return m_alphabet.end();
|
||||
}
|
||||
|
||||
inline auto rbegin() const {
|
||||
return m_alphabet.rbegin();
|
||||
}
|
||||
|
||||
inline auto rend() const {
|
||||
return m_alphabet.rend();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace xcdat
|
||||
|
|
|
@ -41,6 +41,14 @@ class mm_vector {
|
|||
return m_vec.end();
|
||||
}
|
||||
|
||||
inline auto rbegin() const {
|
||||
return m_vec.rbegin();
|
||||
}
|
||||
|
||||
inline auto rend() const {
|
||||
return m_vec.rend();
|
||||
}
|
||||
|
||||
inline const T& operator[](std::uint64_t i) const {
|
||||
return m_vec[i];
|
||||
}
|
||||
|
|
|
@ -131,7 +131,7 @@ class tail_vector {
|
|||
return m_term_flags.size() != 0;
|
||||
}
|
||||
|
||||
inline bool match(std::string_view key, size_t tpos) const {
|
||||
inline bool match(std::string_view key, std::uint64_t tpos) const {
|
||||
if (key.size() == 0) {
|
||||
return tpos == 0;
|
||||
}
|
||||
|
@ -162,7 +162,38 @@ class tail_vector {
|
|||
}
|
||||
}
|
||||
|
||||
inline void decode(size_t tpos, const std::function<void(char)>& fn) const {
|
||||
inline bool prefix_match(std::string_view key, std::uint64_t& tpos) const {
|
||||
if (key.size() == 0) {
|
||||
return true;
|
||||
}
|
||||
|
||||
std::uint64_t kpos = 0;
|
||||
|
||||
if (bin_mode()) {
|
||||
do {
|
||||
if (key[kpos] != m_chars[tpos]) {
|
||||
return false;
|
||||
}
|
||||
kpos += 1;
|
||||
if (m_term_flags[tpos]) {
|
||||
return kpos == key.size();
|
||||
}
|
||||
tpos += 1;
|
||||
} while (kpos < key.size());
|
||||
return true;
|
||||
} else {
|
||||
do {
|
||||
if (!m_chars[tpos] || key[kpos] != m_chars[tpos]) {
|
||||
return false;
|
||||
}
|
||||
kpos += 1;
|
||||
tpos += 1;
|
||||
} while (kpos < key.size());
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
inline void decode(std::uint64_t tpos, const std::function<void(char)>& fn) const {
|
||||
if (bin_mode()) {
|
||||
do {
|
||||
fn(m_chars[tpos]);
|
||||
|
|
|
@ -6,10 +6,14 @@
|
|||
|
||||
#include "bc_vector.hpp"
|
||||
#include "trie_builder.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
namespace xcdat {
|
||||
|
||||
class trie {
|
||||
public:
|
||||
using type = trie;
|
||||
|
||||
private:
|
||||
std::uint64_t m_num_keys = 0;
|
||||
code_table m_table;
|
||||
|
@ -22,7 +26,8 @@ class trie {
|
|||
|
||||
virtual ~trie() = default;
|
||||
|
||||
static trie build(const std::vector<std::string>& keys, bool bin_mode = false) {
|
||||
template <class Strings>
|
||||
static trie build(const Strings& keys, bool bin_mode = false) {
|
||||
trie_builder b(keys, bc_vector::l1_bits, bin_mode);
|
||||
return trie(b);
|
||||
}
|
||||
|
@ -52,7 +57,7 @@ class trie {
|
|||
}
|
||||
return npos_to_id(npos);
|
||||
}
|
||||
const auto cpos = m_bcvec.base(npos) ^ m_table.get_code(key[kpos++]);
|
||||
const std::uint64_t cpos = m_bcvec.base(npos) ^ m_table.get_code(key[kpos++]);
|
||||
if (m_bcvec.check(cpos) != npos) {
|
||||
return std::nullopt;
|
||||
}
|
||||
|
@ -60,7 +65,7 @@ class trie {
|
|||
}
|
||||
|
||||
const std::uint64_t tpos = m_bcvec.link(npos);
|
||||
if (!m_tvec.match(key.substr(kpos, key.size() - kpos), tpos)) {
|
||||
if (!m_tvec.match(utils::get_suffix(key, kpos), tpos)) {
|
||||
return std::nullopt;
|
||||
}
|
||||
return npos_to_id(npos);
|
||||
|
@ -84,15 +89,90 @@ class trie {
|
|||
}
|
||||
|
||||
std::reverse(decoded.begin(), decoded.end());
|
||||
|
||||
if (tpos != 0 && tpos != UINT64_MAX) {
|
||||
m_tvec.decode(tpos, [&](char c) { decoded.push_back(c); });
|
||||
}
|
||||
return decoded;
|
||||
}
|
||||
|
||||
class prefix_iterator {
|
||||
private:
|
||||
const type* m_obj = nullptr;
|
||||
std::string_view m_key;
|
||||
std::uint64_t m_id = 0;
|
||||
std::uint64_t m_kpos = 0;
|
||||
std::uint64_t m_npos = 0;
|
||||
bool is_beg = true;
|
||||
bool is_end = false;
|
||||
|
||||
public:
|
||||
prefix_iterator() = default;
|
||||
|
||||
inline bool next() {
|
||||
return m_obj != nullptr && m_obj->next_prefix(this);
|
||||
}
|
||||
|
||||
inline std::uint64_t id() const {
|
||||
return m_id;
|
||||
}
|
||||
inline std::string_view prefix() const {
|
||||
return {m_key.data(), m_kpos};
|
||||
}
|
||||
|
||||
private:
|
||||
prefix_iterator(const type* obj, std::string_view key) : m_obj(obj), m_key(key) {}
|
||||
|
||||
friend class trie;
|
||||
};
|
||||
|
||||
inline prefix_iterator make_prefix_iterator(std::string_view key) const {
|
||||
return prefix_iterator(this, key);
|
||||
}
|
||||
|
||||
class predictive_iterator {
|
||||
public:
|
||||
struct cursor_type {
|
||||
char label;
|
||||
std::uint64_t kpos;
|
||||
std::uint64_t npos;
|
||||
};
|
||||
|
||||
private:
|
||||
const type* m_obj = nullptr;
|
||||
std::string_view m_key;
|
||||
std::uint64_t m_id = 0;
|
||||
std::string m_prefix;
|
||||
std::vector<cursor_type> m_stack;
|
||||
bool is_beg = true;
|
||||
bool is_end = false;
|
||||
|
||||
public:
|
||||
predictive_iterator() = default;
|
||||
|
||||
inline bool next() {
|
||||
return m_obj != nullptr && m_obj->next_predictive(this);
|
||||
}
|
||||
|
||||
inline std::uint64_t id() const {
|
||||
return m_id;
|
||||
}
|
||||
inline std::string_view prefix() const {
|
||||
return m_prefix;
|
||||
}
|
||||
|
||||
private:
|
||||
predictive_iterator(const type* obj, std::string_view key) : m_obj(obj), m_key(key) {}
|
||||
|
||||
friend class trie;
|
||||
};
|
||||
|
||||
inline predictive_iterator make_predictive_iterator(std::string_view key) const {
|
||||
return predictive_iterator(this, key);
|
||||
}
|
||||
|
||||
private:
|
||||
trie(trie_builder& b)
|
||||
template <class Strings>
|
||||
trie(trie_builder<Strings>& b)
|
||||
: m_num_keys(b.m_keys.size()), m_table(b.m_table), m_terms(b.m_terms, true, true),
|
||||
m_bcvec(b.m_units, std::move(b.m_leaves)), m_tvec(b.m_suffixes) {}
|
||||
|
||||
|
@ -103,6 +183,130 @@ class trie {
|
|||
inline std::uint64_t id_to_npos(std::uint64_t id) const {
|
||||
return m_terms.select(id);
|
||||
};
|
||||
|
||||
inline bool next_prefix(prefix_iterator* itr) const {
|
||||
if (itr->is_end) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (itr->is_beg) {
|
||||
itr->is_beg = false;
|
||||
if (m_terms[itr->m_npos]) {
|
||||
itr->m_id = npos_to_id(itr->m_npos);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
while (!m_bcvec.is_leaf(itr->m_npos)) {
|
||||
const std::uint64_t cpos = m_bcvec.base(itr->m_npos) ^ m_table.get_code(itr->m_key[itr->m_kpos++]);
|
||||
if (m_bcvec.check(cpos) != itr->m_npos) {
|
||||
itr->is_end = true;
|
||||
itr->m_id = num_keys();
|
||||
return false;
|
||||
}
|
||||
itr->m_npos = cpos;
|
||||
if (!m_bcvec.is_leaf(itr->m_npos) && m_terms[itr->m_npos]) {
|
||||
itr->m_id = npos_to_id(itr->m_npos);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
itr->is_end = true;
|
||||
|
||||
const std::uint64_t tpos = m_bcvec.link(itr->m_npos);
|
||||
if (!m_tvec.match(utils::get_suffix(itr->m_key, itr->m_kpos), tpos)) {
|
||||
itr->m_id = num_keys();
|
||||
return false;
|
||||
}
|
||||
|
||||
itr->m_kpos = itr->m_key.size();
|
||||
itr->m_id = npos_to_id(itr->m_npos);
|
||||
return true;
|
||||
}
|
||||
|
||||
inline bool next_predictive(predictive_iterator* itr) const {
|
||||
if (itr->is_end) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (itr->is_beg) {
|
||||
itr->is_beg = false;
|
||||
|
||||
std::uint64_t kpos = 0;
|
||||
std::uint64_t npos = 0;
|
||||
|
||||
for (; kpos < itr->m_key.size(); ++kpos) {
|
||||
if (m_bcvec.is_leaf(npos)) {
|
||||
itr->is_end = true;
|
||||
|
||||
std::uint64_t tpos = m_bcvec.link(npos);
|
||||
if (tpos == 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!m_tvec.prefix_match(utils::get_suffix(itr->m_key, kpos), tpos)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
itr->m_id = npos_to_id(npos);
|
||||
m_tvec.decode(tpos, [&](char c) { itr->m_prefix.push_back(c); });
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
const std::uint64_t cpos = m_bcvec.base(npos) ^ m_table.get_code(itr->m_key[kpos]);
|
||||
|
||||
if (m_bcvec.check(cpos) != npos) {
|
||||
itr->is_end = true;
|
||||
return false;
|
||||
}
|
||||
|
||||
npos = cpos;
|
||||
itr->m_prefix.push_back(itr->m_key[kpos]);
|
||||
}
|
||||
|
||||
if (!itr->m_prefix.empty()) {
|
||||
itr->m_stack.push_back({itr->m_prefix.back(), kpos, npos});
|
||||
} else {
|
||||
itr->m_stack.push_back({'\0', kpos, npos});
|
||||
}
|
||||
}
|
||||
|
||||
while (!itr->m_stack.empty()) {
|
||||
const char label = itr->m_stack.back().label;
|
||||
const std::uint64_t kpos = itr->m_stack.back().kpos;
|
||||
const std::uint64_t npos = itr->m_stack.back().npos;
|
||||
|
||||
itr->m_stack.pop_back();
|
||||
|
||||
if (0 < kpos) {
|
||||
itr->m_prefix.resize(kpos);
|
||||
itr->m_prefix.back() = label;
|
||||
}
|
||||
|
||||
if (m_bcvec.is_leaf(npos)) {
|
||||
itr->m_id = npos_to_id(npos);
|
||||
m_tvec.decode(m_bcvec.link(npos), [&](char c) { itr->m_prefix.push_back(c); });
|
||||
return true;
|
||||
}
|
||||
|
||||
const std::uint64_t base = m_bcvec.base(npos);
|
||||
|
||||
for (auto cit = m_table.rbegin(); cit != m_table.rend(); ++cit) {
|
||||
const std::uint64_t cpos = base ^ m_table.get_code(*cit);
|
||||
if (m_bcvec.check(cpos) == npos) {
|
||||
itr->m_stack.push_back({static_cast<char>(*cit), kpos + 1, cpos});
|
||||
}
|
||||
}
|
||||
|
||||
if (m_terms[npos]) {
|
||||
itr->m_id = npos_to_id(npos);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
itr->is_end = true;
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace xcdat
|
|
@ -11,6 +11,7 @@
|
|||
|
||||
namespace xcdat {
|
||||
|
||||
template <class Strings>
|
||||
class trie_builder {
|
||||
public:
|
||||
struct unit_type {
|
||||
|
@ -22,7 +23,7 @@ class trie_builder {
|
|||
static constexpr std::uint64_t taboo_npos = 1;
|
||||
static constexpr std::uint64_t free_blocks = 16;
|
||||
|
||||
const std::vector<std::string>& m_keys;
|
||||
const Strings& m_keys;
|
||||
const std::uint32_t m_l1_bits; // # of bits for L1 layer of DACs
|
||||
const std::uint64_t m_l1_size;
|
||||
|
||||
|
@ -38,7 +39,7 @@ class trie_builder {
|
|||
tail_vector::builder m_suffixes;
|
||||
|
||||
public:
|
||||
trie_builder(const std::vector<std::string>& keys, std::uint32_t l1_bits, bool bin_mode)
|
||||
trie_builder(const Strings& keys, std::uint32_t l1_bits, bool bin_mode)
|
||||
: m_keys(keys), m_l1_bits(l1_bits), m_l1_size(1ULL << l1_bits), m_bin_mode(bin_mode) {
|
||||
XCDAT_THROW_IF(m_keys.size() == 0, "The input dataset is empty.");
|
||||
|
||||
|
|
|
@ -22,4 +22,9 @@ inline std::uint64_t needed_bits(std::uint64_t x) {
|
|||
return bit_tools::msb(x) + 1;
|
||||
}
|
||||
|
||||
template <class String>
|
||||
inline String get_suffix(const String& s, std::uint64_t i) {
|
||||
return s.substr(i, s.size() - i);
|
||||
}
|
||||
|
||||
} // namespace xcdat::utils
|
1
sample/CMakeLists.txt
Normal file
1
sample/CMakeLists.txt
Normal file
|
@ -0,0 +1 @@
|
|||
add_executable(sample sample.cpp)
|
44
sample/sample.cpp
Normal file
44
sample/sample.cpp
Normal file
|
@ -0,0 +1,44 @@
|
|||
#include <iostream>
|
||||
#include <string>
|
||||
|
||||
#include <xcdat.hpp>
|
||||
|
||||
int main() {
|
||||
std::vector<std::string> keys = {
|
||||
"Mac", "MacBook", "MacBook_Air", "MacBook_Pro", "Mac_Pro", "iMac", "Mac_Mini",
|
||||
};
|
||||
|
||||
// The dataset must be sorted and unique.
|
||||
std::sort(keys.begin(), keys.end());
|
||||
keys.erase(std::unique(keys.begin(), keys.end()), keys.end());
|
||||
|
||||
auto trie = xcdat::trie::build(keys);
|
||||
|
||||
std::cout << "Basic operations" << std::endl;
|
||||
{
|
||||
const auto id = trie.lookup("MacBook_Pro");
|
||||
if (id.has_value()) {
|
||||
std::cout << trie.access(id.value()) << " -> " << id.has_value() << std::endl;
|
||||
} else {
|
||||
std::cout << "Not found" << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
std::cout << "Common prefix search" << std::endl;
|
||||
{
|
||||
auto itr = trie.make_prefix_iterator("MacBook_Air");
|
||||
while (itr.next()) {
|
||||
std::cout << itr.prefix() << " -> " << itr.id() << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
std::cout << "Predictive search" << std::endl;
|
||||
{
|
||||
auto itr = trie.make_predictive_iterator("Mac");
|
||||
while (itr.next()) {
|
||||
std::cout << itr.prefix() << " -> " << itr.id() << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
Loading…
Reference in a new issue