Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 62 additions & 4 deletions src/main/java/com/github/jfasttext/JFastText.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,40 @@
import org.bytedeco.javacpp.PointerPointer;

import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.net.MalformedURLException;
import java.net.URI;
import java.net.URL;
import java.nio.file.CopyOption;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.StandardCopyOption;
import java.util.ArrayList;
import java.util.List;

public class JFastText {

private FastTextWrapper.FastTextApi fta;
private FastTextWrapper.FastTextApi fta = new FastTextWrapper.FastTextApi();;

public JFastText() {
fta = new FastTextWrapper.FastTextApi();
}

public JFastText(final String modelFile) {
loadModel(modelFile);
}

public JFastText(final URI modelUri) throws IOException {
loadModel(modelUri);
}

public JFastText(final URL modelUrl) throws IOException {
loadModel(modelUrl);
}

public JFastText(final InputStream modelStream) throws IOException {
loadModel(modelStream);
}

public void runCmd(String[] args) {
Expand All @@ -22,9 +47,9 @@ public void runCmd(String[] args) {
fta.runCmd(cArgs.length, new PointerPointer(cArgs));
}

public void loadModel(String modelFile) {
public void loadModel(String modelFile) throws IllegalArgumentException {
if (!new File(modelFile).exists()) {
throw new IllegalArgumentException("Model file doesn't exist!");
throw new IllegalArgumentException("Model file " + modelFile + " doesn't exist!");
}
if (!fta.checkModel(modelFile)) {
throw new IllegalArgumentException(
Expand All @@ -33,6 +58,39 @@ public void loadModel(String modelFile) {
fta.loadModel(modelFile);
}

/**
* Loads model from location specified by URI, by copying its content into local file & then loading it.
*
* @param modelUri location of the model
*/
public void loadModel(URI modelUri) throws IOException {
loadModel(modelUri.toURL());
}

/**
* Loads model from location specified by URL, by copying its content into local file & then loading it.
*
* @param modelUrl location of the model
*/
public void loadModel(URL modelUrl) throws IOException {
loadModel(modelUrl.openStream());
}

/**
* Loads model given InputStream, by copying its content into local file & then loading it.
*
* @param modelStream stream for model
*/
public void loadModel(InputStream modelStream) throws IOException {
Path tmpFile = Files.createTempFile("jft-", ".model");
try {
Files.copy(modelStream, tmpFile, StandardCopyOption.REPLACE_EXISTING);
loadModel(tmpFile.toString());
} finally {
Files.deleteIfExists(tmpFile);
}
}

public void unloadModel() {
fta.unloadModel();
}
Expand Down
34 changes: 23 additions & 11 deletions src/test/java/com/github/jfasttext/JFastTextTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
import org.junit.Test;
import org.junit.runners.MethodSorters;

import java.io.FileInputStream;
import java.io.InputStream;
import java.net.URL;
import java.util.List;

@FixMethodOrder(MethodSorters.NAME_ASCENDING)
Expand Down Expand Up @@ -57,8 +60,7 @@ public void test04Predict() throws Exception {

@Test
public void test05PredictProba() throws Exception {
JFastText jft = new JFastText();
jft.loadModel("src/test/resources/models/supervised.model.bin");
JFastText jft = new JFastText("src/test/resources/models/supervised.model.bin");
String text = "What is the most popular sport in the US ?";
JFastText.ProbLabel predictedProbLabel = jft.predictProba(text);
System.out.printf("\nText: '%s', label: '%s', probability: %f\n",
Expand All @@ -67,8 +69,7 @@ public void test05PredictProba() throws Exception {

@Test
public void test06MultiPredictProba() throws Exception {
JFastText jft = new JFastText();
jft.loadModel("src/test/resources/models/supervised.model.bin");
JFastText jft = new JFastText("src/test/resources/models/supervised.model.bin");
String text = "Do you like soccer ?";
System.out.printf("Text: '%s'\n", text);
for (JFastText.ProbLabel predictedProbLabel: jft.predictProba(text, 2)) {
Expand All @@ -79,11 +80,12 @@ public void test06MultiPredictProba() throws Exception {

@Test
public void test07GetVector() throws Exception {
JFastText jft = new JFastText();
jft.loadModel("src/test/resources/models/supervised.model.bin");
String word = "soccer";
List<Float> vec = jft.getVector(word);
System.out.printf("\nWord embedding vector of '%s': %s\n", word, vec);
try (InputStream is = new FileInputStream("src/test/resources/models/supervised.model.bin")) {
JFastText jft = new JFastText(is);
String word = "soccer";
List<Float> vec = jft.getVector(word);
System.out.printf("\nWord embedding vector of '%s': %s\n", word, vec);
}
}

/**
Expand All @@ -92,8 +94,7 @@ public void test07GetVector() throws Exception {
@Test
public void test08ModelInfo() throws Exception {
System.out.printf("\nSupervised model information:\n");
JFastText jft = new JFastText();
jft.loadModel("src/test/resources/models/supervised.model.bin");
JFastText jft = new JFastText("src/test/resources/models/supervised.model.bin");
System.out.printf("\tnumber of words = %d\n", jft.getNWords());
System.out.printf("\twords = %s\n", jft.getWords());
System.out.printf("\tlearning rate = %g\n", jft.getLr());
Expand All @@ -120,4 +121,15 @@ public void test09ModelUnloading() throws Exception {
System.out.println("Unloading model ...");
jft.unloadModel();
}

/**
* Loads model from specified URL (resource, web, etc.)
*
*/
@Test
public void test10ModelFromURL() throws Exception {
URL modelUrl = this.getClass().getClassLoader().getResource("models/supervised.model.bin");
JFastText jft = new JFastText(modelUrl);
System.out.printf("\tnumber of words = %d\n", jft.getNWords());
}
}