diff --git a/src/main/java/com/github/jfasttext/JFastText.java b/src/main/java/com/github/jfasttext/JFastText.java index 1f2d333..43fddd7 100644 --- a/src/main/java/com/github/jfasttext/JFastText.java +++ b/src/main/java/com/github/jfasttext/JFastText.java @@ -3,15 +3,40 @@ import org.bytedeco.javacpp.PointerPointer; import java.io.File; +import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.net.MalformedURLException; +import java.net.URI; +import java.net.URL; +import java.nio.file.CopyOption; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.StandardCopyOption; import java.util.ArrayList; import java.util.List; public class JFastText { - private FastTextWrapper.FastTextApi fta; + private FastTextWrapper.FastTextApi fta = new FastTextWrapper.FastTextApi();; public JFastText() { - fta = new FastTextWrapper.FastTextApi(); + } + + public JFastText(final String modelFile) { + loadModel(modelFile); + } + + public JFastText(final URI modelUri) throws IOException { + loadModel(modelUri); + } + + public JFastText(final URL modelUrl) throws IOException { + loadModel(modelUrl); + } + + public JFastText(final InputStream modelStream) throws IOException { + loadModel(modelStream); } public void runCmd(String[] args) { @@ -22,9 +47,9 @@ public void runCmd(String[] args) { fta.runCmd(cArgs.length, new PointerPointer(cArgs)); } - public void loadModel(String modelFile) { + public void loadModel(String modelFile) throws IllegalArgumentException { if (!new File(modelFile).exists()) { - throw new IllegalArgumentException("Model file doesn't exist!"); + throw new IllegalArgumentException("Model file " + modelFile + " doesn't exist!"); } if (!fta.checkModel(modelFile)) { throw new IllegalArgumentException( @@ -33,6 +58,39 @@ public void loadModel(String modelFile) { fta.loadModel(modelFile); } + /** + * Loads model from location specified by URI, by copying its content into local file & then loading it. + * + * @param modelUri location of the model + */ + public void loadModel(URI modelUri) throws IOException { + loadModel(modelUri.toURL()); + } + + /** + * Loads model from location specified by URL, by copying its content into local file & then loading it. + * + * @param modelUrl location of the model + */ + public void loadModel(URL modelUrl) throws IOException { + loadModel(modelUrl.openStream()); + } + + /** + * Loads model given InputStream, by copying its content into local file & then loading it. + * + * @param modelStream stream for model + */ + public void loadModel(InputStream modelStream) throws IOException { + Path tmpFile = Files.createTempFile("jft-", ".model"); + try { + Files.copy(modelStream, tmpFile, StandardCopyOption.REPLACE_EXISTING); + loadModel(tmpFile.toString()); + } finally { + Files.deleteIfExists(tmpFile); + } + } + public void unloadModel() { fta.unloadModel(); } diff --git a/src/test/java/com/github/jfasttext/JFastTextTest.java b/src/test/java/com/github/jfasttext/JFastTextTest.java index b6f1521..3af9d1d 100644 --- a/src/test/java/com/github/jfasttext/JFastTextTest.java +++ b/src/test/java/com/github/jfasttext/JFastTextTest.java @@ -4,6 +4,9 @@ import org.junit.Test; import org.junit.runners.MethodSorters; +import java.io.FileInputStream; +import java.io.InputStream; +import java.net.URL; import java.util.List; @FixMethodOrder(MethodSorters.NAME_ASCENDING) @@ -57,8 +60,7 @@ public void test04Predict() throws Exception { @Test public void test05PredictProba() throws Exception { - JFastText jft = new JFastText(); - jft.loadModel("src/test/resources/models/supervised.model.bin"); + JFastText jft = new JFastText("src/test/resources/models/supervised.model.bin"); String text = "What is the most popular sport in the US ?"; JFastText.ProbLabel predictedProbLabel = jft.predictProba(text); System.out.printf("\nText: '%s', label: '%s', probability: %f\n", @@ -67,8 +69,7 @@ public void test05PredictProba() throws Exception { @Test public void test06MultiPredictProba() throws Exception { - JFastText jft = new JFastText(); - jft.loadModel("src/test/resources/models/supervised.model.bin"); + JFastText jft = new JFastText("src/test/resources/models/supervised.model.bin"); String text = "Do you like soccer ?"; System.out.printf("Text: '%s'\n", text); for (JFastText.ProbLabel predictedProbLabel: jft.predictProba(text, 2)) { @@ -79,11 +80,12 @@ public void test06MultiPredictProba() throws Exception { @Test public void test07GetVector() throws Exception { - JFastText jft = new JFastText(); - jft.loadModel("src/test/resources/models/supervised.model.bin"); - String word = "soccer"; - List vec = jft.getVector(word); - System.out.printf("\nWord embedding vector of '%s': %s\n", word, vec); + try (InputStream is = new FileInputStream("src/test/resources/models/supervised.model.bin")) { + JFastText jft = new JFastText(is); + String word = "soccer"; + List vec = jft.getVector(word); + System.out.printf("\nWord embedding vector of '%s': %s\n", word, vec); + } } /** @@ -92,8 +94,7 @@ public void test07GetVector() throws Exception { @Test public void test08ModelInfo() throws Exception { System.out.printf("\nSupervised model information:\n"); - JFastText jft = new JFastText(); - jft.loadModel("src/test/resources/models/supervised.model.bin"); + JFastText jft = new JFastText("src/test/resources/models/supervised.model.bin"); System.out.printf("\tnumber of words = %d\n", jft.getNWords()); System.out.printf("\twords = %s\n", jft.getWords()); System.out.printf("\tlearning rate = %g\n", jft.getLr()); @@ -120,4 +121,15 @@ public void test09ModelUnloading() throws Exception { System.out.println("Unloading model ..."); jft.unloadModel(); } + + /** + * Loads model from specified URL (resource, web, etc.) + * + */ + @Test + public void test10ModelFromURL() throws Exception { + URL modelUrl = this.getClass().getClassLoader().getResource("models/supervised.model.bin"); + JFastText jft = new JFastText(modelUrl); + System.out.printf("\tnumber of words = %d\n", jft.getNWords()); + } }