Skip to content
159 changes: 159 additions & 0 deletions src/main/java/org/apache/sysds/api/jmlc/Connection.java
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,11 @@
import java.util.Arrays;
import java.util.Collections;
import java.util.Map;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.sysds.hops.OptimizerUtils;
Expand Down Expand Up @@ -66,6 +70,7 @@
import org.apache.sysds.runtime.transform.meta.TfMetaUtils;
import org.apache.sysds.runtime.util.CollectionUtils;
import org.apache.sysds.runtime.util.DataConverter;
import py4j.GatewayServer;

/**
* Interaction with SystemDS using the JMLC (Java Machine Learning Connector) API is initiated with
Expand All @@ -91,6 +96,12 @@ public class Connection implements Closeable
private final DMLConfig _dmlconf;
private final CompilerConfig _cconf;
private static FileSystem fs = null;
private Process _pythonProcess = null;
private py4j.GatewayServer _gatewayServer = null;
private LLMCallback _llmWorker = null;
private CountDownLatch _workerLatch = null;

private static final Log LOG = LogFactory.getLog(Connection.class.getName());

/**
* Connection constructor, the starting point for any other JMLC API calls.
Expand Down Expand Up @@ -287,13 +298,161 @@ public PreparedScript prepareScript(String script, Map<String,String> nsscripts,
//return newly create precompiled script
return new PreparedScript(rtprog, inputs, outputs, _dmlconf, _cconf);
}

/**
* Loads a HuggingFace model via Python worker for LLM inference.
* Uses auto-detected available ports for Py4J communication.
*
* @param modelName HuggingFace model name
* @param workerScriptPath path to the Python worker script
* @return LLMCallback interface to the Python worker
*/
public LLMCallback loadModel(String modelName, String workerScriptPath) {
//auto-find available ports
int javaPort = findAvailablePort();
int pythonPort = findAvailablePort();
return loadModel(modelName, workerScriptPath, javaPort, pythonPort);
}

/**
* Loads a HuggingFace model via Python worker for LLM inference.
* Starts a Python subprocess and connects via Py4J.
*
* @param modelName HuggingFace model name
* @param workerScriptPath path to the Python worker script
* @param javaPort port for Java gateway server
* @param pythonPort port for Python callback server
* @return LLMCallback interface to the Python worker
*/
public LLMCallback loadModel(String modelName, String workerScriptPath, int javaPort, int pythonPort) {
if (_llmWorker != null)
return _llmWorker;
try {
//initialize latch for worker registration
_workerLatch = new CountDownLatch(1);

//start Py4J gateway server with callback support
_gatewayServer = new GatewayServer.GatewayServerBuilder()
.entryPoint(this)
.javaPort(javaPort)
.callbackClient(pythonPort, java.net.InetAddress.getLoopbackAddress())
.build();
_gatewayServer.start();

//give gateway time to start
Thread.sleep(500);

//start python worker process with both ports
String pythonCmd = findPythonCommand();
LOG.info("Starting LLM worker with script: " + workerScriptPath +
" (python=" + pythonCmd + ", javaPort=" + javaPort + ", pythonPort=" + pythonPort + ")");
_pythonProcess = new ProcessBuilder(
pythonCmd, workerScriptPath, modelName,
String.valueOf(javaPort), String.valueOf(pythonPort)
).redirectErrorStream(true).start();

//read python output in background thread
Thread outputReader = new Thread(() -> {
try (BufferedReader reader = new BufferedReader(
new InputStreamReader(_pythonProcess.getInputStream()))) {
String line;
while ((line = reader.readLine()) != null) {
LOG.info("[LLM Worker] " + line);
}
} catch (IOException e) {
LOG.error("Error reading LLM worker output", e);
}
});
outputReader.setName("llm-worker-output");
outputReader.setDaemon(true);
outputReader.start();

//wait for worker to register, checking process liveness periodically
long deadlineNs = System.nanoTime() + TimeUnit.SECONDS.toNanos(60);
while (!_workerLatch.await(2, TimeUnit.SECONDS)) {
if (!_pythonProcess.isAlive()) {
int exitCode = _pythonProcess.exitValue();
throw new DMLException("LLM worker process died during startup (exit code " + exitCode + ")");
}
if (System.nanoTime() > deadlineNs) {
throw new DMLException("Timeout waiting for LLM worker to register (60s)");
}
}

} catch (DMLException e) {
throw e;
} catch (Exception e) {
throw new DMLException("Failed to start LLM worker: " + e.getMessage());
}
return _llmWorker;
}

