diff --git a/.travis.yml b/.travis.yml index 5c71a1e..ed8dd68 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,6 +1,6 @@ language: java jdk: - - oraclejdk8 + - oraclejdk9 os: - linux - osx diff --git a/src/main/cpp/fasttext_wrapper.cc b/src/main/cpp/fasttext_wrapper.cc index ec5c90e..58ef344 100644 --- a/src/main/cpp/fasttext_wrapper.cc +++ b/src/main/cpp/fasttext_wrapper.cc @@ -80,7 +80,14 @@ namespace FastTextWrapper { std::vector FastTextApi::getVector(const std::string& word) { Vector vec(privateMembers->args_->dim); - fastText.getVector(vec, word); + fastText.getWordVector(vec, word); + return std::vector(vec.data(), vec.data() + vec.size()); + } + + std::vector FastTextApi::getSentenceVector(const std::string& sentence) { + Vector vec(privateMembers->args_->dim); + std::istringstream in(sentence); + fastText.getSentenceVector(in, vec); return std::vector(vec.data(), vec.data() + vec.size()); } diff --git a/src/main/cpp/fasttext_wrapper.h b/src/main/cpp/fasttext_wrapper.h index 10562ef..68758ec 100644 --- a/src/main/cpp/fasttext_wrapper.h +++ b/src/main/cpp/fasttext_wrapper.h @@ -29,6 +29,7 @@ namespace FastTextWrapper { std::vector predict(const std::string&, int32_t); std::vector> predictProba(const std::string&, int32_t); std::vector getVector(const std::string&); + std::vector getSentenceVector(const std::string&); std::vector getWords(); std::vector getLabels(); std::string getWord(int32_t); diff --git a/src/main/java/com/github/jfasttext/FastTextWrapper.java b/src/main/java/com/github/jfasttext/FastTextWrapper.java index c583d03..023b19a 100644 --- a/src/main/java/com/github/jfasttext/FastTextWrapper.java +++ b/src/main/java/com/github/jfasttext/FastTextWrapper.java @@ -201,6 +201,8 @@ public DoubleIntPair put(float firstValue, int secondValue) { public native @ByVal FloatStringPairVector predictProba(@StdString String arg0, int arg1); public native @ByVal RealVector getVector(@StdString BytePointer arg0); public native @ByVal RealVector getVector(@StdString String arg0); + public native @ByVal RealVector getSentenceVector(@StdString BytePointer arg0); + public native @ByVal RealVector getSentenceVector(@StdString String arg0); public native @ByVal StringVector getWords(); public native @ByVal StringVector getLabels(); public native @StdString BytePointer getWord(int arg0); diff --git a/src/main/java/com/github/jfasttext/JFastText.java b/src/main/java/com/github/jfasttext/JFastText.java index 1f2d333..c7ce7d1 100644 --- a/src/main/java/com/github/jfasttext/JFastText.java +++ b/src/main/java/com/github/jfasttext/JFastText.java @@ -93,6 +93,18 @@ public List getVector(String word) { return wordVec; } + public List getSentenceVector(String sentence) { + if (!sentence.endsWith("\n")) { + sentence += "\n"; + } + FastTextWrapper.RealVector rv = fta.getSentenceVector(sentence); + List wordVec = new ArrayList<>(); + for (int i = 0; i < rv.size(); i++) { + wordVec.add(rv.get(i)); + } + return wordVec; + } + public int getNWords() { return fta.getNWords(); } diff --git a/src/test/java/com/github/jfasttext/JFastTextTest.java b/src/test/java/com/github/jfasttext/JFastTextTest.java index b6f1521..d7eb3c5 100644 --- a/src/test/java/com/github/jfasttext/JFastTextTest.java +++ b/src/test/java/com/github/jfasttext/JFastTextTest.java @@ -16,7 +16,9 @@ public void test01TrainSupervisedCmd() { jft.runCmd(new String[] { "supervised", "-input", "src/test/resources/data/labeled_data.txt", - "-output", "src/test/resources/models/supervised.model" + "-output", "src/test/resources/models/supervised.model", + "-wordNgrams", "3", + "-bucket", "100" }); } @@ -86,11 +88,20 @@ public void test07GetVector() throws Exception { System.out.printf("\nWord embedding vector of '%s': %s\n", word, vec); } + @Test + public void test08GetSentenceVector() throws Exception { + JFastText jft = new JFastText(); + jft.loadModel("src/test/resources/models/supervised.model.bin"); + String word = "soccers"; + List vec = jft.getSentenceVector(word); + System.out.printf("\nSentence embedding vector of '%s': %s\n", word, vec); + } + /** * Test retrieving model's information: words, labels, learning rate, etc. */ @Test - public void test08ModelInfo() throws Exception { + public void test09ModelInfo() throws Exception { System.out.printf("\nSupervised model information:\n"); JFastText jft = new JFastText(); jft.loadModel("src/test/resources/models/supervised.model.bin"); @@ -113,7 +124,7 @@ public void test08ModelInfo() throws Exception { * allocated by native function calls). */ @Test - public void test09ModelUnloading() throws Exception { + public void test10ModelUnloading() throws Exception { JFastText jft = new JFastText(); System.out.println("\nLoading model ..."); jft.loadModel("src/test/resources/models/supervised.model.bin");