diff --git a/.gitignore b/.gitignore index 690aeaf..965d9d3 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,7 @@ data fasttext result +build/ +python/fasttext.egg-info/ +python/fasttext_pybind.cpython-37m-x86_64-linux-gnu.so + diff --git a/Makefile b/Makefile index ea5c78f..12c4f5f 100644 --- a/Makefile +++ b/Makefile @@ -7,11 +7,12 @@ # CXX = c++ -CXXFLAGS = -pthread -std=c++0x -march=native +CXXFLAGS = -Wall -pthread -std=c++14 -march=native -ffast-math -Wsuggest-final-methods -Wsuggest-override -Wodr -flto -ftree-loop-linear -floop-strip-mine -floop-block + OBJS = args.o matrix.o dictionary.o loss.o productquantizer.o densematrix.o quantmatrix.o vector.o model.o utils.o meter.o fasttext.o INCLUDES = -I. -opt: CXXFLAGS += -O3 -funroll-loops +opt: CXXFLAGS += -O3 -funroll-loops -DNDEBUG opt: fasttext coverage: CXXFLAGS += -O0 -fno-inline -fprofile-arcs --coverage diff --git a/python/fastText/FastText.py b/python/fastText/FastText.py index afa9e85..9d530b6 100644 --- a/python/fastText/FastText.py +++ b/python/fastText/FastText.py @@ -139,9 +139,26 @@ def check(entry): text = check(text) predictions = self.f.predict(text, k, threshold, on_unicode_error) probs, labels = zip(*predictions) - return labels, np.array(probs, copy=False) + def predict_all(self, text, on_unicode_error='strict'): + def check(entry): + if entry.find('\n') != -1: + raise ValueError( + "predict processes one line at a time (remove \'\\n\')" + ) + entry += "\n" + return entry + + if type(text) is list: + text = [check(entry) for entry in text] + predictions = self.f.multilinePredictAll(text) + return np.array(predictions, dtype=float) + else: + text = check(text) + probs = self.f.predictAll(text) + return np.array(probs, copy=False) + def get_input_matrix(self): """ Get a copy of the full input matrix of a Model. This only diff --git a/python/fastText/pybind/fasttext_pybind.cc b/python/fastText/pybind/fasttext_pybind.cc index 76c25b8..dba074f 100644 --- a/python/fastText/pybind/fasttext_pybind.cc +++ b/python/fastText/pybind/fasttext_pybind.cc @@ -15,6 +15,7 @@ #include #include #include +#include #include #include @@ -39,7 +40,7 @@ py::str castToPythonString(const std::string& s, const char* onUnicodeError) { std::pair, std::vector> getLineText( fasttext::FastText& m, - const std::string text, + const std::string& text, const char* onUnicodeError) { std::shared_ptr d = m.getDictionary(); std::stringstream ioss(text); @@ -60,7 +61,7 @@ std::pair, std::vector> getLineText( if (token == fasttext::Dictionary::EOS) break; } - return std::pair, std::vector>(words, labels); + return {std::move(words), std::move(labels)}; } PYBIND11_MODULE(fasttext_pybind, m) { @@ -159,13 +160,13 @@ PYBIND11_MODULE(fasttext_pybind, m) { }) .def( "loadModel", - [](fasttext::FastText& m, std::string s) { m.loadModel(s); }) + [](fasttext::FastText& m, const std::string& s) { m.loadModel(s); }) .def( "saveModel", - [](fasttext::FastText& m, std::string s) { m.saveModel(s); }) + [](fasttext::FastText& m, const std::string& s) { m.saveModel(s); }) .def( "test", - [](fasttext::FastText& m, const std::string filename, int32_t k) { + [](fasttext::FastText& m, const std::string& filename, int32_t k) { std::ifstream ifs(filename); if (!ifs.is_open()) { throw std::invalid_argument("Test file cannot be opened!"); @@ -180,13 +181,13 @@ PYBIND11_MODULE(fasttext_pybind, m) { "getSentenceVector", [](fasttext::FastText& m, fasttext::Vector& v, - const std::string text) { + const std::string& text) { std::stringstream ioss(text); m.getSentenceVector(ioss, v); }) .def( "tokenize", - [](fasttext::FastText& m, const std::string text) { + [](fasttext::FastText& m, const std::string& text) { std::vector text_split; std::shared_ptr d = m.getDictionary(); std::stringstream ioss(text); @@ -202,53 +203,57 @@ PYBIND11_MODULE(fasttext_pybind, m) { .def( "multilineGetLine", [](fasttext::FastText& m, - const std::vector lines, + const std::vector& lines, const char* onUnicodeError) { std::shared_ptr d = m.getDictionary(); + std::vector> all_words; + all_words.reserve(lines.size()); std::vector> all_labels; + all_labels.reserve(lines.size()); + for (const auto& text : lines) { auto pair = getLineText(m, text, onUnicodeError); - all_words.push_back(pair.first); - all_labels.push_back(pair.second); + all_words.push_back(std::move(pair.first)); + all_labels.push_back(std::move(pair.second)); } return std::pair< std::vector>, - std::vector>>(all_words, all_labels); + std::vector>>(std::move(all_words), std::move(all_labels)); }) .def( "getVocab", [](fasttext::FastText& m, const char* onUnicodeError) { py::str s; - std::vector vocab_list; - std::vector vocab_freq; std::shared_ptr d = m.getDictionary(); - vocab_freq = d->getCounts(fasttext::entry_type::word); + std::vector vocab_freq = d->getCounts(fasttext::entry_type::word); + std::vector vocab_list; + vocab_list.reserve(vocab_freq.size()); for (int32_t i = 0; i < vocab_freq.size(); i++) { - vocab_list.push_back( - castToPythonString(d->getWord(i), onUnicodeError)); + vocab_list.push_back(castToPythonString(d->getWord(i), onUnicodeError)); } return std::pair, std::vector>( - vocab_list, vocab_freq); + std::move(vocab_list), std::move(vocab_freq)); }) .def( "getLabels", [](fasttext::FastText& m, const char* onUnicodeError) { - std::vector labels_list; - std::vector labels_freq; std::shared_ptr d = m.getDictionary(); - labels_freq = d->getCounts(fasttext::entry_type::label); + std::vector labels_freq = d->getCounts(fasttext::entry_type::label); + std::vector labels_list; + labels_list.reserve(labels_freq.size()); + for (int32_t i = 0; i < labels_freq.size(); i++) { labels_list.push_back( castToPythonString(d->getLabel(i), onUnicodeError)); } return std::pair, std::vector>( - labels_list, labels_freq); + std::move(labels_list), std::move(labels_freq)); }) .def( "quantize", [](fasttext::FastText& m, - const std::string input, + const std::string& input, bool qout, int32_t cutoff, bool retrain, @@ -276,7 +281,7 @@ PYBIND11_MODULE(fasttext_pybind, m) { // NOTE: text needs to end in a newline // to exactly mimic the behavior of the cli [](fasttext::FastText& m, - const std::string text, + const std::string& text, int32_t k, fasttext::real threshold, const char* onUnicodeError) { @@ -284,18 +289,40 @@ PYBIND11_MODULE(fasttext_pybind, m) { std::vector> predictions; m.predictLine(ioss, predictions, k, threshold); - std::vector> - transformedPredictions; + std::vector> transformedPredictions; + transformedPredictions.reserve(predictions.size()); for (const auto& prediction : predictions) { - transformedPredictions.push_back(std::make_pair( + transformedPredictions.emplace_back( prediction.first, - castToPythonString(prediction.second, onUnicodeError))); + castToPythonString(prediction.second, onUnicodeError) + ); } return transformedPredictions; }) .def( + "predictAll", + // NOTE: text needs to end in a newline + // to exactly mimic the behavior of the cli + [](fasttext::FastText& m, const std::string& text) { + std::stringstream ioss(text); + std::vector> predictions; + + m.predictLine(ioss, predictions); + std::sort(std::begin(predictions), std::end(predictions), [](const auto& x, const auto &y) { + return x.second < y.second; + }); + + std::vector transformedPredictions; + transformedPredictions.reserve(predictions.size()); + + std::transform(std::begin(predictions), std::end(predictions), std::back_inserter(transformedPredictions), [](const auto& x) { + return x.first; + }); + return transformedPredictions; + }) + .def( "multilinePredict", // NOTE: text needs to end in a newline // to exactly mimic the behavior of the cli @@ -306,26 +333,56 @@ PYBIND11_MODULE(fasttext_pybind, m) { const char* onUnicodeError) { std::vector>> allPredictions; + allPredictions.reserve(lines.size()); std::vector> predictions; for (const std::string& text : lines) { - std::stringstream ioss(text); + std::stringstream ioss(text); /// stringstream is slow m.predictLine(ioss, predictions, k, threshold); - std::vector> - transformedPredictions; + std::vector> transformedPredictions; + transformedPredictions.reserve(predictions.size()); for (const auto& prediction : predictions) { - transformedPredictions.push_back(std::make_pair( + transformedPredictions.emplace_back( prediction.first, - castToPythonString(prediction.second, onUnicodeError))); + castToPythonString(prediction.second, onUnicodeError) + ); } - allPredictions.push_back(transformedPredictions); + allPredictions.push_back(std::move(transformedPredictions)); + } + return allPredictions; + }) + .def( + "multilinePredictAll", + // NOTE: text needs to end in a newline + // to exactly mimic the behavior of the cli + [](fasttext::FastText& m, const std::vector& lines) { + std::vector> allPredictions; + + allPredictions.reserve(lines.size()); + std::vector> predictions; + for (const std::string& text : lines) { + std::stringstream ioss(text); /// stringstream is slow + m.predictLine(ioss, predictions); + + std::sort(std::begin(predictions), std::end(predictions), [](const auto& x, const auto &y) { + return x.second < y.second; + }); + + std::vector transformedPredictions; + transformedPredictions.reserve(predictions.size()); + + std::transform(std::begin(predictions), std::end(predictions), std::back_inserter(transformedPredictions), [](const auto& x) { + return x.first; + }); + + allPredictions.push_back(std::move(transformedPredictions)); } return allPredictions; }) .def( "testLabel", [](fasttext::FastText& m, - const std::string filename, + const std::string& filename, int32_t k, fasttext::real threshold) { std::ifstream ifs(filename); @@ -335,7 +392,7 @@ PYBIND11_MODULE(fasttext_pybind, m) { fasttext::Meter meter; m.test(ifs, k, threshold, meter); std::shared_ptr d = m.getDictionary(); - std::unordered_map returnedValue; + std::unordered_map returnedValue(d->nlabels()); for (int32_t i = 0; i < d->nlabels(); i++) { returnedValue[d->getLabel(i)] = py::dict( "precision"_a = meter.precision(i), @@ -347,12 +404,12 @@ PYBIND11_MODULE(fasttext_pybind, m) { }) .def( "getWordId", - [](fasttext::FastText& m, const std::string word) { + [](fasttext::FastText& m, const std::string& word) { return m.getWordId(word); }) .def( "getSubwordId", - [](fasttext::FastText& m, const std::string word) { + [](fasttext::FastText& m, const std::string& word) { return m.getSubwordId(word); }) .def( @@ -364,25 +421,26 @@ PYBIND11_MODULE(fasttext_pybind, m) { "getWordVector", [](fasttext::FastText& m, fasttext::Vector& vec, - const std::string word) { m.getWordVector(vec, word); }) + const std::string& word) { m.getWordVector(vec, word); }) .def( "getSubwords", [](fasttext::FastText& m, - const std::string word, + const std::string& word, const char* onUnicodeError) { std::vector subwords; std::vector ngrams; std::shared_ptr d = m.getDictionary(); d->getSubwords(word, ngrams, subwords); + std::vector transformedSubwords; + transformedSubwords.reserve(subwords.size()); for (const auto& subword : subwords) { - transformedSubwords.push_back( - castToPythonString(subword, onUnicodeError)); + transformedSubwords.push_back(castToPythonString(subword, onUnicodeError)); } return std::pair, std::vector>( - transformedSubwords, ngrams); + std::move(transformedSubwords), std::move(ngrams)); }) .def("isQuant", [](fasttext::FastText& m) { return m.isQuant(); }); } diff --git a/setup.py b/setup.py index d8061c3..215dbe6 100644 --- a/setup.py +++ b/setup.py @@ -60,6 +60,10 @@ def __str__(self): map(lambda x: str(os.path.join(FASTTEXT_SRC, x)), fasttext_src_cc) ) +extra_compile_args = " -march=native -ffast-math -Wsuggest-final-methods" \ + " -Wsuggest-override -Wodr -flto -ftree-loop-linear" \ + " -floop-strip-mine -floop-block " + ext_modules = [ Extension( str('fasttext_pybind'), @@ -74,8 +78,8 @@ def __str__(self): FASTTEXT_SRC, ], language='c++', - extra_compile_args=["-O0 -fno-inline -fprofile-arcs -pthread -march=native" if coverage else - "-O3 -funroll-loops -pthread -march=native"], + extra_compile_args=[("-O0 -fno-inline -fprofile-arcs -pthread -march=native" if coverage else + "-O3 -funroll-loops -pthread -march=native") + extra_compile_args], ), ] @@ -100,6 +104,7 @@ def cpp_flag(compiler): """Return the -std=c++[0x/11/14] compiler flag. The c++14 is preferred over c++0x/11 (when it is available). """ + return '-std=c++14' standards = ['-std=c++14', '-std=c++11', '-std=c++0x'] for standard in standards: if has_flag(compiler, [standard]): @@ -134,6 +139,11 @@ def build_extensions(self): ct = self.compiler.compiler_type opts = self.c_opts.get(ct, []) extra_link_args = [] + self.c_opts['unix'] += [ + "-flto", "-march=native", "-ffast-math", "-Wsuggest-final-methods", + "-Wsuggest-override", "-Wodr", "-ftree-loop-linear", + "-floop-strip-mine", "-floop-block", "-O3", "-DNDEBUG", + ] if coverage: coverage_option = '--coverage' diff --git a/src/args.cc b/src/args.cc index cc1af1b..29de668 100644 --- a/src/args.cc +++ b/src/args.cc @@ -175,7 +175,7 @@ void Args::parseArgs(const std::vector& args) { printHelp(); exit(EXIT_FAILURE); } - } catch (std::out_of_range) { + } catch (std::out_of_range&) { std::cerr << args[ai] << " is missing an argument" << std::endl; printHelp(); exit(EXIT_FAILURE); diff --git a/src/densematrix.h b/src/densematrix.h index 1296779..be2d024 100644 --- a/src/densematrix.h +++ b/src/densematrix.h @@ -32,7 +32,7 @@ class DenseMatrix : public Matrix { DenseMatrix(DenseMatrix&&) noexcept; DenseMatrix& operator=(const DenseMatrix&) = delete; DenseMatrix& operator=(DenseMatrix&&) = delete; - virtual ~DenseMatrix() noexcept override = default; + virtual ~DenseMatrix() noexcept final override = default; inline real* data() { return data_.data(); @@ -64,12 +64,12 @@ class DenseMatrix : public Matrix { real l2NormRow(int64_t i) const; void l2NormRow(Vector& norms) const; - real dotRow(const Vector&, int64_t) const override; - void addVectorToRow(const Vector&, int64_t, real) override; + real dotRow(const Vector&, int64_t) const override final; + void addVectorToRow(const Vector&, int64_t, real) override final; void addRowToVector(Vector& x, int32_t i) const override; - void addRowToVector(Vector& x, int32_t i, real a) const override; - void save(std::ostream&) const override; - void load(std::istream&) override; - void dump(std::ostream&) const override; + void addRowToVector(Vector& x, int32_t i, real a) const override final; + void save(std::ostream&) const override final; + void load(std::istream&) override final; + void dump(std::ostream&) const override final; }; } // namespace fasttext diff --git a/src/dictionary.cc b/src/dictionary.cc index cb396cd..ea81700 100644 --- a/src/dictionary.cc +++ b/src/dictionary.cc @@ -161,23 +161,27 @@ std::string Dictionary::getWord(int32_t id) const { // using signed char, we fixed the hash function to make models // compatible whatever compiler is used. uint32_t Dictionary::hash(const std::string& str) const { - uint32_t h = 2166136261; - for (size_t i = 0; i < str.size(); i++) { - h = h ^ uint32_t(int8_t(str[i])); - h = h * 16777619; + uint32_t h = 2166136261u; + for (auto x : str) { + h = h ^ uint32_t(int8_t(x)); // TODO: oops + h = h * 16777619u; } return h; } + void Dictionary::computeSubwords( const std::string& word, std::vector& ngrams, std::vector* substrings) const { + + ngrams.reserve(3ul * word.size()); for (size_t i = 0; i < word.size(); i++) { std::string ngram; if ((word[i] & 0xC0) == 0x80) { continue; } + for (size_t j = i, n = 1; j < word.size() && n <= args_->maxn; n++) { ngram.push_back(word[j++]); while (j < word.size() && (word[j] & 0xC0) == 0x80) { @@ -301,11 +305,13 @@ void Dictionary::initTableDiscard() { std::vector Dictionary::getCounts(entry_type type) const { std::vector counts; - for (auto& w : words_) { + // counts.reserve(words_.size()); + for (const auto& w : words_) { if (w.type == type) { counts.push_back(w.count); } } + // counts.shrink_to_fit(); return counts; } @@ -489,11 +495,13 @@ void Dictionary::init() { void Dictionary::prune(std::vector& idx) { std::vector words, ngrams; - for (auto it = idx.cbegin(); it != idx.cend(); ++it) { - if (*it < nwords_) { - words.push_back(*it); + words.reserve(idx.size()); + ngrams.reserve(idx.size()); + for (int32_t x : idx) { + if (x < nwords_) { + words.push_back(x); } else { - ngrams.push_back(*it); + ngrams.push_back(x); } } std::sort(words.begin(), words.end()); @@ -528,12 +536,12 @@ void Dictionary::prune(std::vector& idx) { void Dictionary::dump(std::ostream& out) const { out << words_.size() << std::endl; - for (auto it : words_) { - std::string entryType = "word"; + for (const auto& it : words_) { if (it.type == entry_type::label) { - entryType = "label"; + out << it.word << " " << it.count << " " << "label" << std::endl; + } else { + out << it.word << " " << it.count << " " << "word" << std::endl; } - out << it.word << " " << it.count << " " << entryType << std::endl; } } diff --git a/src/fasttext.cc b/src/fasttext.cc index 2fb9247..454dab2 100644 --- a/src/fasttext.cc +++ b/src/fasttext.cc @@ -89,9 +89,10 @@ int32_t FastText::getSubwordId(const std::string& subword) const { void FastText::getWordVector(Vector& vec, const std::string& word) const { const std::vector& ngrams = dict_->getSubwords(word); vec.zero(); - for (int i = 0; i < ngrams.size(); i++) { - addInputVector(vec, ngrams[i]); + for (int32_t ngram : ngrams) { + addInputVector(vec, ngram); } + if (ngrams.size() > 0) { vec.mul(1.0 / ngrams.size()); } @@ -121,7 +122,6 @@ void FastText::saveVectors(const std::string& filename) { getWordVector(vec, word); ofs << word << " " << vec << std::endl; } - ofs.close(); } void FastText::saveVectors() { @@ -449,6 +449,20 @@ void FastText::predict( model_->predict(words, k, threshold, predictions, state); } +void FastText::predict(const std::vector& words, Predictions& predictions) const { + if (words.empty()) { + return; + } + Model::State state(args_->dim, dict_->nlabels(), 0); + predictions.reserve(dict_->nlabels()); + if (args_->model != model_name::sup) { + throw std::invalid_argument("Model needs to be supervised for prediction!"); + } + model_->predict(words, predictions, state); +} + + + bool FastText::predictLine( std::istream& in, std::vector>& predictions, @@ -463,15 +477,35 @@ bool FastText::predictLine( dict_->getLine(in, words, labels); Predictions linePredictions; predict(k, words, linePredictions, threshold); + predictions.reserve(linePredictions.size()); for (const auto& p : linePredictions) { - predictions.push_back( - std::make_pair(std::exp(p.first), dict_->getLabel(p.second))); + predictions.emplace_back(std::exp(p.first), dict_->getLabel(p.second)); } return true; } +bool FastText::predictLine(std::istream& in, std::vector>& predictions) const { + predictions.clear(); + if (in.peek() == EOF) { + return false; + } + + std::vector words, labels; + dict_->getLine(in, words, labels); + Predictions linePredictions; + predict(words, linePredictions); + predictions.reserve(linePredictions.size()); + for (const auto& p : linePredictions) { + predictions.emplace_back(std::exp(p.first), dict_->getLabel(p.second)); + } + + return true; +} + + void FastText::getSentenceVector(std::istream& in, fasttext::Vector& svec) { + // std::istream is slow svec.zero(); if (args_->model == model_name::sup) { std::vector line, labels; @@ -511,12 +545,13 @@ std::vector> FastText::getNgramVectors( std::vector substrings; dict_->getSubwords(word, ngrams, substrings); assert(ngrams.size() <= substrings.size()); + result.reserve(ngrams.size()); for (int32_t i = 0; i < ngrams.size(); i++) { Vector vec(args_->dim); if (ngrams[i] >= 0) { vec.addRow(*input_, ngrams[i]); } - result.push_back(std::make_pair(substrings[i], std::move(vec))); + result.emplace_back(substrings[i], std::move(vec)); } return result; } @@ -570,30 +605,44 @@ std::vector> FastText::getNN( int32_t k, const std::set& banSet) { std::vector> heap; + heap.reserve(size_t(k + 1)); real queryNorm = query.norm(); if (std::abs(queryNorm) < 1e-8) { queryNorm = 1; } - for (int32_t i = 0; i < dict_->nwords(); i++) { + int32_t i = 0; + while (i < dict_->nwords() && heap.size() < k) { std::string word = dict_->getWord(i); if (banSet.find(word) == banSet.end()) { real dp = wordVectors.dotRow(query, i); real similarity = dp / queryNorm; + heap.emplace_back(similarity, std::move(word)); + } + ++i; + } - if (heap.size() == k && similarity < heap.front().first) { - continue; - } - heap.push_back(std::make_pair(similarity, word)); - std::push_heap(heap.begin(), heap.end(), comparePairs); - if (heap.size() > k) { - std::pop_heap(heap.begin(), heap.end(), comparePairs); - heap.pop_back(); + greater_first cmp; + std::make_heap(std::begin(heap), std::end(heap), cmp); + + while (i < dict_->nwords()) { + std::string word = dict_->getWord(i); + if (banSet.find(word) == banSet.end()) { + real dp = wordVectors.dotRow(query, i); + real similarity = dp / queryNorm; + + if (similarity >= heap.front().first) { + heap.emplace_back(similarity, std::move(word)); + std::push_heap(heap.begin(), heap.end(), cmp); + std::pop_heap(heap.begin(), heap.end(), cmp); + heap.pop_back(); } } + ++i; } - std::sort_heap(heap.begin(), heap.end(), comparePairs); + + std::sort(heap.begin(), heap.end(), cmp); // faster than std::sort_heap return heap; } @@ -686,7 +735,6 @@ std::shared_ptr FastText::getInputMatrixFromFile( const std::string& filename) const { std::ifstream in(filename); std::vector words; - std::shared_ptr mat; // temp. matrix for pretrained vectors int64_t n, dim; if (!in.is_open()) { throw std::invalid_argument(filename + " cannot be opened for loading!"); @@ -697,23 +745,22 @@ std::shared_ptr FastText::getInputMatrixFromFile( "Dimension of pretrained vectors (" + std::to_string(dim) + ") does not match dimension (" + std::to_string(args_->dim) + ")!"); } - mat = std::make_shared(n, dim); + DenseMatrix mat(n, dim); // temp. matrix for pretrained vectors for (size_t i = 0; i < n; i++) { std::string word; in >> word; words.push_back(word); dict_->add(word); for (size_t j = 0; j < dim; j++) { - in >> mat->at(i, j); + in >> mat.at(i, j); } } in.close(); dict_->threshold(1, 0); dict_->init(); - std::shared_ptr input = std::make_shared( - dict_->nwords() + args_->bucket, args_->dim); - input->uniform(1.0 / args_->dim); + DenseMatrix input(dict_->nwords() + args_->bucket, args_->dim); + input.uniform(1.0 / args_->dim); for (size_t i = 0; i < n; i++) { int32_t idx = dict_->getId(words[i]); @@ -721,10 +768,10 @@ std::shared_ptr FastText::getInputMatrixFromFile( continue; } for (size_t j = 0; j < dim; j++) { - input->at(idx, j) = mat->at(i, j); + input.at(idx, j) = mat.at(i, j); } } - return input; + return std::make_shared(std::move(input)); } void FastText::loadVectors(const std::string& filename) { @@ -781,13 +828,14 @@ void FastText::startThreads() { tokenCount_ = 0; loss_ = -1; std::vector threads; + threads.reserve(args_->thread); for (int32_t i = 0; i < args_->thread; i++) { threads.push_back(std::thread([=]() { trainThread(i); })); } const int64_t ntokens = dict_->ntokens(); // Same condition as trainThread while (tokenCount_ < args_->epoch * ntokens) { - std::this_thread::sleep_for(std::chrono::milliseconds(100)); + std::this_thread::sleep_for(std::chrono::milliseconds(10000)); if (loss_ >= 0 && args_->verbose > 1) { real progress = real(tokenCount_) / (args_->epoch * ntokens); std::cerr << "\r"; @@ -812,9 +860,8 @@ bool FastText::isQuant() const { return quant_; } -bool comparePairs( - const std::pair& l, - const std::pair& r) { +bool comparePairs(const std::pair &l, + const std::pair &r) { return l.first > r.first; } diff --git a/src/fasttext.h b/src/fasttext.h index b63e4cd..010ef45 100644 --- a/src/fasttext.h +++ b/src/fasttext.h @@ -122,12 +122,16 @@ class FastText { Predictions& predictions, real threshold = 0.0) const; + void predict(const std::vector& words, Predictions& predictions) const; + bool predictLine( std::istream& in, std::vector>& predictions, int32_t k, real threshold) const; + bool predictLine(std::istream& in, std::vector>& predictions) const; + std::vector> getNgramVectors( const std::string& word) const; @@ -188,4 +192,11 @@ class FastText { const std::set& banSet, std::vector>& results); }; + +template +struct greater_first { + bool operator()(const T& x, const T& y) { + return y.first < x.first; + } +}; } // namespace fasttext diff --git a/src/loss.cc b/src/loss.cc index 285eb9f..9610679 100644 --- a/src/loss.cc +++ b/src/loss.cc @@ -71,6 +71,15 @@ void Loss::predict( std::sort_heap(heap.begin(), heap.end(), comparePairs); } +void Loss::predict(Predictions& predictions, Model::State& state) const { + computeOutput(state); + const Vector& output = state.output; + predictions.reserve(output.size()); + for (int32_t i = 0; i < output.size(); i++) { + predictions.emplace_back(std_log(output[i]), i); + } +} + void Loss::findKBest( int32_t k, real threshold, @@ -83,7 +92,7 @@ void Loss::findKBest( if (heap.size() == k && std_log(output[i]) < heap.front().first) { continue; } - heap.push_back(std::make_pair(std_log(output[i]), i)); + heap.emplace_back(std_log(output[i]), i); std::push_heap(heap.begin(), heap.end(), comparePairs); if (heap.size() > k) { std::pop_heap(heap.begin(), heap.end(), comparePairs); @@ -239,8 +248,8 @@ void HierarchicalSoftmaxLoss::buildTree(const std::vector& counts) { code.push_back(tree_[j].binary); j = tree_[j].parent; } - paths_.push_back(path); - codes_.push_back(code); + paths_.push_back(std::move(path)); + codes_.push_back(std::move(code)); } } @@ -269,6 +278,12 @@ void HierarchicalSoftmaxLoss::predict( std::sort_heap(heap.begin(), heap.end(), comparePairs); } +void HierarchicalSoftmaxLoss::predict(Predictions& predictions, Model::State& state) const { + dfs(2 * osz_ - 2, 0.0, predictions, state.hidden); +} + + + void HierarchicalSoftmaxLoss::dfs( int32_t k, real threshold, @@ -284,7 +299,7 @@ void HierarchicalSoftmaxLoss::dfs( } if (tree_[node].left == -1 && tree_[node].right == -1) { - heap.push_back(std::make_pair(score, node)); + heap.emplace_back(score, node); std::push_heap(heap.begin(), heap.end(), comparePairs); if (heap.size() > k) { std::pop_heap(heap.begin(), heap.end(), comparePairs); @@ -300,6 +315,24 @@ void HierarchicalSoftmaxLoss::dfs( dfs(k, threshold, tree_[node].right, score + std_log(f), heap, hidden); } +void HierarchicalSoftmaxLoss::dfs( + int32_t node, + real score, + Predictions& predictions, + const Vector& hidden) const { + + if (tree_[node].left == -1 && tree_[node].right == -1) { + predictions.emplace_back(score, node); + return; + } + + real f = wo_->dotRow(hidden, node - osz_); + f = 1. / (1 + std::exp(-f)); + + dfs(tree_[node].left, score + std_log(1.0 - f), predictions, hidden); + dfs(tree_[node].right, score + std_log(f), predictions, hidden); +} + SoftmaxLoss::SoftmaxLoss(std::shared_ptr& wo) : Loss(wo) {} void SoftmaxLoss::computeOutput(Model::State& state) const { diff --git a/src/loss.h b/src/loss.h index 3aea72f..d8e0d0e 100644 --- a/src/loss.h +++ b/src/loss.h @@ -53,6 +53,8 @@ class Loss { real /*threshold*/, Predictions& /*heap*/, Model::State& /*state*/) const; + + virtual void predict(Predictions& predictions, Model::State& state) const; }; class BinaryLogisticLoss : public Loss { @@ -73,7 +75,7 @@ class BinaryLogisticLoss : public Loss { class OneVsAllLoss : public BinaryLogisticLoss { public: explicit OneVsAllLoss(std::shared_ptr& wo); - ~OneVsAllLoss() noexcept override = default; + ~OneVsAllLoss() noexcept override final = default; real forward( const std::vector& targets, int32_t targetIndex, @@ -96,14 +98,14 @@ class NegativeSamplingLoss : public BinaryLogisticLoss { std::shared_ptr& wo, int neg, const std::vector& targetCounts); - ~NegativeSamplingLoss() noexcept override = default; + ~NegativeSamplingLoss() noexcept override final = default; real forward( const std::vector& targets, int32_t targetIndex, Model::State& state, real lr, - bool backprop) override; + bool backprop) override final; }; class HierarchicalSoftmaxLoss : public BinaryLogisticLoss { @@ -129,35 +131,39 @@ class HierarchicalSoftmaxLoss : public BinaryLogisticLoss { Predictions& heap, const Vector& hidden) const; + void dfs(int32_t node, real score, Predictions& predictions, const Vector& hidden) const; + public: explicit HierarchicalSoftmaxLoss( std::shared_ptr& wo, const std::vector& counts); - ~HierarchicalSoftmaxLoss() noexcept override = default; + ~HierarchicalSoftmaxLoss() noexcept override final = default; real forward( const std::vector& targets, int32_t targetIndex, Model::State& state, real lr, - bool backprop) override; + bool backprop) override final; void predict( int32_t k, real threshold, Predictions& heap, - Model::State& state) const override; + Model::State& state) const override final; + + void predict(Predictions& predictions, Model::State& state) const override final; }; class SoftmaxLoss : public Loss { public: explicit SoftmaxLoss(std::shared_ptr& wo); - ~SoftmaxLoss() noexcept override = default; + ~SoftmaxLoss() noexcept override final = default; real forward( const std::vector& targets, int32_t targetIndex, Model::State& state, real lr, bool backprop) override; - void computeOutput(Model::State& state) const override; + void computeOutput(Model::State& state) const override final; }; } // namespace fasttext diff --git a/src/model.cc b/src/model.cc index b4fcaed..a38c6f4 100644 --- a/src/model.cc +++ b/src/model.cc @@ -67,6 +67,11 @@ void Model::predict( loss_->predict(k, threshold, heap, state); } +void Model::predict(const std::vector& input, Predictions& predictions, State& state) const { + computeHidden(input, state); + loss_->predict(predictions, state); +} + void Model::update( const std::vector& input, const std::vector& targets, diff --git a/src/model.h b/src/model.h index 65987d6..35089df 100644 --- a/src/model.h +++ b/src/model.h @@ -62,6 +62,9 @@ class Model { real threshold, Predictions& heap, State& state) const; + + void predict(const std::vector& input, Predictions& predictions, State& state) const; + void update( const std::vector& input, const std::vector& targets, diff --git a/src/quantmatrix.h b/src/quantmatrix.h index 9e1b2f2..ffe8eba 100644 --- a/src/quantmatrix.h +++ b/src/quantmatrix.h @@ -43,7 +43,7 @@ class QuantMatrix : public Matrix { QuantMatrix(QuantMatrix&&) = delete; QuantMatrix& operator=(const QuantMatrix&) = delete; QuantMatrix& operator=(QuantMatrix&&) = delete; - virtual ~QuantMatrix() noexcept override = default; + virtual ~QuantMatrix() noexcept override final = default; void quantizeNorm(const Vector&); void quantize(DenseMatrix&& mat);