add functions

This commit is contained in:
Shunsuke Kanda 2021-06-27 01:40:11 +09:00
parent d0bd44652e
commit a6e7f6a07b
9 changed files with 313 additions and 9 deletions

View file

@ -36,4 +36,6 @@ include_directories(include)
enable_testing()
add_subdirectory(test)
add_subdirectory(sample)
file(COPY ${CMAKE_SOURCE_DIR}/test/keys.txt DESTINATION ${CMAKE_CURRENT_BINARY_DIR})

View file

@ -86,6 +86,14 @@ class code_table {
inline auto end() const {
return m_alphabet.end();
}
inline auto rbegin() const {
return m_alphabet.rbegin();
}
inline auto rend() const {
return m_alphabet.rend();
}
};
} // namespace xcdat

View file

@ -41,6 +41,14 @@ class mm_vector {
return m_vec.end();
}
inline auto rbegin() const {
return m_vec.rbegin();
}
inline auto rend() const {
return m_vec.rend();
}
inline const T& operator[](std::uint64_t i) const {
return m_vec[i];
}

View file

@ -131,7 +131,7 @@ class tail_vector {
return m_term_flags.size() != 0;
}
inline bool match(std::string_view key, size_t tpos) const {
inline bool match(std::string_view key, std::uint64_t tpos) const {
if (key.size() == 0) {
return tpos == 0;
}
@ -162,7 +162,38 @@ class tail_vector {
}
}
inline void decode(size_t tpos, const std::function<void(char)>& fn) const {
inline bool prefix_match(std::string_view key, std::uint64_t& tpos) const {
if (key.size() == 0) {
return true;
}
std::uint64_t kpos = 0;
if (bin_mode()) {
do {
if (key[kpos] != m_chars[tpos]) {
return false;
}
kpos += 1;
if (m_term_flags[tpos]) {
return kpos == key.size();
}
tpos += 1;
} while (kpos < key.size());
return true;
} else {
do {
if (!m_chars[tpos] || key[kpos] != m_chars[tpos]) {
return false;
}
kpos += 1;
tpos += 1;
} while (kpos < key.size());
return true;
}
}
inline void decode(std::uint64_t tpos, const std::function<void(char)>& fn) const {
if (bin_mode()) {
do {
fn(m_chars[tpos]);

View file

@ -6,10 +6,14 @@
#include "bc_vector.hpp"
#include "trie_builder.hpp"
#include "utils.hpp"
namespace xcdat {
class trie {
public:
using type = trie;
private:
std::uint64_t m_num_keys = 0;
code_table m_table;
@ -22,7 +26,8 @@ class trie {
virtual ~trie() = default;
static trie build(const std::vector<std::string>& keys, bool bin_mode = false) {
template <class Strings>
static trie build(const Strings& keys, bool bin_mode = false) {
trie_builder b(keys, bc_vector::l1_bits, bin_mode);
return trie(b);
}
@ -52,7 +57,7 @@ class trie {
}
return npos_to_id(npos);
}
const auto cpos = m_bcvec.base(npos) ^ m_table.get_code(key[kpos++]);
const std::uint64_t cpos = m_bcvec.base(npos) ^ m_table.get_code(key[kpos++]);
if (m_bcvec.check(cpos) != npos) {
return std::nullopt;
}
@ -60,7 +65,7 @@ class trie {
}
const std::uint64_t tpos = m_bcvec.link(npos);
if (!m_tvec.match(key.substr(kpos, key.size() - kpos), tpos)) {
if (!m_tvec.match(utils::get_suffix(key, kpos), tpos)) {
return std::nullopt;
}
return npos_to_id(npos);
@ -84,15 +89,90 @@ class trie {
}
std::reverse(decoded.begin(), decoded.end());
if (tpos != 0 && tpos != UINT64_MAX) {
m_tvec.decode(tpos, [&](char c) { decoded.push_back(c); });
}
return decoded;
}
class prefix_iterator {
private:
const type* m_obj = nullptr;
std::string_view m_key;
std::uint64_t m_id = 0;
std::uint64_t m_kpos = 0;
std::uint64_t m_npos = 0;
bool is_beg = true;
bool is_end = false;
public:
prefix_iterator() = default;
inline bool next() {
return m_obj != nullptr && m_obj->next_prefix(this);
}
inline std::uint64_t id() const {
return m_id;
}
inline std::string_view prefix() const {
return {m_key.data(), m_kpos};
}
private:
prefix_iterator(const type* obj, std::string_view key) : m_obj(obj), m_key(key) {}
friend class trie;
};
inline prefix_iterator make_prefix_iterator(std::string_view key) const {
return prefix_iterator(this, key);
}
class predictive_iterator {
public:
struct cursor_type {
char label;
std::uint64_t kpos;
std::uint64_t npos;
};
private:
const type* m_obj = nullptr;
std::string_view m_key;
std::uint64_t m_id = 0;
std::string m_prefix;
std::vector<cursor_type> m_stack;
bool is_beg = true;
bool is_end = false;
public:
predictive_iterator() = default;
inline bool next() {
return m_obj != nullptr && m_obj->next_predictive(this);
}
inline std::uint64_t id() const {
return m_id;
}
inline std::string_view prefix() const {
return m_prefix;
}
private:
predictive_iterator(const type* obj, std::string_view key) : m_obj(obj), m_key(key) {}
friend class trie;
};
inline predictive_iterator make_predictive_iterator(std::string_view key) const {
return predictive_iterator(this, key);
}
private:
trie(trie_builder& b)
template <class Strings>
trie(trie_builder<Strings>& b)
: m_num_keys(b.m_keys.size()), m_table(b.m_table), m_terms(b.m_terms, true, true),
m_bcvec(b.m_units, std::move(b.m_leaves)), m_tvec(b.m_suffixes) {}
@ -103,6 +183,130 @@ class trie {
inline std::uint64_t id_to_npos(std::uint64_t id) const {
return m_terms.select(id);
};
inline bool next_prefix(prefix_iterator* itr) const {
if (itr->is_end) {
return false;
}
if (itr->is_beg) {
itr->is_beg = false;
if (m_terms[itr->m_npos]) {
itr->m_id = npos_to_id(itr->m_npos);
return true;
}
}
while (!m_bcvec.is_leaf(itr->m_npos)) {
const std::uint64_t cpos = m_bcvec.base(itr->m_npos) ^ m_table.get_code(itr->m_key[itr->m_kpos++]);
if (m_bcvec.check(cpos) != itr->m_npos) {
itr->is_end = true;
itr->m_id = num_keys();
return false;
}
itr->m_npos = cpos;
if (!m_bcvec.is_leaf(itr->m_npos) && m_terms[itr->m_npos]) {
itr->m_id = npos_to_id(itr->m_npos);
return true;
}
}
itr->is_end = true;
const std::uint64_t tpos = m_bcvec.link(itr->m_npos);
if (!m_tvec.match(utils::get_suffix(itr->m_key, itr->m_kpos), tpos)) {
itr->m_id = num_keys();
return false;
}
itr->m_kpos = itr->m_key.size();
itr->m_id = npos_to_id(itr->m_npos);
return true;
}
inline bool next_predictive(predictive_iterator* itr) const {
if (itr->is_end) {
return false;
}
if (itr->is_beg) {
itr->is_beg = false;
std::uint64_t kpos = 0;
std::uint64_t npos = 0;
for (; kpos < itr->m_key.size(); ++kpos) {
if (m_bcvec.is_leaf(npos)) {
itr->is_end = true;
std::uint64_t tpos = m_bcvec.link(npos);
if (tpos == 0) {
return false;
}
if (!m_tvec.prefix_match(utils::get_suffix(itr->m_key, kpos), tpos)) {
return false;
}
itr->m_id = npos_to_id(npos);
m_tvec.decode(tpos, [&](char c) { itr->m_prefix.push_back(c); });
return true;
}
const std::uint64_t cpos = m_bcvec.base(npos) ^ m_table.get_code(itr->m_key[kpos]);
if (m_bcvec.check(cpos) != npos) {
itr->is_end = true;
return false;
}
npos = cpos;
itr->m_prefix.push_back(itr->m_key[kpos]);
}
if (!itr->m_prefix.empty()) {
itr->m_stack.push_back({itr->m_prefix.back(), kpos, npos});
} else {
itr->m_stack.push_back({'\0', kpos, npos});
}
}
while (!itr->m_stack.empty()) {
const char label = itr->m_stack.back().label;
const std::uint64_t kpos = itr->m_stack.back().kpos;
const std::uint64_t npos = itr->m_stack.back().npos;
itr->m_stack.pop_back();
if (0 < kpos) {
itr->m_prefix.resize(kpos);
itr->m_prefix.back() = label;
}
if (m_bcvec.is_leaf(npos)) {
itr->m_id = npos_to_id(npos);
m_tvec.decode(m_bcvec.link(npos), [&](char c) { itr->m_prefix.push_back(c); });
return true;
}
const std::uint64_t base = m_bcvec.base(npos);
for (auto cit = m_table.rbegin(); cit != m_table.rend(); ++cit) {
const std::uint64_t cpos = base ^ m_table.get_code(*cit);
if (m_bcvec.check(cpos) == npos) {
itr->m_stack.push_back({static_cast<char>(*cit), kpos + 1, cpos});
}
}
if (m_terms[npos]) {
itr->m_id = npos_to_id(npos);
return true;
}
}
itr->is_end = true;
return false;
}
};
} // namespace xcdat

View file

@ -11,6 +11,7 @@
namespace xcdat {
template <class Strings>
class trie_builder {
public:
struct unit_type {
@ -22,7 +23,7 @@ class trie_builder {
static constexpr std::uint64_t taboo_npos = 1;
static constexpr std::uint64_t free_blocks = 16;
const std::vector<std::string>& m_keys;
const Strings& m_keys;
const std::uint32_t m_l1_bits; // # of bits for L1 layer of DACs
const std::uint64_t m_l1_size;
@ -38,7 +39,7 @@ class trie_builder {
tail_vector::builder m_suffixes;
public:
trie_builder(const std::vector<std::string>& keys, std::uint32_t l1_bits, bool bin_mode)
trie_builder(const Strings& keys, std::uint32_t l1_bits, bool bin_mode)
: m_keys(keys), m_l1_bits(l1_bits), m_l1_size(1ULL << l1_bits), m_bin_mode(bin_mode) {
XCDAT_THROW_IF(m_keys.size() == 0, "The input dataset is empty.");

View file

@ -22,4 +22,9 @@ inline std::uint64_t needed_bits(std::uint64_t x) {
return bit_tools::msb(x) + 1;
}
template <class String>
inline String get_suffix(const String& s, std::uint64_t i) {
return s.substr(i, s.size() - i);
}
} // namespace xcdat::utils

1
sample/CMakeLists.txt Normal file
View file

@ -0,0 +1 @@
add_executable(sample sample.cpp)

44
sample/sample.cpp Normal file
View file

@ -0,0 +1,44 @@
#include <iostream>
#include <string>
#include <xcdat.hpp>
int main() {
std::vector<std::string> keys = {
"Mac", "MacBook", "MacBook_Air", "MacBook_Pro", "Mac_Pro", "iMac", "Mac_Mini",
};
// The dataset must be sorted and unique.
std::sort(keys.begin(), keys.end());
keys.erase(std::unique(keys.begin(), keys.end()), keys.end());
auto trie = xcdat::trie::build(keys);
std::cout << "Basic operations" << std::endl;
{
const auto id = trie.lookup("MacBook_Pro");
if (id.has_value()) {
std::cout << trie.access(id.value()) << " -> " << id.has_value() << std::endl;
} else {
std::cout << "Not found" << std::endl;
}
}
std::cout << "Common prefix search" << std::endl;
{
auto itr = trie.make_prefix_iterator("MacBook_Air");
while (itr.next()) {
std::cout << itr.prefix() << " -> " << itr.id() << std::endl;
}
}
std::cout << "Predictive search" << std::endl;
{
auto itr = trie.make_predictive_iterator("Mac");
while (itr.next()) {
std::cout << itr.prefix() << " -> " << itr.id() << std::endl;
}
}
return 0;
}