/**
* Finds the available Python command, trying python3 first then python.
* @return python command name
*/
private static String findPythonCommand() {
for (String cmd : new String[]{"python3", "python"}) {
try {
Process p = new ProcessBuilder(cmd, "--version")
.redirectErrorStream(true).start();
int exitCode = p.waitFor();
if (exitCode == 0)
return cmd;
} catch (Exception e) {
//command not found, try next
}
}
throw new DMLException("No Python installation found (tried python3, python)");
}

/**
* Finds an available port on the local machine.
* @return available port number
*/
private int findAvailablePort() {
try (java.net.ServerSocket socket = new java.net.ServerSocket(0)) {
socket.setReuseAddress(true);
return socket.getLocalPort();
} catch (IOException e) {
throw new DMLException("Failed to find available port: " + e.getMessage());
}
}

/**
* Called by Python worker to register itself via Py4J.
*/
public void registerWorker(LLMCallback worker) {
_llmWorker = worker;
if (_workerLatch != null) {
_workerLatch.countDown();
}
LOG.info("LLM worker registered successfully");
}

/**
* Close connection to SystemDS, which clears the
* thread-local DML and compiler configurations.
*/
@Override
public void close() {

//shutdown LLM worker if running
if (_pythonProcess != null) {
_pythonProcess.destroyForcibly();
try {
_pythonProcess.waitFor(5, TimeUnit.SECONDS);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
}
_pythonProcess = null;
}
if (_gatewayServer != null) {
_gatewayServer.shutdown();
_gatewayServer = null;
}
_llmWorker = null;

//clear thread-local configurations
ConfigurationManager.clearLocalConfigs();
if( ConfigurationManager.isCodegenEnabled() )
Expand Down
31 changes: 31 additions & 0 deletions src/main/java/org/apache/sysds/api/jmlc/LLMCallback.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package org.apache.sysds.api.jmlc;

/**
* Interface for the Python LLM worker.
* The Python side implements this via Py4J callback.
*/
public interface LLMCallback {

/**
* Generates text using the LLM model.
*
* @param prompt the input prompt text
* @param maxNewTokens maximum number of new tokens to generate
* @param temperature sampling temperature (0.0 = deterministic, higher = more random)
* @param topP nucleus sampling probability threshold
* @return generated text continuation
*/
String generate(String prompt, int maxNewTokens, double temperature, double topP);

/**
* Generates text and returns result with token counts as a JSON string.
* Format: {"text": "...", "input_tokens": N, "output_tokens": M}
*
* @param prompt the input prompt text
* @param maxNewTokens maximum number of new tokens to generate
* @param temperature sampling temperature (0.0 = deterministic, higher = more random)
* @param topP nucleus sampling probability threshold
* @return JSON string with generated text and token counts
*/
String generateWithTokenCount(String prompt, int maxNewTokens, double temperature, double topP);
}
109 changes: 109 additions & 0 deletions src/main/java/org/apache/sysds/api/jmlc/PreparedScript.java
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@ public class PreparedScript implements ConfigurableAPI
private final CompilerConfig _cconf;
private HashMap<String, String> _outVarLineage;

//LLM inference support
private LLMCallback _llmWorker = null;

private PreparedScript(PreparedScript that) {
//shallow copy, except for a separate symbol table
//and related meta data of reused inputs
Expand Down Expand Up @@ -160,6 +163,112 @@ public CompilerConfig getCompilerConfig() {
return _cconf;
}

/**
* Sets the LLM worker callback for text generation.
*
* @param worker the LLM callback interface
*/
public void setLLMWorker(LLMCallback worker) {
_llmWorker = worker;
}

/**
* Gets the LLM worker callback.
*
* @return the LLM callback interface, or null if not set
*/
public LLMCallback getLLMWorker() {
return _llmWorker;
}

/**
* Generates text using the LLM worker.
*
* @param prompt the input prompt text
* @param maxNewTokens maximum number of new tokens to generate
* @param temperature sampling temperature (0.0 = deterministic, higher = more random)
* @param topP nucleus sampling probability threshold
* @return generated text
* @throws DMLException if no LLM worker is set
*/
public String generate(String prompt, int maxNewTokens, double temperature, double topP) {
if (_llmWorker == null) {
throw new DMLException("No LLM worker set. Call setLLMWorker() first.");
}
return _llmWorker.generate(prompt, maxNewTokens, temperature, topP);
}

/**
* Generates text for multiple prompts and returns results as a FrameBlock.
* The FrameBlock has two columns: [prompt, generated_text].
*
* @param prompts array of input prompt texts
* @param maxNewTokens maximum number of new tokens to generate
* @param temperature sampling temperature
* @param topP nucleus sampling probability threshold
* @return FrameBlock with columns [prompt, generated_text]
*/
public FrameBlock generateBatch(String[] prompts, int maxNewTokens, double temperature, double topP) {
if (_llmWorker == null) {
throw new DMLException("No LLM worker set. Call setLLMWorker() first.");
}
//generate text for each prompt
String[][] data = new String[prompts.length][2];
for (int i = 0; i < prompts.length; i++) {
data[i][0] = prompts[i];
data[i][1] = _llmWorker.generate(prompts[i], maxNewTokens, temperature, topP);
}
//create FrameBlock with string schema
ValueType[] schema = new ValueType[]{ValueType.STRING, ValueType.STRING};
String[] colNames = new String[]{"prompt", "generated_text"};
FrameBlock fb = new FrameBlock(schema, colNames);
for (String[] row : data)
fb.appendRow(row);
return fb;
}

/**
* Generates text for multiple prompts and returns results with timing metrics.
* The FrameBlock has five columns: [prompt, generated_text, time_ms, input_tokens, output_tokens].
*
* @param prompts array of input prompt texts
* @param maxNewTokens maximum number of new tokens to generate
* @param temperature sampling temperature
* @param topP nucleus sampling probability threshold
* @return FrameBlock with columns [prompt, generated_text, time_ms, input_tokens, output_tokens]
*/
public FrameBlock generateBatchWithMetrics(String[] prompts, int maxNewTokens, double temperature, double topP) {
if (_llmWorker == null) {
throw new DMLException("No LLM worker set. Call setLLMWorker() first.");
}
//generate text for each prompt with timing and token counts
String[][] data = new String[prompts.length][5];
for (int i = 0; i < prompts.length; i++) {
long start = System.nanoTime();
String json = _llmWorker.generateWithTokenCount(prompts[i], maxNewTokens, temperature, topP);
long elapsed = (System.nanoTime() - start) / 1_000_000;
//parse JSON response: {"text": "...", "input_tokens": N, "output_tokens": M}
try {
org.apache.wink.json4j.JSONObject obj = new org.apache.wink.json4j.JSONObject(json);
data[i][0] = prompts[i];
data[i][1] = obj.getString("text");
data[i][2] = String.valueOf(elapsed);
data[i][3] = String.valueOf(obj.getInt("input_tokens"));
data[i][4] = String.valueOf(obj.getInt("output_tokens"));
} catch (Exception e) {
throw new DMLException("Failed to parse LLM worker response: " + e.getMessage());
}
}
//create FrameBlock with schema
ValueType[] schema = new ValueType[]{
ValueType.STRING, ValueType.STRING, ValueType.INT64, ValueType.INT64, ValueType.INT64};
String[] colNames = new String[]{"prompt", "generated_text", "time_ms", "input_tokens", "output_tokens"};
FrameBlock fb = new FrameBlock(schema, colNames);
for (String[] row : data)
fb.appendRow(row);
return fb;
}

/**
* Binds a scalar boolean to a registered input variable.
*
Expand Down
Loading