Change prefix-based operations into the iterator versions

This commit is contained in:
kampersanda 2017-11-11 21:00:42 +09:00
parent cc965a853e
commit f720e3b039

View file

@ -1,6 +1,7 @@
#ifndef XCDAT_TRIE_HPP_ #ifndef XCDAT_TRIE_HPP_
#define XCDAT_TRIE_HPP_ #define XCDAT_TRIE_HPP_
#include <Trie.hpp>
#include "DacBc.hpp" #include "DacBc.hpp"
#include "FastDacBc.hpp" #include "FastDacBc.hpp"
@ -14,38 +15,185 @@ constexpr auto kNotFound = kIdMax;
template<bool Fast> template<bool Fast>
class Trie { class Trie {
public: public:
using Type = Trie<Fast>; using type = Trie<Fast>;
using BcType = typename std::conditional<Fast, FastDacBc, DacBc>::type; using bc_type = typename std::conditional<Fast, FastDacBc, DacBc>::type;
// Generic constructor. // Generic constructor.
Trie() = default; Trie() = default;
// Reads the dictionary from an std::istream. // Reads the dictionary from an std::istream.
explicit Trie(std::istream& is); explicit Trie(std::istream& is) {
bc_ = bc_type(is);
terminal_flags_ = BitVector(is);
tail_ = Vector<uint8_t>(is);
boundary_flags_ = BitVector(is);
alphabet_ = Vector<uint8_t>(is);
is.read(reinterpret_cast<char*>(table_), 512);
num_keys_ = read_value<size_t>(is);
max_length_ = read_value<size_t>(is);
binary_mode_ = read_value<bool>(is);
}
// Generic destructor. // Generic destructor.
~Trie() = default; ~Trie() = default;
// Lookups the ID of a given key. If the key is not registered, otherwise // Lookups the ID of a given key. If the key is not registered, otherwise
// returns kNotFound. // returns kNotFound.
id_type lookup(const uint8_t* key, size_t length) const; id_type lookup(const uint8_t* key, size_t length) const {
size_t pos = 0;
id_type node_id = 0;
while (!bc_.is_leaf(node_id)) {
if (pos == length) {
return terminal_flags_[node_id] ? to_key_id_(node_id) : kNotFound;
}
const auto child_id = bc_.base(node_id) ^table_[key[pos++]];
if (bc_.check(child_id) != node_id) {
return kNotFound;
}
node_id = child_id;
}
size_t tail_pos = bc_.link(node_id);
if (!match_suffix_(key, length, pos, tail_pos)) {
return kNotFound;
}
return to_key_id_(node_id);
}
// Decodes the key associated with a given ID. The decoded key is appended to // Decodes the key associated with a given ID. The decoded key is appended to
// 'ret' and its length is returned. // 'ret' and its length is returned.
size_t access(id_type id, std::vector<uint8_t>& ret) const; size_t access(id_type id, std::vector<uint8_t>& ret) const {
if (num_keys_ <= id) {
return 0;
}
// Returns the IDs of keys included as prefixes of a given key. The IDs are auto orig_size = ret.size();
// appended to 'ids' and the number is returned. By using 'limit', you can ret.reserve(orig_size + max_length_);
// restrict the maximum number of returned IDs.
size_t common_prefix_lookup(const uint8_t* key, size_t length,
std::vector<id_type>& ids,
size_t limit = std::numeric_limits<size_t>::max()) const;
// Returns the IDs of keys starting with a given key. The IDs are appended to auto node_id = to_node_id_(id);
// 'ids' and the number is returned. By using 'limit', you can restrict the auto tail_pos = bc_.is_leaf(node_id) ? bc_.link(node_id) : kNotFound;
// maximum number of returned IDs.
size_t predictive_lookup(const uint8_t* key, size_t length, std::vector<id_type>& ids, while (node_id) {
size_t limit = std::numeric_limits<size_t>::max()) const; const auto parent_id = bc_.check(node_id);
ret.push_back(edge_(parent_id, node_id));
node_id = parent_id;
}
std::reverse(std::begin(ret) + orig_size, std::end(ret));
if (tail_pos != 0 && tail_pos != kNotFound) {
if (binary_mode_) {
do {
ret.push_back(tail_[tail_pos]);
} while (!boundary_flags_[tail_pos++]);
} else {
do {
ret.push_back(tail_[tail_pos++]);
} while (tail_[tail_pos]);
}
}
return ret.size() - orig_size;
}
// Iterator class for enumerating the keys and IDs included as prefixes of a
// given key, that is, supporting so-called common prefix lookup. It is
// created by using make_prefix_iterator().
class PrefixIterator {
public:
PrefixIterator() = default;
// Scans the next key and ID
bool next() {
return trie_ != nullptr && trie_->next_prefix_(this);
}
// Gets the key
std::pair<const uint8_t*, size_t> key() const {
return {key_, pos_};
};
// Gets the ID
id_type id() const {
return trie_->to_key_id_(node_id_);
}
private:
const type* trie_ {};
const uint8_t* key_ {};
const size_t length_ {};
size_t pos_ {0};
id_type node_id_ {0};
bool begin_flag_ {true};
bool end_flag_ {false};
PrefixIterator(const type* trie, const uint8_t* key, size_t length)
: trie_{trie}, key_{key}, length_{length} {}
friend class Trie;
};
// Makes PrefixIterator from a given key
PrefixIterator make_prefix_iterator(const uint8_t* key, size_t length) const {
return PrefixIterator{this, key, length};
}
// Iterator class for enumerating the keys and IDs starting with prefixes of a
// given key, that is, supporting so-called predictive lookup. It is created
// by using make_predictive_iterator().
class PredictiveIterator {
public:
PredictiveIterator() = default;
// Scans the next key and ID
bool next() {
return trie_ != nullptr && trie_->next_predictive_(this);
}
// Gets the key
std::pair<const uint8_t*, size_t> key() const {
return {buf_.data(), buf_.size()};
};
// Gets the ID
id_type id() const {
return id_;
}
private:
const type* trie_ {};
const uint8_t* key_ {};
const size_t length_ {};
bool begin_flag_ {true};
bool end_flag_ {false};
struct entry {
id_type node_id;
size_t depth;
uint8_t c;
};
std::vector<entry> stack_ {};
std::vector<uint8_t> buf_ {};
id_type id_ {};
PredictiveIterator(const type* trie, const uint8_t* key, size_t length)
: trie_{trie}, key_{key}, length_{length} {
buf_.reserve(trie->max_length());
}
friend class Trie;
};
PredictiveIterator
make_predictive_iterator(const uint8_t* key, size_t length) const {
return {this, key, length};
}
// Gets the number of registered keys in the dictionary // Gets the number of registered keys in the dictionary
size_t num_keys() const { size_t num_keys() const {
@ -83,13 +231,50 @@ public:
} }
// Computes the output dictionary size in bytes. // Computes the output dictionary size in bytes.
size_t size_in_bytes() const; size_t size_in_bytes() const {
size_t ret = 0;
ret += bc_.size_in_bytes();
ret += terminal_flags_.size_in_bytes();
ret += tail_.size_in_bytes();
ret += boundary_flags_.size_in_bytes();
ret += alphabet_.size_in_bytes();
ret += sizeof(table_);
ret += sizeof(num_keys_);
ret += sizeof(max_length_);
ret += sizeof(binary_mode_);
return ret;
}
// Reports the dictionary statistics into an ostream. // Reports the dictionary statistics into an ostream.
void show_stat(std::ostream& os) const; void show_stat(std::ostream& os) const {
const auto total_size = size_in_bytes();
os << "basic statistics of xcdat::Trie" << std::endl;
show_size("\tnum keys: ", num_keys(), os);
show_size("\talphabet size: ", alphabet_size(), os);
show_size("\tnum nodes: ", num_nodes(), os);
show_size("\tnum used nodes:", num_used_nodes(), os);
show_size("\tnum free nodes:", num_free_nodes(), os);
show_size("\tsize in bytes: ", size_in_bytes(), os);
os << "member size statistics of xcdat::Trie" << std::endl;
show_size_ratio("\tbc: ", bc_.size_in_bytes(), total_size, os);
show_size_ratio("\tterminal_flags:", terminal_flags_.size_in_bytes(), total_size, os);
show_size_ratio("\ttail: ", tail_.size_in_bytes(), total_size, os);
show_size_ratio("\tboundary_flags:", boundary_flags_.size_in_bytes(), total_size, os);
bc_.show_stat(os);
}
// Writes the dictionary into an ostream. // Writes the dictionary into an ostream.
void write(std::ostream& os) const; void write(std::ostream& os) const {
bc_.write(os);
terminal_flags_.write(os);
tail_.write(os);
boundary_flags_.write(os);
alphabet_.write(os);
os.write(reinterpret_cast<const char*>(table_), 512);
write_value(num_keys_, os);
write_value(max_length_, os);
write_value(binary_mode_, os);
}
// Disallows copy and assignment. // Disallows copy and assignment.
Trie(const Trie&) = delete; Trie(const Trie&) = delete;
@ -99,7 +284,7 @@ public:
Trie& operator=(Trie&&) noexcept = default; Trie& operator=(Trie&&) noexcept = default;
private: private:
BcType bc_ {}; bc_type bc_ {};
BitVector terminal_flags_ {}; BitVector terminal_flags_ {};
Vector<uint8_t> tail_ {}; Vector<uint8_t> tail_ {};
BitVector boundary_flags_ {}; // used if binary_mode_ == true BitVector boundary_flags_ {}; // used if binary_mode_ == true
@ -113,16 +298,15 @@ private:
id_type to_key_id_(id_type node_id) const { id_type to_key_id_(id_type node_id) const {
return terminal_flags_.rank(node_id); return terminal_flags_.rank(node_id);
}; };
id_type to_node_id_(id_type string_id) const { id_type to_node_id_(id_type string_id) const {
return terminal_flags_.select(string_id); return terminal_flags_.select(string_id);
}; };
uint8_t edge_(id_type node_id, id_type child_id) const { uint8_t edge_(id_type node_id, id_type child_id) const {
return table_[static_cast<uint8_t>(bc_.base(node_id) ^ child_id) + 256]; return table_[static_cast<uint8_t>(bc_.base(node_id) ^ child_id) + 256];
} }
bool match_(const uint8_t* key, size_t length, size_t pos, size_t tail_pos) const { bool match_suffix_(const uint8_t* key, size_t length, size_t pos,
size_t tail_pos) const {
assert(pos <= length); assert(pos <= length);
if (pos == length) { if (pos == length) {
@ -153,266 +337,166 @@ private:
} }
} }
bool prefix_match_(const uint8_t* key, size_t length, void extract_suffix_(size_t tail_pos, std::vector<uint8_t>& dec) const {
size_t pos, size_t tail_pos) const { if (binary_mode_) {
assert(pos < length); if (tail_pos != 0) {
do {
dec.push_back(tail_[tail_pos]);
} while (!boundary_flags_[tail_pos++]);
}
} else {
while (tail_[tail_pos] != '\0') {
dec.push_back(tail_[tail_pos]);
++tail_pos;
}
}
}
if (tail_pos == 0) { bool next_prefix_(PrefixIterator* it) const {
if (it->end_flag_) {
return false; return false;
} }
if (binary_mode_) { if (it->begin_flag_) {
do { it->begin_flag_ = false;
if (key[pos] != tail_[tail_pos]) { if (terminal_flags_[it->node_id_]) {
return false; return true;
} }
++pos;
if (boundary_flags_[tail_pos]) {
return pos == length;
}
++tail_pos;
} while (pos < length);
} else {
do {
if (key[pos] != tail_[tail_pos] || !tail_[tail_pos]) {
return false;
}
++pos;
++tail_pos;
} while (pos < length);
} }
while (!bc_.is_leaf(it->node_id_)) {
id_type child_id = bc_.base(it->node_id_) ^table_[it->key_[it->pos_++]];
if (bc_.check(child_id) != it->node_id_) {
it->end_flag_ = true;
it->node_id_ = kNotFound;
return false;
}
it->node_id_ = child_id;
if (!bc_.is_leaf(it->node_id_) && terminal_flags_[it->node_id_]) {
return true;
}
}
it->end_flag_ = true;
size_t tail_pos = bc_.link(it->node_id_);
if (!match_suffix_(it->key_, it->length_, it->pos_, tail_pos)) {
it->node_id_ = kNotFound;
return false;
}
it->pos_ = it->length_;
return true; return true;
} }
bool next_predictive_(PredictiveIterator* it) const {
if (it->end_flag_) {
return false;
}
if (it->begin_flag_) {
it->begin_flag_ = false;
id_type node_id = 0;
size_t pos = 0;
for (; pos < it->length_; ++pos) {
if (bc_.is_leaf(node_id)) {
it->end_flag_ = true;
size_t tail_pos = bc_.link(node_id);
if (tail_pos == 0) {
return false;
}
if (binary_mode_) {
do {
if (it->key_[pos] != tail_[tail_pos]) {
return false;
}
it->buf_.push_back(it->key_[pos++]);
if (boundary_flags_[tail_pos]) {
if (pos == it->length_) {
it->id_ = to_key_id_(node_id);
return true;
}
return false;
}
++tail_pos;
} while (pos < it->length_);
} else {
do {
if (it->key_[pos] != tail_[tail_pos] || !tail_[tail_pos]) {
return false;
}
it->buf_.push_back(it->key_[pos++]);
++tail_pos;
} while (pos < it->length_);
}
it->id_ = to_key_id_(node_id);
extract_suffix_(tail_pos, it->buf_);
return true;
}
id_type child_id = bc_.base(node_id) ^table_[it->key_[pos]];
if (bc_.check(child_id) != node_id) {
it->end_flag_ = true;
return false;
}
node_id = child_id;
it->buf_.push_back(it->key_[pos]);
}
if (!it->buf_.empty()) {
it->stack_.push_back({node_id, pos, it->buf_.back()});
} else {
it->stack_.push_back({node_id, pos});
}
}
while (!it->stack_.empty()) {
id_type node_id = it->stack_.back().node_id;
size_t depth = it->stack_.back().depth;
uint8_t c = it->stack_.back().c;
it->stack_.pop_back();
if (0 < depth) {
it->buf_.resize(depth);
it->buf_.back() = c;
}
if (bc_.is_leaf(node_id)) {
it->id_ = to_key_id_(node_id);
extract_suffix_(bc_.link(node_id), it->buf_);
return true;
}
const id_type base = bc_.base(node_id);
// For lex sort
for (int i = static_cast<int>(alphabet_.size()) - 1; i >= 0; --i) {
const id_type child_id = base ^table_[alphabet_[i]];
if (bc_.check(child_id) == node_id) {
it->stack_.push_back({child_id, depth + 1, alphabet_[i]});
}
}
if (terminal_flags_[node_id]) {
it->id_ = to_key_id_(node_id);
return true;
}
}
it->end_flag_ = true;
return false;
}
friend class TrieBuilder; friend class TrieBuilder;
}; };
template<bool Fast>
Trie<Fast>::Trie(std::istream& is) {
bc_ = BcType(is);
terminal_flags_ = BitVector(is);
tail_ = Vector<uint8_t>(is);
boundary_flags_ = BitVector(is);
alphabet_ = Vector<uint8_t>(is);
is.read(reinterpret_cast<char*>(table_), 512);
num_keys_ = read_value<size_t>(is);
max_length_ = read_value<size_t>(is);
binary_mode_ = read_value<bool>(is);
}
template<bool Fast>
id_type Trie<Fast>::lookup(const uint8_t* key, size_t length) const {
size_t pos = 0;
id_type node_id = 0;
while (!bc_.is_leaf(node_id)) {
if (pos == length) {
return terminal_flags_[node_id] ? to_key_id_(node_id) : kNotFound;
}
const auto child_id = bc_.base(node_id) ^table_[key[pos++]];
if (bc_.check(child_id) != node_id) {
return kNotFound;
}
node_id = child_id;
}
size_t tail_pos = bc_.link(node_id);
if (!match_(key, length, pos, tail_pos)) {
return kNotFound;
}
return to_key_id_(node_id);
}
template<bool Fast>
size_t Trie<Fast>::access(id_type id, std::vector<uint8_t>& ret) const {
if (num_keys_ <= id) {
return 0;
}
auto orig_size = ret.size();
ret.reserve(orig_size + max_length_);
auto node_id = to_node_id_(id);
auto tail_pos = bc_.is_leaf(node_id) ? bc_.link(node_id) : kNotFound;
while (node_id) {
const auto parent_id = bc_.check(node_id);
ret.push_back(edge_(parent_id, node_id));
node_id = parent_id;
}
std::reverse(std::begin(ret) + orig_size, std::end(ret));
if (tail_pos != 0 && tail_pos != kNotFound) {
if (binary_mode_) {
do {
ret.push_back(tail_[tail_pos]);
} while (!boundary_flags_[tail_pos++]);
} else {
do {
ret.push_back(tail_[tail_pos++]);
} while (tail_[tail_pos]);
}
}
return ret.size() - orig_size;
}
template<bool Fast>
size_t Trie<Fast>::common_prefix_lookup(const uint8_t* key, size_t length,
std::vector<id_type>& ids,
size_t limit) const {
if (limit == 0) {
return 0;
}
size_t pos = 0, count = 0;
id_type node_id = 0;
while (!bc_.is_leaf(node_id)) {
if (terminal_flags_[node_id]) {
ids.push_back(to_key_id_(node_id));
if (limit <= ++count) {
return count;
}
}
if (pos == length) {
return count;
}
const auto child_id = bc_.base(node_id) ^table_[key[pos++]];
if (bc_.check(child_id) != node_id) {
return count;
}
node_id = child_id;
}
size_t tail_pos = bc_.link(node_id);
if (match_(key, length, pos, tail_pos)) {
ids.push_back(to_key_id_(node_id));
++count;
}
return count;
}
template<bool Fast>
size_t Trie<Fast>::predictive_lookup(const uint8_t* key, size_t length,
std::vector<id_type>& ids,
size_t limit) const {
if (limit == 0) {
return 0;
}
size_t pos = 0;
id_type node_id = 0;
for (; pos < length; ++pos) {
if (bc_.is_leaf(node_id)) {
size_t tail_pos = bc_.link(node_id);
if (!prefix_match_(key, length, pos, tail_pos)) {
return 0;
}
ids.push_back(to_key_id_(node_id));
return 1;
}
const auto child_id = bc_.base(node_id) ^table_[key[pos]];
if (bc_.check(child_id) != node_id) {
return 0;
}
node_id = child_id;
}
size_t count = 0;
std::vector<std::pair<id_type, size_t>> stack;
stack.emplace_back(std::make_pair(node_id, pos));
while (!stack.empty()) {
node_id = stack.back().first;
pos = stack.back().second;
stack.pop_back();
if (bc_.is_leaf(node_id)) {
ids.push_back(to_key_id_(node_id));
if (limit <= ++count) {
break;
}
} else {
if (terminal_flags_[node_id]) {
ids.push_back(to_key_id_(node_id));
if (limit <= ++count) {
break;
}
}
const auto base = bc_.base(node_id);
for (const auto label : alphabet_) {
const auto child_id = base ^table_[label];
if (bc_.check(child_id) == node_id) {
stack.push_back({child_id, pos + 1});
}
}
}
}
return count;
}
template<bool Fast>
size_t Trie<Fast>::size_in_bytes() const {
size_t ret = 0;
ret += bc_.size_in_bytes();
ret += terminal_flags_.size_in_bytes();
ret += tail_.size_in_bytes();
ret += boundary_flags_.size_in_bytes();
ret += alphabet_.size_in_bytes();
ret += sizeof(table_);
ret += sizeof(num_keys_);
ret += sizeof(max_length_);
ret += sizeof(binary_mode_);
return ret;
}
template<bool Fast>
void Trie<Fast>::show_stat(std::ostream& os) const {
const auto total_size = size_in_bytes();
os << "basic statistics of xcdat::Trie" << std::endl;
show_size("\tnum keys: ", num_keys(), os);
show_size("\talphabet size: ", alphabet_size(), os);
show_size("\tnum nodes: ", num_nodes(), os);
show_size("\tnum used nodes:", num_used_nodes(), os);
show_size("\tnum free nodes:", num_free_nodes(), os);
show_size("\tsize in bytes: ", size_in_bytes(), os);
os << "member size statistics of xcdat::Trie" << std::endl;
show_size_ratio("\tbc: ", bc_.size_in_bytes(), total_size, os);
show_size_ratio("\tterminal_flags:", terminal_flags_.size_in_bytes(), total_size, os);
show_size_ratio("\ttail: ", tail_.size_in_bytes(), total_size, os);
show_size_ratio("\tboundary_flags:", boundary_flags_.size_in_bytes(), total_size, os);
bc_.show_stat(os);
}
template<bool Fast>
void Trie<Fast>::write(std::ostream& os) const {
bc_.write(os);
terminal_flags_.write(os);
tail_.write(os);
boundary_flags_.write(os);
alphabet_.write(os);
os.write(reinterpret_cast<const char*>(table_), 512);
write_value(num_keys_, os);
write_value(max_length_, os);
write_value(binary_mode_, os);
}
} //namespace - xcdat } //namespace - xcdat
#endif //XCDAT_TRIE_HPP_ #endif //XCDAT_TRIE_HPP_