From 8e7d6da46a592fc99338dd198f5d4f266884df21 Mon Sep 17 00:00:00 2001 From: Kubra Aksu <67320250+kubraaksux@users.noreply.github.com> Date: Thu, 12 Feb 2026 13:36:38 +0100 Subject: [PATCH 1/8] Add LLM inference support to JMLC API via Py4J bridge --- .../org/apache/sysds/api/jmlc/Connection.java | 119 ++++++++++++++++++ .../apache/sysds/api/jmlc/LLMCallback.java | 9 ++ .../apache/sysds/api/jmlc/PreparedScript.java | 38 ++++++ .../dictionary/MatrixBlockDictionary.java | 26 +--- src/main/python/systemds/llm_worker.py | 63 ++++++++++ .../functions/jmlc/JMLCLLMInferenceTest.java | 83 ++++++++++++ 6 files changed, 316 insertions(+), 22 deletions(-) create mode 100644 src/main/java/org/apache/sysds/api/jmlc/LLMCallback.java create mode 100644 src/main/python/systemds/llm_worker.py create mode 100644 src/test/java/org/apache/sysds/test/functions/jmlc/JMLCLLMInferenceTest.java diff --git a/src/main/java/org/apache/sysds/api/jmlc/Connection.java b/src/main/java/org/apache/sysds/api/jmlc/Connection.java index 525c1a97bb2..c58241a222b 100644 --- a/src/main/java/org/apache/sysds/api/jmlc/Connection.java +++ b/src/main/java/org/apache/sysds/api/jmlc/Connection.java @@ -28,7 +28,11 @@ import java.util.Arrays; import java.util.Collections; import java.util.Map; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.Path; import org.apache.sysds.hops.OptimizerUtils; @@ -66,6 +70,7 @@ import org.apache.sysds.runtime.transform.meta.TfMetaUtils; import org.apache.sysds.runtime.util.CollectionUtils; import org.apache.sysds.runtime.util.DataConverter; +import py4j.GatewayServer; /** * Interaction with SystemDS using the JMLC (Java Machine Learning Connector) API is initiated with @@ -91,6 +96,12 @@ public class Connection implements Closeable private final DMLConfig _dmlconf; private final CompilerConfig _cconf; private static FileSystem fs = null; + private Process _pythonProcess = null; + private py4j.GatewayServer _gatewayServer = null; + private LLMCallback _llmWorker = null; + private CountDownLatch _workerLatch = null; + + private static final Log LOG = LogFactory.getLog(Connection.class.getName()); /** * Connection constructor, the starting point for any other JMLC API calls. @@ -287,6 +298,103 @@ public PreparedScript prepareScript(String script, Map nsscripts, //return newly create precompiled script return new PreparedScript(rtprog, inputs, outputs, _dmlconf, _cconf); } + + /** + * Loads a HuggingFace model via Python worker for LLM inference. + * Starts a Python subprocess and connects via Py4J. + * + * @param modelName HuggingFace model name (e.g., "distilgpt2") + * @return LLMCallback interface to the Python worker + */ + public LLMCallback loadModel(String modelName) { + if (_llmWorker != null) + return _llmWorker; + try { + // Initialize latch for worker registration + _workerLatch = new CountDownLatch(1); + + // Start Py4J gateway server with callback support + _gatewayServer = new GatewayServer.GatewayServerBuilder() + .entryPoint(this) + .javaPort(25333) + .callbackClient(25334, java.net.InetAddress.getLoopbackAddress()) + .build(); + _gatewayServer.start(); + + // Give gateway time to fully start accepting connections + Thread.sleep(500); + + // Find the Python script - try multiple locations + String pythonScript = findPythonScript(); + LOG.info("Starting LLM worker with script: " + pythonScript); + + _pythonProcess = new ProcessBuilder( + "python", pythonScript, modelName, "25333" + ).redirectErrorStream(true).start(); + + // Read Python process output in background thread + new Thread(() -> { + try (BufferedReader reader = new BufferedReader( + new InputStreamReader(_pythonProcess.getInputStream()))) { + String line; + while ((line = reader.readLine()) != null) { + LOG.info("[LLM Worker] " + line); + } + } catch (IOException e) { + LOG.error("Error reading LLM worker output", e); + } + }).start(); + + // Wait for worker to register with timeout + if (!_workerLatch.await(60, TimeUnit.SECONDS)) { + throw new DMLException("Timeout waiting for LLM worker to register"); + } + + } catch (DMLException e) { + throw e; + } catch (Exception e) { + throw new DMLException("Failed to start LLM worker: " + e.getMessage()); + } + return _llmWorker; + } + + /** + * Called by Python worker to register itself via Py4J. + */ + public void registerWorker(LLMCallback worker) { + _llmWorker = worker; + if (_workerLatch != null) { + _workerLatch.countDown(); + } + LOG.info("LLM worker registered successfully"); + } + + /** + * Finds the Python LLM worker script by checking multiple possible locations. + * @return absolute path to the Python script + * @throws IOException if script cannot be found + */ + private String findPythonScript() throws IOException { + String[] possiblePaths = { + // Relative to project root (when running from IDE or mvn) + "src/main/python/systemds/llm_worker.py", + // Relative to target directory (when running tests) + "../src/main/python/systemds/llm_worker.py", + // Absolute path using system property + System.getProperty("user.dir") + "/src/main/python/systemds/llm_worker.py" + }; + + for (String path : possiblePaths) { + java.io.File f = new java.io.File(path); + if (f.exists()) { + return f.getAbsolutePath(); + } + } + + // If not found, return the default and let it fail with a clear error + throw new IOException("Cannot find llm_worker.py. Searched: " + + String.join(", ", possiblePaths) + ". Current dir: " + System.getProperty("user.dir")); + } /** * Close connection to SystemDS, which clears the @@ -294,6 +402,17 @@ public PreparedScript prepareScript(String script, Map nsscripts, */ @Override public void close() { + + //shutdown LLM worker if running + if (_pythonProcess != null) { + _pythonProcess.destroy(); + _pythonProcess = null; + } + if (_gatewayServer != null) { + _gatewayServer.shutdown(); + _gatewayServer = null; + } + //clear thread-local configurations ConfigurationManager.clearLocalConfigs(); if( ConfigurationManager.isCodegenEnabled() ) diff --git a/src/main/java/org/apache/sysds/api/jmlc/LLMCallback.java b/src/main/java/org/apache/sysds/api/jmlc/LLMCallback.java new file mode 100644 index 00000000000..68f1767994e --- /dev/null +++ b/src/main/java/org/apache/sysds/api/jmlc/LLMCallback.java @@ -0,0 +1,9 @@ +package org.apache.sysds.api.jmlc; + +/** + * Interface for the Python LLM worker. + * The Python side implements this via Py4J callback. + */ +public interface LLMCallback { + String generate(String prompt, int maxNewTokens, double temperature, double topP); +} \ No newline at end of file diff --git a/src/main/java/org/apache/sysds/api/jmlc/PreparedScript.java b/src/main/java/org/apache/sysds/api/jmlc/PreparedScript.java index 31bb7457227..2e6109d0102 100644 --- a/src/main/java/org/apache/sysds/api/jmlc/PreparedScript.java +++ b/src/main/java/org/apache/sysds/api/jmlc/PreparedScript.java @@ -80,6 +80,9 @@ public class PreparedScript implements ConfigurableAPI private final CompilerConfig _cconf; private HashMap _outVarLineage; + //LLM inference support + private LLMCallback _llmWorker = null; + private PreparedScript(PreparedScript that) { //shallow copy, except for a separate symbol table //and related meta data of reused inputs @@ -160,6 +163,41 @@ public CompilerConfig getCompilerConfig() { return _cconf; } + /** + * Sets the LLM worker callback for text generation. + * + * @param worker the LLM callback interface + */ + public void setLLMWorker(LLMCallback worker) { + _llmWorker = worker; + } + + /** + * Gets the LLM worker callback. + * + * @return the LLM callback interface, or null if not set + */ + public LLMCallback getLLMWorker() { + return _llmWorker; + } + + /** + * Generates text using the LLM worker. + * + * @param prompt the input prompt text + * @param maxNewTokens maximum number of new tokens to generate + * @param temperature sampling temperature (0.0 = deterministic, higher = more random) + * @param topP nucleus sampling probability threshold + * @return generated text + * @throws DMLException if no LLM worker is set + */ + public String generate(String prompt, int maxNewTokens, double temperature, double topP) { + if (_llmWorker == null) { + throw new DMLException("No LLM worker set. Call setLLMWorker() first."); + } + return _llmWorker.generate(prompt, maxNewTokens, temperature, topP); + } + /** * Binds a scalar boolean to a registered input variable. * diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/MatrixBlockDictionary.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/MatrixBlockDictionary.java index 71a4112f157..f6b09d4384a 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/MatrixBlockDictionary.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/MatrixBlockDictionary.java @@ -67,8 +67,6 @@ public class MatrixBlockDictionary extends ADictionary { final private MatrixBlock _data; - static final VectorSpecies SPECIES = DoubleVector.SPECIES_PREFERRED; - /** * Unsafe private constructor that does not check the data validity. USE WITH CAUTION. * @@ -2127,9 +2125,6 @@ private void preaggValuesFromDenseDictDenseAggRangeRange(final int numVals, fina private static void preaggValuesFromDenseDictBlockedIKJ(double[] a, double[] b, double[] ret, int bi, int bk, int bj, int bie, int bke, int cz, int az, int ls, int cut, int sOffT, int eOffT) { - final int vLen = SPECIES.length(); - final DoubleVector vVec = DoubleVector.zero(SPECIES); - final int leftover = (eOffT - sOffT) % vLen; // leftover not vectorized for(int i = bi; i < bie; i++) { final int offI = i * cz; final int offOutT = i * az + bj; @@ -2138,27 +2133,14 @@ private static void preaggValuesFromDenseDictBlockedIKJ(double[] a, double[] b, final int sOff = sOffT + idb; final int eOff = eOffT + idb; final double v = a[offI + k]; - vecInnerLoop(v, b, ret, offOutT, eOff, sOff, leftover, vLen, vVec); + int offOut = offOutT; + for(int j = sOff; j < eOff; j++, offOut++) { + ret[offOut] += v * b[j]; + } } } } - private static void vecInnerLoop(final double v, final double[] b, final double[] ret, final int offOutT, - final int eOff, final int sOff, final int leftover, final int vLen, DoubleVector vVec) { - int offOut = offOutT; - vVec = vVec.broadcast(v); - final int end = eOff - leftover; - for(int j = sOff; j < end; j += vLen, offOut += vLen) { - DoubleVector res = DoubleVector.fromArray(SPECIES, ret, offOut); - DoubleVector bVec = DoubleVector.fromArray(SPECIES, b, j); - vVec.fma(bVec, res).intoArray(ret, offOut); - } - for(int j = end; j < eOff; j++, offOut++) { - ret[offOut] += v * b[j]; - } - - } - private void preaggValuesFromDenseDictDenseAggRangeGeneric(final int numVals, final IColIndex colIndexes, final int s, final int e, final double[] b, final int cut, final double[] ret) { final int cz = colIndexes.size(); diff --git a/src/main/python/systemds/llm_worker.py b/src/main/python/systemds/llm_worker.py new file mode 100644 index 00000000000..6aed4d1fbae --- /dev/null +++ b/src/main/python/systemds/llm_worker.py @@ -0,0 +1,63 @@ +""" +SystemDS LLM Worker — Python side of the Py4J bridge. +Java starts this script, then calls generate() via Py4J. +""" +import sys +import torch +from transformers import AutoTokenizer, AutoModelForCausalLM +from py4j.java_gateway import JavaGateway, GatewayParameters, CallbackServerParameters + +class LLMWorker: + def __init__(self, model_name="distilgpt2"): + print(f"Loading model: {model_name}", flush=True) + self.tokenizer = AutoTokenizer.from_pretrained(model_name) + self.model = AutoModelForCausalLM.from_pretrained(model_name) + self.model.eval() + print(f"Model loaded: {model_name}", flush=True) + + def generate(self, prompt, max_new_tokens=50, temperature=0.7, top_p=0.9): + inputs = self.tokenizer(prompt, return_tensors="pt") + with torch.no_grad(): + outputs = self.model.generate( + **inputs, + max_new_tokens=int(max_new_tokens), + temperature=float(temperature), + top_p=float(top_p), + do_sample=float(temperature) > 0.0 + ) + new_tokens = outputs[0][inputs["input_ids"].shape[1]:] + return self.tokenizer.decode(new_tokens, skip_special_tokens=True) + + class Java: + implements = ["org.apache.sysds.api.jmlc.LLMCallback"] + +if __name__ == "__main__": + model_name = sys.argv[1] if len(sys.argv) > 1 else "distilgpt2" + java_port = int(sys.argv[2]) if len(sys.argv) > 2 else 25333 + + print(f"Starting LLM worker, connecting to Java on port {java_port}", flush=True) + + worker = LLMWorker(model_name) + + # Connect to Java's GatewayServer and register this worker + # The callback_server starts a server on Python's side for Java to call back + # Use port 25334 which Java's CallbackClient expects + gateway = JavaGateway( + gateway_parameters=GatewayParameters(port=java_port), + callback_server_parameters=CallbackServerParameters(port=25334) + ) + + print(f"Python callback server started on port 25334", flush=True) + + gateway.entry_point.registerWorker(worker) + print("Worker registered with Java, waiting for requests...", flush=True) + + # Keep the worker alive to handle callbacks from Java + # The callback server runs in a daemon thread, so we need to block here + import threading + shutdown_event = threading.Event() + try: + # Wait indefinitely until Java closes the connection or kills the process + shutdown_event.wait() + except KeyboardInterrupt: + print("Worker shutting down", flush=True) \ No newline at end of file diff --git a/src/test/java/org/apache/sysds/test/functions/jmlc/JMLCLLMInferenceTest.java b/src/test/java/org/apache/sysds/test/functions/jmlc/JMLCLLMInferenceTest.java new file mode 100644 index 00000000000..c28dc768a95 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/jmlc/JMLCLLMInferenceTest.java @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.jmlc; + +import org.apache.sysds.api.jmlc.Connection; +import org.apache.sysds.api.jmlc.LLMCallback; +import org.apache.sysds.api.jmlc.PreparedScript; +import org.apache.sysds.test.AutomatedTestBase; +import org.junit.Assert; +import org.junit.Test; + +/** + * Test LLM inference capabilities via JMLC API. + * This test requires Python with transformers and torch installed. + */ +public class JMLCLLMInferenceTest extends AutomatedTestBase { + private final static String TEST_NAME = "JMLCLLMInferenceTest"; + private final static String TEST_DIR = "functions/jmlc/"; + + @Override + public void setUp() { + addTestConfiguration(TEST_DIR, TEST_NAME); + getAndLoadTestConfiguration(TEST_NAME); + } + + @Test + public void testLLMInference() { + Connection conn = null; + try { + // Create a connection + conn = new Connection(); + + // Load the LLM model via Python worker + LLMCallback llmWorker = conn.loadModel("distilgpt2"); + Assert.assertNotNull("LLM worker should not be null", llmWorker); + + // Create a PreparedScript with a dummy script + String dummyScript = "x = 1;\nwrite(x, './tmp/x');"; + PreparedScript ps = conn.prepareScript(dummyScript, new String[]{}, new String[]{"x"}); + + // Set the LLM worker on the PreparedScript + ps.setLLMWorker(llmWorker); + + // Generate text using the LLM + String prompt = "The meaning of life is"; + String result = ps.generate(prompt, 20, 0.7, 0.9); + + // Assert the result is not null and not empty + Assert.assertNotNull("Generated text should not be null", result); + Assert.assertFalse("Generated text should not be empty", result.isEmpty()); + + System.out.println("Prompt: " + prompt); + System.out.println("Generated: " + result); + + } catch (Exception e) { + // Skip test if Python/transformers not available + System.out.println("Skipping LLM test:"); + e.printStackTrace(); + org.junit.Assume.assumeNoException("LLM dependencies not available", e); + } finally { + if (conn != null) { + conn.close(); + } + } + } +} From 47dd0db1dfd95caecce251099bda101182359c0f Mon Sep 17 00:00:00 2001 From: Kubra Aksu <67320250+kubraaksux@users.noreply.github.com> Date: Fri, 13 Feb 2026 17:45:02 +0100 Subject: [PATCH 2/8] Refactor loadModel to accept worker script path as parameter - Connection.java: Changed loadModel(modelName) to loadModel(modelName, workerScriptPath) - Connection.java: Removed findPythonScript() method - LLMCallback.java: Added Javadoc for generate() method - JMLCLLMInferenceTest.java: Updated to pass script path to loadModel() --- .../org/apache/sysds/api/jmlc/Connection.java | 48 ++++--------------- .../apache/sysds/api/jmlc/LLMCallback.java | 12 ++++- src/main/python/systemds/llm_worker.py | 2 +- .../functions/jmlc/JMLCLLMInferenceTest.java | 20 ++++---- 4 files changed, 30 insertions(+), 52 deletions(-) diff --git a/src/main/java/org/apache/sysds/api/jmlc/Connection.java b/src/main/java/org/apache/sysds/api/jmlc/Connection.java index c58241a222b..2ec7754d298 100644 --- a/src/main/java/org/apache/sysds/api/jmlc/Connection.java +++ b/src/main/java/org/apache/sysds/api/jmlc/Connection.java @@ -304,16 +304,17 @@ public PreparedScript prepareScript(String script, Map nsscripts, * Starts a Python subprocess and connects via Py4J. * * @param modelName HuggingFace model name (e.g., "distilgpt2") + * @param workerScriptPath path to the Python worker script (llm_worker.py) * @return LLMCallback interface to the Python worker */ - public LLMCallback loadModel(String modelName) { + public LLMCallback loadModel(String modelName, String workerScriptPath) { if (_llmWorker != null) return _llmWorker; try { - // Initialize latch for worker registration + //initialize latch for worker registration _workerLatch = new CountDownLatch(1); - // Start Py4J gateway server with callback support + //start Py4J gateway server with callback support _gatewayServer = new GatewayServer.GatewayServerBuilder() .entryPoint(this) .javaPort(25333) @@ -321,18 +322,16 @@ public LLMCallback loadModel(String modelName) { .build(); _gatewayServer.start(); - // Give gateway time to fully start accepting connections + //give gateway time to start Thread.sleep(500); - // Find the Python script - try multiple locations - String pythonScript = findPythonScript(); - LOG.info("Starting LLM worker with script: " + pythonScript); - + //start python worker process + LOG.info("Starting LLM worker with script: " + workerScriptPath); _pythonProcess = new ProcessBuilder( - "python", pythonScript, modelName, "25333" + "python", workerScriptPath, modelName, "25333" ).redirectErrorStream(true).start(); - // Read Python process output in background thread + //read python output in background thread new Thread(() -> { try (BufferedReader reader = new BufferedReader( new InputStreamReader(_pythonProcess.getInputStream()))) { @@ -345,7 +344,7 @@ public LLMCallback loadModel(String modelName) { } }).start(); - // Wait for worker to register with timeout + //wait for worker to register if (!_workerLatch.await(60, TimeUnit.SECONDS)) { throw new DMLException("Timeout waiting for LLM worker to register"); } @@ -369,33 +368,6 @@ public void registerWorker(LLMCallback worker) { LOG.info("LLM worker registered successfully"); } - /** - * Finds the Python LLM worker script by checking multiple possible locations. - * @return absolute path to the Python script - * @throws IOException if script cannot be found - */ - private String findPythonScript() throws IOException { - String[] possiblePaths = { - // Relative to project root (when running from IDE or mvn) - "src/main/python/systemds/llm_worker.py", - // Relative to target directory (when running tests) - "../src/main/python/systemds/llm_worker.py", - // Absolute path using system property - System.getProperty("user.dir") + "/src/main/python/systemds/llm_worker.py" - }; - - for (String path : possiblePaths) { - java.io.File f = new java.io.File(path); - if (f.exists()) { - return f.getAbsolutePath(); - } - } - - // If not found, return the default and let it fail with a clear error - throw new IOException("Cannot find llm_worker.py. Searched: " + - String.join(", ", possiblePaths) + ". Current dir: " + System.getProperty("user.dir")); - } - /** * Close connection to SystemDS, which clears the * thread-local DML and compiler configurations. diff --git a/src/main/java/org/apache/sysds/api/jmlc/LLMCallback.java b/src/main/java/org/apache/sysds/api/jmlc/LLMCallback.java index 68f1767994e..09ee8debb29 100644 --- a/src/main/java/org/apache/sysds/api/jmlc/LLMCallback.java +++ b/src/main/java/org/apache/sysds/api/jmlc/LLMCallback.java @@ -5,5 +5,15 @@ * The Python side implements this via Py4J callback. */ public interface LLMCallback { - String generate(String prompt, int maxNewTokens, double temperature, double topP); + + /** + * Generates text using the LLM model. + * + * @param prompt the input prompt text + * @param maxNewTokens maximum number of new tokens to generate + * @param temperature sampling temperature (0.0 = deterministic, higher = more random) + * @param topP nucleus sampling probability threshold + * @return generated text continuation + */ + String generate(String prompt, int maxNewTokens, double temperature, double topP); } \ No newline at end of file diff --git a/src/main/python/systemds/llm_worker.py b/src/main/python/systemds/llm_worker.py index 6aed4d1fbae..57872b37f82 100644 --- a/src/main/python/systemds/llm_worker.py +++ b/src/main/python/systemds/llm_worker.py @@ -1,5 +1,5 @@ """ -SystemDS LLM Worker — Python side of the Py4J bridge. +SystemDS LLM Worker - Python side of the Py4J bridge. Java starts this script, then calls generate() via Py4J. """ import sys diff --git a/src/test/java/org/apache/sysds/test/functions/jmlc/JMLCLLMInferenceTest.java b/src/test/java/org/apache/sysds/test/functions/jmlc/JMLCLLMInferenceTest.java index c28dc768a95..5909b0cef4c 100644 --- a/src/test/java/org/apache/sysds/test/functions/jmlc/JMLCLLMInferenceTest.java +++ b/src/test/java/org/apache/sysds/test/functions/jmlc/JMLCLLMInferenceTest.java @@ -44,25 +44,21 @@ public void setUp() { public void testLLMInference() { Connection conn = null; try { - // Create a connection + //create connection and load model conn = new Connection(); - - // Load the LLM model via Python worker - LLMCallback llmWorker = conn.loadModel("distilgpt2"); + LLMCallback llmWorker = conn.loadModel("distilgpt2", "src/main/python/systemds/llm_worker.py"); Assert.assertNotNull("LLM worker should not be null", llmWorker); - // Create a PreparedScript with a dummy script - String dummyScript = "x = 1;\nwrite(x, './tmp/x');"; - PreparedScript ps = conn.prepareScript(dummyScript, new String[]{}, new String[]{"x"}); - - // Set the LLM worker on the PreparedScript + //create prepared script and set llm worker + String script = "x = 1;\nwrite(x, './tmp/x');"; + PreparedScript ps = conn.prepareScript(script, new String[]{}, new String[]{"x"}); ps.setLLMWorker(llmWorker); - // Generate text using the LLM + //generate text using llm String prompt = "The meaning of life is"; String result = ps.generate(prompt, 20, 0.7, 0.9); - // Assert the result is not null and not empty + //verify result Assert.assertNotNull("Generated text should not be null", result); Assert.assertFalse("Generated text should not be empty", result.isEmpty()); @@ -70,7 +66,7 @@ public void testLLMInference() { System.out.println("Generated: " + result); } catch (Exception e) { - // Skip test if Python/transformers not available + //skip test if dependencies not available System.out.println("Skipping LLM test:"); e.printStackTrace(); org.junit.Assume.assumeNoException("LLM dependencies not available", e); From 672a3faa2bc52b0b5687b9c510e0e8dad8d12bce Mon Sep 17 00:00:00 2001 From: Kubra Aksu <67320250+kubraaksux@users.noreply.github.com> Date: Fri, 13 Feb 2026 18:20:51 +0100 Subject: [PATCH 3/8] Add dynamic port allocation and improve resource cleanup - Connection.java: Auto-find available ports for Py4J communication - Connection.java: Add loadModel() overload for manual port override - Connection.java: Use destroyForcibly() with waitFor() for clean shutdown - llm_worker.py: Accept python_port as command line argument --- .../org/apache/sysds/api/jmlc/Connection.java | 56 ++++++++++++++++--- src/main/python/systemds/llm_worker.py | 11 ++-- 2 files changed, 52 insertions(+), 15 deletions(-) diff --git a/src/main/java/org/apache/sysds/api/jmlc/Connection.java b/src/main/java/org/apache/sysds/api/jmlc/Connection.java index 2ec7754d298..13dcf2a0247 100644 --- a/src/main/java/org/apache/sysds/api/jmlc/Connection.java +++ b/src/main/java/org/apache/sysds/api/jmlc/Connection.java @@ -301,13 +301,30 @@ public PreparedScript prepareScript(String script, Map nsscripts, /** * Loads a HuggingFace model via Python worker for LLM inference. - * Starts a Python subprocess and connects via Py4J. + * Uses auto-detected available ports for Py4J communication. * - * @param modelName HuggingFace model name (e.g., "distilgpt2") - * @param workerScriptPath path to the Python worker script (llm_worker.py) + * @param modelName HuggingFace model name + * @param workerScriptPath path to the Python worker script * @return LLMCallback interface to the Python worker */ public LLMCallback loadModel(String modelName, String workerScriptPath) { + //auto-find available ports + int javaPort = findAvailablePort(); + int pythonPort = findAvailablePort(); + return loadModel(modelName, workerScriptPath, javaPort, pythonPort); + } + + /** + * Loads a HuggingFace model via Python worker for LLM inference. + * Starts a Python subprocess and connects via Py4J. + * + * @param modelName HuggingFace model name + * @param workerScriptPath path to the Python worker script + * @param javaPort port for Java gateway server + * @param pythonPort port for Python callback server + * @return LLMCallback interface to the Python worker + */ + public LLMCallback loadModel(String modelName, String workerScriptPath, int javaPort, int pythonPort) { if (_llmWorker != null) return _llmWorker; try { @@ -317,18 +334,20 @@ public LLMCallback loadModel(String modelName, String workerScriptPath) { //start Py4J gateway server with callback support _gatewayServer = new GatewayServer.GatewayServerBuilder() .entryPoint(this) - .javaPort(25333) - .callbackClient(25334, java.net.InetAddress.getLoopbackAddress()) + .javaPort(javaPort) + .callbackClient(pythonPort, java.net.InetAddress.getLoopbackAddress()) .build(); _gatewayServer.start(); //give gateway time to start Thread.sleep(500); - //start python worker process - LOG.info("Starting LLM worker with script: " + workerScriptPath); + //start python worker process with both ports + LOG.info("Starting LLM worker with script: " + workerScriptPath + + " (javaPort=" + javaPort + ", pythonPort=" + pythonPort + ")"); _pythonProcess = new ProcessBuilder( - "python", workerScriptPath, modelName, "25333" + "python", workerScriptPath, modelName, + String.valueOf(javaPort), String.valueOf(pythonPort) ).redirectErrorStream(true).start(); //read python output in background thread @@ -357,6 +376,19 @@ public LLMCallback loadModel(String modelName, String workerScriptPath) { return _llmWorker; } + /** + * Finds an available port on the local machine. + * @return available port number + */ + private int findAvailablePort() { + try (java.net.ServerSocket socket = new java.net.ServerSocket(0)) { + socket.setReuseAddress(true); + return socket.getLocalPort(); + } catch (IOException e) { + throw new DMLException("Failed to find available port: " + e.getMessage()); + } + } + /** * Called by Python worker to register itself via Py4J. */ @@ -377,13 +409,19 @@ public void close() { //shutdown LLM worker if running if (_pythonProcess != null) { - _pythonProcess.destroy(); + _pythonProcess.destroyForcibly(); + try { + _pythonProcess.waitFor(5, TimeUnit.SECONDS); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } _pythonProcess = null; } if (_gatewayServer != null) { _gatewayServer.shutdown(); _gatewayServer = null; } + _llmWorker = null; //clear thread-local configurations ConfigurationManager.clearLocalConfigs(); diff --git a/src/main/python/systemds/llm_worker.py b/src/main/python/systemds/llm_worker.py index 57872b37f82..9b5dd4a9155 100644 --- a/src/main/python/systemds/llm_worker.py +++ b/src/main/python/systemds/llm_worker.py @@ -34,20 +34,19 @@ class Java: if __name__ == "__main__": model_name = sys.argv[1] if len(sys.argv) > 1 else "distilgpt2" java_port = int(sys.argv[2]) if len(sys.argv) > 2 else 25333 + python_port = int(sys.argv[3]) if len(sys.argv) > 3 else 25334 - print(f"Starting LLM worker, connecting to Java on port {java_port}", flush=True) + print(f"Starting LLM worker (javaPort={java_port}, pythonPort={python_port})", flush=True) worker = LLMWorker(model_name) - # Connect to Java's GatewayServer and register this worker - # The callback_server starts a server on Python's side for Java to call back - # Use port 25334 which Java's CallbackClient expects + #connect to Java's GatewayServer and register this worker gateway = JavaGateway( gateway_parameters=GatewayParameters(port=java_port), - callback_server_parameters=CallbackServerParameters(port=25334) + callback_server_parameters=CallbackServerParameters(port=python_port) ) - print(f"Python callback server started on port 25334", flush=True) + print(f"Python callback server started on port {python_port}", flush=True) gateway.entry_point.registerWorker(worker) print("Worker registered with Java, waiting for requests...", flush=True) From dacdc1c1cae94d55a1c7fd7c6c525b88cdeb861d Mon Sep 17 00:00:00 2001 From: Kubra Aksu <67320250+kubraaksux@users.noreply.github.com> Date: Fri, 13 Feb 2026 22:53:53 +0100 Subject: [PATCH 4/8] Move llm_worker.py to fix Python module collision Move worker script from src/main/python/systemds/ to src/main/python/ to avoid shadowing Python stdlib operator module. --- src/main/java/org/apache/sysds/api/jmlc/Connection.java | 2 +- src/main/python/{systemds => }/llm_worker.py | 0 .../apache/sysds/test/functions/jmlc/JMLCLLMInferenceTest.java | 2 +- 3 files changed, 2 insertions(+), 2 deletions(-) rename src/main/python/{systemds => }/llm_worker.py (100%) diff --git a/src/main/java/org/apache/sysds/api/jmlc/Connection.java b/src/main/java/org/apache/sysds/api/jmlc/Connection.java index 13dcf2a0247..9dcc15937a0 100644 --- a/src/main/java/org/apache/sysds/api/jmlc/Connection.java +++ b/src/main/java/org/apache/sysds/api/jmlc/Connection.java @@ -346,7 +346,7 @@ public LLMCallback loadModel(String modelName, String workerScriptPath, int java LOG.info("Starting LLM worker with script: " + workerScriptPath + " (javaPort=" + javaPort + ", pythonPort=" + pythonPort + ")"); _pythonProcess = new ProcessBuilder( - "python", workerScriptPath, modelName, + "python3", workerScriptPath, modelName, String.valueOf(javaPort), String.valueOf(pythonPort) ).redirectErrorStream(true).start(); diff --git a/src/main/python/systemds/llm_worker.py b/src/main/python/llm_worker.py similarity index 100% rename from src/main/python/systemds/llm_worker.py rename to src/main/python/llm_worker.py diff --git a/src/test/java/org/apache/sysds/test/functions/jmlc/JMLCLLMInferenceTest.java b/src/test/java/org/apache/sysds/test/functions/jmlc/JMLCLLMInferenceTest.java index 5909b0cef4c..ac3f3e7e069 100644 --- a/src/test/java/org/apache/sysds/test/functions/jmlc/JMLCLLMInferenceTest.java +++ b/src/test/java/org/apache/sysds/test/functions/jmlc/JMLCLLMInferenceTest.java @@ -46,7 +46,7 @@ public void testLLMInference() { try { //create connection and load model conn = new Connection(); - LLMCallback llmWorker = conn.loadModel("distilgpt2", "src/main/python/systemds/llm_worker.py"); + LLMCallback llmWorker = conn.loadModel("distilgpt2", "src/main/python/llm_worker.py"); Assert.assertNotNull("LLM worker should not be null", llmWorker); //create prepared script and set llm worker From 29f657c2a55ad3d29c02b6f75ff19c9a2a0b1e9e Mon Sep 17 00:00:00 2001 From: Kubra Aksu <67320250+kubraaksux@users.noreply.github.com> Date: Sat, 14 Feb 2026 16:07:39 +0100 Subject: [PATCH 5/8] Use python3 with fallback to python in Connection.java --- .../org/apache/sysds/api/jmlc/Connection.java | 24 +++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/src/main/java/org/apache/sysds/api/jmlc/Connection.java b/src/main/java/org/apache/sysds/api/jmlc/Connection.java index 9dcc15937a0..da0cb0be1d2 100644 --- a/src/main/java/org/apache/sysds/api/jmlc/Connection.java +++ b/src/main/java/org/apache/sysds/api/jmlc/Connection.java @@ -343,10 +343,11 @@ public LLMCallback loadModel(String modelName, String workerScriptPath, int java Thread.sleep(500); //start python worker process with both ports + String pythonCmd = findPythonCommand(); LOG.info("Starting LLM worker with script: " + workerScriptPath + - " (javaPort=" + javaPort + ", pythonPort=" + pythonPort + ")"); + " (python=" + pythonCmd + ", javaPort=" + javaPort + ", pythonPort=" + pythonPort + ")"); _pythonProcess = new ProcessBuilder( - "python3", workerScriptPath, modelName, + pythonCmd, workerScriptPath, modelName, String.valueOf(javaPort), String.valueOf(pythonPort) ).redirectErrorStream(true).start(); @@ -376,6 +377,25 @@ public LLMCallback loadModel(String modelName, String workerScriptPath, int java return _llmWorker; } + /** + * Finds the available Python command, trying python3 first then python. + * @return python command name + */ + private static String findPythonCommand() { + for (String cmd : new String[]{"python3", "python"}) { + try { + Process p = new ProcessBuilder(cmd, "--version") + .redirectErrorStream(true).start(); + int exitCode = p.waitFor(); + if (exitCode == 0) + return cmd; + } catch (Exception e) { + //command not found, try next + } + } + throw new DMLException("No Python installation found (tried python3, python)"); + } + /** * Finds an available port on the local machine. * @return available port number From e40e4f232035643aee11f6154bf3211e0416e35f Mon Sep 17 00:00:00 2001 From: Kubra Aksu <67320250+kubraaksux@users.noreply.github.com> Date: Sat, 14 Feb 2026 17:04:10 +0100 Subject: [PATCH 6/8] Add batch inference with FrameBlock and metrics support --- .../apache/sysds/api/jmlc/PreparedScript.java | 62 ++++++++++++ .../functions/jmlc/JMLCLLMInferenceTest.java | 95 +++++++++++++++++++ 2 files changed, 157 insertions(+) diff --git a/src/main/java/org/apache/sysds/api/jmlc/PreparedScript.java b/src/main/java/org/apache/sysds/api/jmlc/PreparedScript.java index 2e6109d0102..f8a61cfaf20 100644 --- a/src/main/java/org/apache/sysds/api/jmlc/PreparedScript.java +++ b/src/main/java/org/apache/sysds/api/jmlc/PreparedScript.java @@ -198,6 +198,68 @@ public String generate(String prompt, int maxNewTokens, double temperature, doub return _llmWorker.generate(prompt, maxNewTokens, temperature, topP); } + /** + * Generates text for multiple prompts and returns results as a FrameBlock. + * The FrameBlock has two columns: [prompt, generated_text]. + * + * @param prompts array of input prompt texts + * @param maxNewTokens maximum number of new tokens to generate + * @param temperature sampling temperature + * @param topP nucleus sampling probability threshold + * @return FrameBlock with columns [prompt, generated_text] + */ + public FrameBlock generateBatch(String[] prompts, int maxNewTokens, double temperature, double topP) { + if (_llmWorker == null) { + throw new DMLException("No LLM worker set. Call setLLMWorker() first."); + } + //generate text for each prompt + String[][] data = new String[prompts.length][2]; + for (int i = 0; i < prompts.length; i++) { + data[i][0] = prompts[i]; + data[i][1] = _llmWorker.generate(prompts[i], maxNewTokens, temperature, topP); + } + //create FrameBlock with string schema + ValueType[] schema = new ValueType[]{ValueType.STRING, ValueType.STRING}; + String[] colNames = new String[]{"prompt", "generated_text"}; + FrameBlock fb = new FrameBlock(schema, colNames); + for (String[] row : data) + fb.appendRow(row); + return fb; + } + + /** + * Generates text for multiple prompts and returns results with timing metrics. + * The FrameBlock has three columns: [prompt, generated_text, time_ms]. + * + * @param prompts array of input prompt texts + * @param maxNewTokens maximum number of new tokens to generate + * @param temperature sampling temperature + * @param topP nucleus sampling probability threshold + * @return FrameBlock with columns [prompt, generated_text, time_ms] + */ + public FrameBlock generateBatchWithMetrics(String[] prompts, int maxNewTokens, double temperature, double topP) { + if (_llmWorker == null) { + throw new DMLException("No LLM worker set. Call setLLMWorker() first."); + } + //generate text for each prompt with timing + String[][] data = new String[prompts.length][3]; + for (int i = 0; i < prompts.length; i++) { + long start = System.nanoTime(); + String result = _llmWorker.generate(prompts[i], maxNewTokens, temperature, topP); + long elapsed = (System.nanoTime() - start) / 1_000_000; + data[i][0] = prompts[i]; + data[i][1] = result; + data[i][2] = String.valueOf(elapsed); + } + //create FrameBlock with schema + ValueType[] schema = new ValueType[]{ValueType.STRING, ValueType.STRING, ValueType.INT64}; + String[] colNames = new String[]{"prompt", "generated_text", "time_ms"}; + FrameBlock fb = new FrameBlock(schema, colNames); + for (String[] row : data) + fb.appendRow(row); + return fb; + } + /** * Binds a scalar boolean to a registered input variable. * diff --git a/src/test/java/org/apache/sysds/test/functions/jmlc/JMLCLLMInferenceTest.java b/src/test/java/org/apache/sysds/test/functions/jmlc/JMLCLLMInferenceTest.java index ac3f3e7e069..0fab3134703 100644 --- a/src/test/java/org/apache/sysds/test/functions/jmlc/JMLCLLMInferenceTest.java +++ b/src/test/java/org/apache/sysds/test/functions/jmlc/JMLCLLMInferenceTest.java @@ -22,6 +22,7 @@ import org.apache.sysds.api.jmlc.Connection; import org.apache.sysds.api.jmlc.LLMCallback; import org.apache.sysds.api.jmlc.PreparedScript; +import org.apache.sysds.runtime.frame.data.FrameBlock; import org.apache.sysds.test.AutomatedTestBase; import org.junit.Assert; import org.junit.Test; @@ -76,4 +77,98 @@ public void testLLMInference() { } } } + + @Test + public void testBatchInference() { + Connection conn = null; + try { + //create connection and load model + conn = new Connection(); + LLMCallback llmWorker = conn.loadModel("distilgpt2", "src/main/python/llm_worker.py"); + + //create prepared script and set llm worker + String script = "x = 1;\nwrite(x, './tmp/x');"; + PreparedScript ps = conn.prepareScript(script, new String[]{}, new String[]{"x"}); + ps.setLLMWorker(llmWorker); + + //batch generate with multiple prompts + String[] prompts = { + "The meaning of life is", + "Machine learning is", + "Apache SystemDS enables" + }; + FrameBlock result = ps.generateBatch(prompts, 20, 0.7, 0.9); + + //verify FrameBlock structure + Assert.assertNotNull("Batch result should not be null", result); + Assert.assertEquals("Should have 3 rows", 3, result.getNumRows()); + Assert.assertEquals("Should have 2 columns", 2, result.getNumColumns()); + + //verify each row has prompt and generated text + for (int i = 0; i < prompts.length; i++) { + String prompt = (String) result.get(i, 0); + String generated = (String) result.get(i, 1); + Assert.assertEquals("Prompt should match", prompts[i], prompt); + Assert.assertNotNull("Generated text should not be null", generated); + Assert.assertFalse("Generated text should not be empty", generated.isEmpty()); + System.out.println("Prompt: " + prompt); + System.out.println("Generated: " + generated); + } + + } catch (Exception e) { + System.out.println("Skipping batch LLM test:"); + e.printStackTrace(); + org.junit.Assume.assumeNoException("LLM dependencies not available", e); + } finally { + if (conn != null) { + conn.close(); + } + } + } + + @Test + public void testBatchWithMetrics() { + Connection conn = null; + try { + //create connection and load model + conn = new Connection(); + LLMCallback llmWorker = conn.loadModel("distilgpt2", "src/main/python/llm_worker.py"); + + //create prepared script and set llm worker + String script = "x = 1;\nwrite(x, './tmp/x');"; + PreparedScript ps = conn.prepareScript(script, new String[]{}, new String[]{"x"}); + ps.setLLMWorker(llmWorker); + + //batch generate with metrics + String[] prompts = {"The meaning of life is", "Data science is"}; + FrameBlock result = ps.generateBatchWithMetrics(prompts, 20, 0.7, 0.9); + + //verify FrameBlock structure with metrics + Assert.assertNotNull("Metrics result should not be null", result); + Assert.assertEquals("Should have 2 rows", 2, result.getNumRows()); + Assert.assertEquals("Should have 3 columns", 3, result.getNumColumns()); + + //verify metrics column contains timing data + for (int i = 0; i < prompts.length; i++) { + String prompt = (String) result.get(i, 0); + String generated = (String) result.get(i, 1); + long timeMs = Long.parseLong(result.get(i, 2).toString()); + Assert.assertEquals("Prompt should match", prompts[i], prompt); + Assert.assertFalse("Generated text should not be empty", generated.isEmpty()); + Assert.assertTrue("Time should be positive", timeMs > 0); + System.out.println("Prompt: " + prompt); + System.out.println("Generated: " + generated); + System.out.println("Time: " + timeMs + "ms"); + } + + } catch (Exception e) { + System.out.println("Skipping metrics LLM test:"); + e.printStackTrace(); + org.junit.Assume.assumeNoException("LLM dependencies not available", e); + } finally { + if (conn != null) { + conn.close(); + } + } + } } From fdd16849a024685bbbc6b0a98f9836b57a901555 Mon Sep 17 00:00:00 2001 From: Kubra Aksu <67320250+kubraaksux@users.noreply.github.com> Date: Sat, 14 Feb 2026 17:41:57 +0100 Subject: [PATCH 7/8] Clean up test: extract constants and shared setup method --- .../functions/jmlc/JMLCLLMInferenceTest.java | 50 ++++++++----------- 1 file changed, 21 insertions(+), 29 deletions(-) diff --git a/src/test/java/org/apache/sysds/test/functions/jmlc/JMLCLLMInferenceTest.java b/src/test/java/org/apache/sysds/test/functions/jmlc/JMLCLLMInferenceTest.java index 0fab3134703..e54d606c16d 100644 --- a/src/test/java/org/apache/sysds/test/functions/jmlc/JMLCLLMInferenceTest.java +++ b/src/test/java/org/apache/sysds/test/functions/jmlc/JMLCLLMInferenceTest.java @@ -34,6 +34,9 @@ public class JMLCLLMInferenceTest extends AutomatedTestBase { private final static String TEST_NAME = "JMLCLLMInferenceTest"; private final static String TEST_DIR = "functions/jmlc/"; + private final static String MODEL_NAME = "distilgpt2"; + private final static String WORKER_SCRIPT = "src/main/python/llm_worker.py"; + private final static String DML_SCRIPT = "x = 1;\nwrite(x, './tmp/x');"; @Override public void setUp() { @@ -41,19 +44,24 @@ public void setUp() { getAndLoadTestConfiguration(TEST_NAME); } + /** + * Creates a connection, loads the LLM model, and returns a PreparedScript + * with the LLM worker attached. + */ + private PreparedScript createLLMScript(Connection conn) throws Exception { + LLMCallback llmWorker = conn.loadModel(MODEL_NAME, WORKER_SCRIPT); + Assert.assertNotNull("LLM worker should not be null", llmWorker); + PreparedScript ps = conn.prepareScript(DML_SCRIPT, new String[]{}, new String[]{"x"}); + ps.setLLMWorker(llmWorker); + return ps; + } + @Test public void testLLMInference() { Connection conn = null; try { - //create connection and load model conn = new Connection(); - LLMCallback llmWorker = conn.loadModel("distilgpt2", "src/main/python/llm_worker.py"); - Assert.assertNotNull("LLM worker should not be null", llmWorker); - - //create prepared script and set llm worker - String script = "x = 1;\nwrite(x, './tmp/x');"; - PreparedScript ps = conn.prepareScript(script, new String[]{}, new String[]{"x"}); - ps.setLLMWorker(llmWorker); + PreparedScript ps = createLLMScript(conn); //generate text using llm String prompt = "The meaning of life is"; @@ -67,14 +75,12 @@ public void testLLMInference() { System.out.println("Generated: " + result); } catch (Exception e) { - //skip test if dependencies not available System.out.println("Skipping LLM test:"); e.printStackTrace(); org.junit.Assume.assumeNoException("LLM dependencies not available", e); } finally { - if (conn != null) { + if (conn != null) conn.close(); - } } } @@ -82,14 +88,8 @@ public void testLLMInference() { public void testBatchInference() { Connection conn = null; try { - //create connection and load model conn = new Connection(); - LLMCallback llmWorker = conn.loadModel("distilgpt2", "src/main/python/llm_worker.py"); - - //create prepared script and set llm worker - String script = "x = 1;\nwrite(x, './tmp/x');"; - PreparedScript ps = conn.prepareScript(script, new String[]{}, new String[]{"x"}); - ps.setLLMWorker(llmWorker); + PreparedScript ps = createLLMScript(conn); //batch generate with multiple prompts String[] prompts = { @@ -120,9 +120,8 @@ public void testBatchInference() { e.printStackTrace(); org.junit.Assume.assumeNoException("LLM dependencies not available", e); } finally { - if (conn != null) { + if (conn != null) conn.close(); - } } } @@ -130,14 +129,8 @@ public void testBatchInference() { public void testBatchWithMetrics() { Connection conn = null; try { - //create connection and load model conn = new Connection(); - LLMCallback llmWorker = conn.loadModel("distilgpt2", "src/main/python/llm_worker.py"); - - //create prepared script and set llm worker - String script = "x = 1;\nwrite(x, './tmp/x');"; - PreparedScript ps = conn.prepareScript(script, new String[]{}, new String[]{"x"}); - ps.setLLMWorker(llmWorker); + PreparedScript ps = createLLMScript(conn); //batch generate with metrics String[] prompts = {"The meaning of life is", "Data science is"}; @@ -166,9 +159,8 @@ public void testBatchWithMetrics() { e.printStackTrace(); org.junit.Assume.assumeNoException("LLM dependencies not available", e); } finally { - if (conn != null) { + if (conn != null) conn.close(); - } } } } From b9ba3e05cd2a5c4076b23b7fc0a433158879e3d2 Mon Sep 17 00:00:00 2001 From: Kubra Aksu <67320250+kubraaksux@users.noreply.github.com> Date: Sat, 14 Feb 2026 19:02:19 +0100 Subject: [PATCH 8/8] Add token counts, GPU support, and improve error handling - Add generateWithTokenCount() returning JSON with input/output token counts - Update generateBatchWithMetrics() to include input_tokens and output_tokens columns - Add CUDA auto-detection with device_map=auto for multi-GPU support in llm_worker.py - Check Python process liveness during startup instead of blind 60s timeout --- .../org/apache/sysds/api/jmlc/Connection.java | 20 ++++++++--- .../apache/sysds/api/jmlc/LLMCallback.java | 12 +++++++ .../apache/sysds/api/jmlc/PreparedScript.java | 29 +++++++++------ src/main/python/llm_worker.py | 35 +++++++++++++++++-- .../functions/jmlc/JMLCLLMInferenceTest.java | 11 ++++-- 5 files changed, 86 insertions(+), 21 deletions(-) diff --git a/src/main/java/org/apache/sysds/api/jmlc/Connection.java b/src/main/java/org/apache/sysds/api/jmlc/Connection.java index da0cb0be1d2..21b98c8563e 100644 --- a/src/main/java/org/apache/sysds/api/jmlc/Connection.java +++ b/src/main/java/org/apache/sysds/api/jmlc/Connection.java @@ -352,7 +352,7 @@ public LLMCallback loadModel(String modelName, String workerScriptPath, int java ).redirectErrorStream(true).start(); //read python output in background thread - new Thread(() -> { + Thread outputReader = new Thread(() -> { try (BufferedReader reader = new BufferedReader( new InputStreamReader(_pythonProcess.getInputStream()))) { String line; @@ -362,11 +362,21 @@ public LLMCallback loadModel(String modelName, String workerScriptPath, int java } catch (IOException e) { LOG.error("Error reading LLM worker output", e); } - }).start(); + }); + outputReader.setName("llm-worker-output"); + outputReader.setDaemon(true); + outputReader.start(); - //wait for worker to register - if (!_workerLatch.await(60, TimeUnit.SECONDS)) { - throw new DMLException("Timeout waiting for LLM worker to register"); + //wait for worker to register, checking process liveness periodically + long deadlineNs = System.nanoTime() + TimeUnit.SECONDS.toNanos(60); + while (!_workerLatch.await(2, TimeUnit.SECONDS)) { + if (!_pythonProcess.isAlive()) { + int exitCode = _pythonProcess.exitValue(); + throw new DMLException("LLM worker process died during startup (exit code " + exitCode + ")"); + } + if (System.nanoTime() > deadlineNs) { + throw new DMLException("Timeout waiting for LLM worker to register (60s)"); + } } } catch (DMLException e) { diff --git a/src/main/java/org/apache/sysds/api/jmlc/LLMCallback.java b/src/main/java/org/apache/sysds/api/jmlc/LLMCallback.java index 09ee8debb29..10d9787d992 100644 --- a/src/main/java/org/apache/sysds/api/jmlc/LLMCallback.java +++ b/src/main/java/org/apache/sysds/api/jmlc/LLMCallback.java @@ -16,4 +16,16 @@ public interface LLMCallback { * @return generated text continuation */ String generate(String prompt, int maxNewTokens, double temperature, double topP); + + /** + * Generates text and returns result with token counts as a JSON string. + * Format: {"text": "...", "input_tokens": N, "output_tokens": M} + * + * @param prompt the input prompt text + * @param maxNewTokens maximum number of new tokens to generate + * @param temperature sampling temperature (0.0 = deterministic, higher = more random) + * @param topP nucleus sampling probability threshold + * @return JSON string with generated text and token counts + */ + String generateWithTokenCount(String prompt, int maxNewTokens, double temperature, double topP); } \ No newline at end of file diff --git a/src/main/java/org/apache/sysds/api/jmlc/PreparedScript.java b/src/main/java/org/apache/sysds/api/jmlc/PreparedScript.java index f8a61cfaf20..e664f04ca04 100644 --- a/src/main/java/org/apache/sysds/api/jmlc/PreparedScript.java +++ b/src/main/java/org/apache/sysds/api/jmlc/PreparedScript.java @@ -229,31 +229,40 @@ public FrameBlock generateBatch(String[] prompts, int maxNewTokens, double tempe /** * Generates text for multiple prompts and returns results with timing metrics. - * The FrameBlock has three columns: [prompt, generated_text, time_ms]. + * The FrameBlock has five columns: [prompt, generated_text, time_ms, input_tokens, output_tokens]. * * @param prompts array of input prompt texts * @param maxNewTokens maximum number of new tokens to generate * @param temperature sampling temperature * @param topP nucleus sampling probability threshold - * @return FrameBlock with columns [prompt, generated_text, time_ms] + * @return FrameBlock with columns [prompt, generated_text, time_ms, input_tokens, output_tokens] */ public FrameBlock generateBatchWithMetrics(String[] prompts, int maxNewTokens, double temperature, double topP) { if (_llmWorker == null) { throw new DMLException("No LLM worker set. Call setLLMWorker() first."); } - //generate text for each prompt with timing - String[][] data = new String[prompts.length][3]; + //generate text for each prompt with timing and token counts + String[][] data = new String[prompts.length][5]; for (int i = 0; i < prompts.length; i++) { long start = System.nanoTime(); - String result = _llmWorker.generate(prompts[i], maxNewTokens, temperature, topP); + String json = _llmWorker.generateWithTokenCount(prompts[i], maxNewTokens, temperature, topP); long elapsed = (System.nanoTime() - start) / 1_000_000; - data[i][0] = prompts[i]; - data[i][1] = result; - data[i][2] = String.valueOf(elapsed); + //parse JSON response: {"text": "...", "input_tokens": N, "output_tokens": M} + try { + org.apache.wink.json4j.JSONObject obj = new org.apache.wink.json4j.JSONObject(json); + data[i][0] = prompts[i]; + data[i][1] = obj.getString("text"); + data[i][2] = String.valueOf(elapsed); + data[i][3] = String.valueOf(obj.getInt("input_tokens")); + data[i][4] = String.valueOf(obj.getInt("output_tokens")); + } catch (Exception e) { + throw new DMLException("Failed to parse LLM worker response: " + e.getMessage()); + } } //create FrameBlock with schema - ValueType[] schema = new ValueType[]{ValueType.STRING, ValueType.STRING, ValueType.INT64}; - String[] colNames = new String[]{"prompt", "generated_text", "time_ms"}; + ValueType[] schema = new ValueType[]{ + ValueType.STRING, ValueType.STRING, ValueType.INT64, ValueType.INT64, ValueType.INT64}; + String[] colNames = new String[]{"prompt", "generated_text", "time_ms", "input_tokens", "output_tokens"}; FrameBlock fb = new FrameBlock(schema, colNames); for (String[] row : data) fb.appendRow(row); diff --git a/src/main/python/llm_worker.py b/src/main/python/llm_worker.py index 9b5dd4a9155..27160e27f13 100644 --- a/src/main/python/llm_worker.py +++ b/src/main/python/llm_worker.py @@ -3,6 +3,7 @@ Java starts this script, then calls generate() via Py4J. """ import sys +import json import torch from transformers import AutoTokenizer, AutoModelForCausalLM from py4j.java_gateway import JavaGateway, GatewayParameters, CallbackServerParameters @@ -11,12 +12,20 @@ class LLMWorker: def __init__(self, model_name="distilgpt2"): print(f"Loading model: {model_name}", flush=True) self.tokenizer = AutoTokenizer.from_pretrained(model_name) - self.model = AutoModelForCausalLM.from_pretrained(model_name) + #auto-detect GPU and load model accordingly + if torch.cuda.is_available(): + print(f"CUDA available: {torch.cuda.device_count()} GPU(s)", flush=True) + self.model = AutoModelForCausalLM.from_pretrained( + model_name, device_map="auto", torch_dtype=torch.float16) + self.device = "cuda" + else: + self.model = AutoModelForCausalLM.from_pretrained(model_name) + self.device = "cpu" self.model.eval() - print(f"Model loaded: {model_name}", flush=True) + print(f"Model loaded: {model_name} (device={self.device})", flush=True) def generate(self, prompt, max_new_tokens=50, temperature=0.7, top_p=0.9): - inputs = self.tokenizer(prompt, return_tensors="pt") + inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device) with torch.no_grad(): outputs = self.model.generate( **inputs, @@ -28,6 +37,26 @@ def generate(self, prompt, max_new_tokens=50, temperature=0.7, top_p=0.9): new_tokens = outputs[0][inputs["input_ids"].shape[1]:] return self.tokenizer.decode(new_tokens, skip_special_tokens=True) + def generateWithTokenCount(self, prompt, max_new_tokens=50, temperature=0.7, top_p=0.9): + inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device) + input_token_count = inputs["input_ids"].shape[1] + with torch.no_grad(): + outputs = self.model.generate( + **inputs, + max_new_tokens=int(max_new_tokens), + temperature=float(temperature), + top_p=float(top_p), + do_sample=float(temperature) > 0.0 + ) + new_tokens = outputs[0][input_token_count:] + output_token_count = len(new_tokens) + text = self.tokenizer.decode(new_tokens, skip_special_tokens=True) + return json.dumps({ + "text": text, + "input_tokens": input_token_count, + "output_tokens": output_token_count + }) + class Java: implements = ["org.apache.sysds.api.jmlc.LLMCallback"] diff --git a/src/test/java/org/apache/sysds/test/functions/jmlc/JMLCLLMInferenceTest.java b/src/test/java/org/apache/sysds/test/functions/jmlc/JMLCLLMInferenceTest.java index e54d606c16d..0e07e07d221 100644 --- a/src/test/java/org/apache/sysds/test/functions/jmlc/JMLCLLMInferenceTest.java +++ b/src/test/java/org/apache/sysds/test/functions/jmlc/JMLCLLMInferenceTest.java @@ -136,22 +136,27 @@ public void testBatchWithMetrics() { String[] prompts = {"The meaning of life is", "Data science is"}; FrameBlock result = ps.generateBatchWithMetrics(prompts, 20, 0.7, 0.9); - //verify FrameBlock structure with metrics + //verify FrameBlock structure with metrics and token counts Assert.assertNotNull("Metrics result should not be null", result); Assert.assertEquals("Should have 2 rows", 2, result.getNumRows()); - Assert.assertEquals("Should have 3 columns", 3, result.getNumColumns()); + Assert.assertEquals("Should have 5 columns", 5, result.getNumColumns()); - //verify metrics column contains timing data + //verify metrics columns contain timing and token data for (int i = 0; i < prompts.length; i++) { String prompt = (String) result.get(i, 0); String generated = (String) result.get(i, 1); long timeMs = Long.parseLong(result.get(i, 2).toString()); + long inputTokens = Long.parseLong(result.get(i, 3).toString()); + long outputTokens = Long.parseLong(result.get(i, 4).toString()); Assert.assertEquals("Prompt should match", prompts[i], prompt); Assert.assertFalse("Generated text should not be empty", generated.isEmpty()); Assert.assertTrue("Time should be positive", timeMs > 0); + Assert.assertTrue("Input tokens should be positive", inputTokens > 0); + Assert.assertTrue("Output tokens should be positive", outputTokens > 0); System.out.println("Prompt: " + prompt); System.out.println("Generated: " + generated); System.out.println("Time: " + timeMs + "ms"); + System.out.println("Tokens: " + inputTokens + " in, " + outputTokens + " out"); } } catch (Exception e) {