diff --git a/src/main/java/org/apache/sysds/api/jmlc/Connection.java b/src/main/java/org/apache/sysds/api/jmlc/Connection.java index 525c1a97bb2..21b98c8563e 100644 --- a/src/main/java/org/apache/sysds/api/jmlc/Connection.java +++ b/src/main/java/org/apache/sysds/api/jmlc/Connection.java @@ -28,7 +28,11 @@ import java.util.Arrays; import java.util.Collections; import java.util.Map; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.Path; import org.apache.sysds.hops.OptimizerUtils; @@ -66,6 +70,7 @@ import org.apache.sysds.runtime.transform.meta.TfMetaUtils; import org.apache.sysds.runtime.util.CollectionUtils; import org.apache.sysds.runtime.util.DataConverter; +import py4j.GatewayServer; /** * Interaction with SystemDS using the JMLC (Java Machine Learning Connector) API is initiated with @@ -91,6 +96,12 @@ public class Connection implements Closeable private final DMLConfig _dmlconf; private final CompilerConfig _cconf; private static FileSystem fs = null; + private Process _pythonProcess = null; + private py4j.GatewayServer _gatewayServer = null; + private LLMCallback _llmWorker = null; + private CountDownLatch _workerLatch = null; + + private static final Log LOG = LogFactory.getLog(Connection.class.getName()); /** * Connection constructor, the starting point for any other JMLC API calls. @@ -287,6 +298,137 @@ public PreparedScript prepareScript(String script, Map nsscripts, //return newly create precompiled script return new PreparedScript(rtprog, inputs, outputs, _dmlconf, _cconf); } + + /** + * Loads a HuggingFace model via Python worker for LLM inference. + * Uses auto-detected available ports for Py4J communication. + * + * @param modelName HuggingFace model name + * @param workerScriptPath path to the Python worker script + * @return LLMCallback interface to the Python worker + */ + public LLMCallback loadModel(String modelName, String workerScriptPath) { + //auto-find available ports + int javaPort = findAvailablePort(); + int pythonPort = findAvailablePort(); + return loadModel(modelName, workerScriptPath, javaPort, pythonPort); + } + + /** + * Loads a HuggingFace model via Python worker for LLM inference. + * Starts a Python subprocess and connects via Py4J. + * + * @param modelName HuggingFace model name + * @param workerScriptPath path to the Python worker script + * @param javaPort port for Java gateway server + * @param pythonPort port for Python callback server + * @return LLMCallback interface to the Python worker + */ + public LLMCallback loadModel(String modelName, String workerScriptPath, int javaPort, int pythonPort) { + if (_llmWorker != null) + return _llmWorker; + try { + //initialize latch for worker registration + _workerLatch = new CountDownLatch(1); + + //start Py4J gateway server with callback support + _gatewayServer = new GatewayServer.GatewayServerBuilder() + .entryPoint(this) + .javaPort(javaPort) + .callbackClient(pythonPort, java.net.InetAddress.getLoopbackAddress()) + .build(); + _gatewayServer.start(); + + //give gateway time to start + Thread.sleep(500); + + //start python worker process with both ports + String pythonCmd = findPythonCommand(); + LOG.info("Starting LLM worker with script: " + workerScriptPath + + " (python=" + pythonCmd + ", javaPort=" + javaPort + ", pythonPort=" + pythonPort + ")"); + _pythonProcess = new ProcessBuilder( + pythonCmd, workerScriptPath, modelName, + String.valueOf(javaPort), String.valueOf(pythonPort) + ).redirectErrorStream(true).start(); + + //read python output in background thread + Thread outputReader = new Thread(() -> { + try (BufferedReader reader = new BufferedReader( + new InputStreamReader(_pythonProcess.getInputStream()))) { + String line; + while ((line = reader.readLine()) != null) { + LOG.info("[LLM Worker] " + line); + } + } catch (IOException e) { + LOG.error("Error reading LLM worker output", e); + } + }); + outputReader.setName("llm-worker-output"); + outputReader.setDaemon(true); + outputReader.start(); + + //wait for worker to register, checking process liveness periodically + long deadlineNs = System.nanoTime() + TimeUnit.SECONDS.toNanos(60); + while (!_workerLatch.await(2, TimeUnit.SECONDS)) { + if (!_pythonProcess.isAlive()) { + int exitCode = _pythonProcess.exitValue(); + throw new DMLException("LLM worker process died during startup (exit code " + exitCode + ")"); + } + if (System.nanoTime() > deadlineNs) { + throw new DMLException("Timeout waiting for LLM worker to register (60s)"); + } + } + + } catch (DMLException e) { + throw e; + } catch (Exception e) { + throw new DMLException("Failed to start LLM worker: " + e.getMessage()); + } + return _llmWorker; + } + + /** + * Finds the available Python command, trying python3 first then python. + * @return python command name + */ + private static String findPythonCommand() { + for (String cmd : new String[]{"python3", "python"}) { + try { + Process p = new ProcessBuilder(cmd, "--version") + .redirectErrorStream(true).start(); + int exitCode = p.waitFor(); + if (exitCode == 0) + return cmd; + } catch (Exception e) { + //command not found, try next + } + } + throw new DMLException("No Python installation found (tried python3, python)"); + } + + /** + * Finds an available port on the local machine. + * @return available port number + */ + private int findAvailablePort() { + try (java.net.ServerSocket socket = new java.net.ServerSocket(0)) { + socket.setReuseAddress(true); + return socket.getLocalPort(); + } catch (IOException e) { + throw new DMLException("Failed to find available port: " + e.getMessage()); + } + } + + /** + * Called by Python worker to register itself via Py4J. + */ + public void registerWorker(LLMCallback worker) { + _llmWorker = worker; + if (_workerLatch != null) { + _workerLatch.countDown(); + } + LOG.info("LLM worker registered successfully"); + } /** * Close connection to SystemDS, which clears the @@ -294,6 +436,23 @@ public PreparedScript prepareScript(String script, Map nsscripts, */ @Override public void close() { + + //shutdown LLM worker if running + if (_pythonProcess != null) { + _pythonProcess.destroyForcibly(); + try { + _pythonProcess.waitFor(5, TimeUnit.SECONDS); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + _pythonProcess = null; + } + if (_gatewayServer != null) { + _gatewayServer.shutdown(); + _gatewayServer = null; + } + _llmWorker = null; + //clear thread-local configurations ConfigurationManager.clearLocalConfigs(); if( ConfigurationManager.isCodegenEnabled() ) diff --git a/src/main/java/org/apache/sysds/api/jmlc/LLMCallback.java b/src/main/java/org/apache/sysds/api/jmlc/LLMCallback.java new file mode 100644 index 00000000000..10d9787d992 --- /dev/null +++ b/src/main/java/org/apache/sysds/api/jmlc/LLMCallback.java @@ -0,0 +1,31 @@ +package org.apache.sysds.api.jmlc; + +/** + * Interface for the Python LLM worker. + * The Python side implements this via Py4J callback. + */ +public interface LLMCallback { + + /** + * Generates text using the LLM model. + * + * @param prompt the input prompt text + * @param maxNewTokens maximum number of new tokens to generate + * @param temperature sampling temperature (0.0 = deterministic, higher = more random) + * @param topP nucleus sampling probability threshold + * @return generated text continuation + */ + String generate(String prompt, int maxNewTokens, double temperature, double topP); + + /** + * Generates text and returns result with token counts as a JSON string. + * Format: {"text": "...", "input_tokens": N, "output_tokens": M} + * + * @param prompt the input prompt text + * @param maxNewTokens maximum number of new tokens to generate + * @param temperature sampling temperature (0.0 = deterministic, higher = more random) + * @param topP nucleus sampling probability threshold + * @return JSON string with generated text and token counts + */ + String generateWithTokenCount(String prompt, int maxNewTokens, double temperature, double topP); +} \ No newline at end of file diff --git a/src/main/java/org/apache/sysds/api/jmlc/PreparedScript.java b/src/main/java/org/apache/sysds/api/jmlc/PreparedScript.java index 31bb7457227..e664f04ca04 100644 --- a/src/main/java/org/apache/sysds/api/jmlc/PreparedScript.java +++ b/src/main/java/org/apache/sysds/api/jmlc/PreparedScript.java @@ -80,6 +80,9 @@ public class PreparedScript implements ConfigurableAPI private final CompilerConfig _cconf; private HashMap _outVarLineage; + //LLM inference support + private LLMCallback _llmWorker = null; + private PreparedScript(PreparedScript that) { //shallow copy, except for a separate symbol table //and related meta data of reused inputs @@ -160,6 +163,112 @@ public CompilerConfig getCompilerConfig() { return _cconf; } + /** + * Sets the LLM worker callback for text generation. + * + * @param worker the LLM callback interface + */ + public void setLLMWorker(LLMCallback worker) { + _llmWorker = worker; + } + + /** + * Gets the LLM worker callback. + * + * @return the LLM callback interface, or null if not set + */ + public LLMCallback getLLMWorker() { + return _llmWorker; + } + + /** + * Generates text using the LLM worker. + * + * @param prompt the input prompt text + * @param maxNewTokens maximum number of new tokens to generate + * @param temperature sampling temperature (0.0 = deterministic, higher = more random) + * @param topP nucleus sampling probability threshold + * @return generated text + * @throws DMLException if no LLM worker is set + */ + public String generate(String prompt, int maxNewTokens, double temperature, double topP) { + if (_llmWorker == null) { + throw new DMLException("No LLM worker set. Call setLLMWorker() first."); + } + return _llmWorker.generate(prompt, maxNewTokens, temperature, topP); + } + + /** + * Generates text for multiple prompts and returns results as a FrameBlock. + * The FrameBlock has two columns: [prompt, generated_text]. + * + * @param prompts array of input prompt texts + * @param maxNewTokens maximum number of new tokens to generate + * @param temperature sampling temperature + * @param topP nucleus sampling probability threshold + * @return FrameBlock with columns [prompt, generated_text] + */ + public FrameBlock generateBatch(String[] prompts, int maxNewTokens, double temperature, double topP) { + if (_llmWorker == null) { + throw new DMLException("No LLM worker set. Call setLLMWorker() first."); + } + //generate text for each prompt + String[][] data = new String[prompts.length][2]; + for (int i = 0; i < prompts.length; i++) { + data[i][0] = prompts[i]; + data[i][1] = _llmWorker.generate(prompts[i], maxNewTokens, temperature, topP); + } + //create FrameBlock with string schema + ValueType[] schema = new ValueType[]{ValueType.STRING, ValueType.STRING}; + String[] colNames = new String[]{"prompt", "generated_text"}; + FrameBlock fb = new FrameBlock(schema, colNames); + for (String[] row : data) + fb.appendRow(row); + return fb; + } + + /** + * Generates text for multiple prompts and returns results with timing metrics. + * The FrameBlock has five columns: [prompt, generated_text, time_ms, input_tokens, output_tokens]. + * + * @param prompts array of input prompt texts + * @param maxNewTokens maximum number of new tokens to generate + * @param temperature sampling temperature + * @param topP nucleus sampling probability threshold + * @return FrameBlock with columns [prompt, generated_text, time_ms, input_tokens, output_tokens] + */ + public FrameBlock generateBatchWithMetrics(String[] prompts, int maxNewTokens, double temperature, double topP) { + if (_llmWorker == null) { + throw new DMLException("No LLM worker set. Call setLLMWorker() first."); + } + //generate text for each prompt with timing and token counts + String[][] data = new String[prompts.length][5]; + for (int i = 0; i < prompts.length; i++) { + long start = System.nanoTime(); + String json = _llmWorker.generateWithTokenCount(prompts[i], maxNewTokens, temperature, topP); + long elapsed = (System.nanoTime() - start) / 1_000_000; + //parse JSON response: {"text": "...", "input_tokens": N, "output_tokens": M} + try { + org.apache.wink.json4j.JSONObject obj = new org.apache.wink.json4j.JSONObject(json); + data[i][0] = prompts[i]; + data[i][1] = obj.getString("text"); + data[i][2] = String.valueOf(elapsed); + data[i][3] = String.valueOf(obj.getInt("input_tokens")); + data[i][4] = String.valueOf(obj.getInt("output_tokens")); + } catch (Exception e) { + throw new DMLException("Failed to parse LLM worker response: " + e.getMessage()); + } + } + //create FrameBlock with schema + ValueType[] schema = new ValueType[]{ + ValueType.STRING, ValueType.STRING, ValueType.INT64, ValueType.INT64, ValueType.INT64}; + String[] colNames = new String[]{"prompt", "generated_text", "time_ms", "input_tokens", "output_tokens"}; + FrameBlock fb = new FrameBlock(schema, colNames); + for (String[] row : data) + fb.appendRow(row); + return fb; + } + /** * Binds a scalar boolean to a registered input variable. * diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/MatrixBlockDictionary.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/MatrixBlockDictionary.java index 71a4112f157..f6b09d4384a 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/MatrixBlockDictionary.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/MatrixBlockDictionary.java @@ -67,8 +67,6 @@ public class MatrixBlockDictionary extends ADictionary { final private MatrixBlock _data; - static final VectorSpecies SPECIES = DoubleVector.SPECIES_PREFERRED; - /** * Unsafe private constructor that does not check the data validity. USE WITH CAUTION. * @@ -2127,9 +2125,6 @@ private void preaggValuesFromDenseDictDenseAggRangeRange(final int numVals, fina private static void preaggValuesFromDenseDictBlockedIKJ(double[] a, double[] b, double[] ret, int bi, int bk, int bj, int bie, int bke, int cz, int az, int ls, int cut, int sOffT, int eOffT) { - final int vLen = SPECIES.length(); - final DoubleVector vVec = DoubleVector.zero(SPECIES); - final int leftover = (eOffT - sOffT) % vLen; // leftover not vectorized for(int i = bi; i < bie; i++) { final int offI = i * cz; final int offOutT = i * az + bj; @@ -2138,27 +2133,14 @@ private static void preaggValuesFromDenseDictBlockedIKJ(double[] a, double[] b, final int sOff = sOffT + idb; final int eOff = eOffT + idb; final double v = a[offI + k]; - vecInnerLoop(v, b, ret, offOutT, eOff, sOff, leftover, vLen, vVec); + int offOut = offOutT; + for(int j = sOff; j < eOff; j++, offOut++) { + ret[offOut] += v * b[j]; + } } } } - private static void vecInnerLoop(final double v, final double[] b, final double[] ret, final int offOutT, - final int eOff, final int sOff, final int leftover, final int vLen, DoubleVector vVec) { - int offOut = offOutT; - vVec = vVec.broadcast(v); - final int end = eOff - leftover; - for(int j = sOff; j < end; j += vLen, offOut += vLen) { - DoubleVector res = DoubleVector.fromArray(SPECIES, ret, offOut); - DoubleVector bVec = DoubleVector.fromArray(SPECIES, b, j); - vVec.fma(bVec, res).intoArray(ret, offOut); - } - for(int j = end; j < eOff; j++, offOut++) { - ret[offOut] += v * b[j]; - } - - } - private void preaggValuesFromDenseDictDenseAggRangeGeneric(final int numVals, final IColIndex colIndexes, final int s, final int e, final double[] b, final int cut, final double[] ret) { final int cz = colIndexes.size(); diff --git a/src/main/python/llm_worker.py b/src/main/python/llm_worker.py new file mode 100644 index 00000000000..27160e27f13 --- /dev/null +++ b/src/main/python/llm_worker.py @@ -0,0 +1,91 @@ +""" +SystemDS LLM Worker - Python side of the Py4J bridge. +Java starts this script, then calls generate() via Py4J. +""" +import sys +import json +import torch +from transformers import AutoTokenizer, AutoModelForCausalLM +from py4j.java_gateway import JavaGateway, GatewayParameters, CallbackServerParameters + +class LLMWorker: + def __init__(self, model_name="distilgpt2"): + print(f"Loading model: {model_name}", flush=True) + self.tokenizer = AutoTokenizer.from_pretrained(model_name) + #auto-detect GPU and load model accordingly + if torch.cuda.is_available(): + print(f"CUDA available: {torch.cuda.device_count()} GPU(s)", flush=True) + self.model = AutoModelForCausalLM.from_pretrained( + model_name, device_map="auto", torch_dtype=torch.float16) + self.device = "cuda" + else: + self.model = AutoModelForCausalLM.from_pretrained(model_name) + self.device = "cpu" + self.model.eval() + print(f"Model loaded: {model_name} (device={self.device})", flush=True) + + def generate(self, prompt, max_new_tokens=50, temperature=0.7, top_p=0.9): + inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device) + with torch.no_grad(): + outputs = self.model.generate( + **inputs, + max_new_tokens=int(max_new_tokens), + temperature=float(temperature), + top_p=float(top_p), + do_sample=float(temperature) > 0.0 + ) + new_tokens = outputs[0][inputs["input_ids"].shape[1]:] + return self.tokenizer.decode(new_tokens, skip_special_tokens=True) + + def generateWithTokenCount(self, prompt, max_new_tokens=50, temperature=0.7, top_p=0.9): + inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device) + input_token_count = inputs["input_ids"].shape[1] + with torch.no_grad(): + outputs = self.model.generate( + **inputs, + max_new_tokens=int(max_new_tokens), + temperature=float(temperature), + top_p=float(top_p), + do_sample=float(temperature) > 0.0 + ) + new_tokens = outputs[0][input_token_count:] + output_token_count = len(new_tokens) + text = self.tokenizer.decode(new_tokens, skip_special_tokens=True) + return json.dumps({ + "text": text, + "input_tokens": input_token_count, + "output_tokens": output_token_count + }) + + class Java: + implements = ["org.apache.sysds.api.jmlc.LLMCallback"] + +if __name__ == "__main__": + model_name = sys.argv[1] if len(sys.argv) > 1 else "distilgpt2" + java_port = int(sys.argv[2]) if len(sys.argv) > 2 else 25333 + python_port = int(sys.argv[3]) if len(sys.argv) > 3 else 25334 + + print(f"Starting LLM worker (javaPort={java_port}, pythonPort={python_port})", flush=True) + + worker = LLMWorker(model_name) + + #connect to Java's GatewayServer and register this worker + gateway = JavaGateway( + gateway_parameters=GatewayParameters(port=java_port), + callback_server_parameters=CallbackServerParameters(port=python_port) + ) + + print(f"Python callback server started on port {python_port}", flush=True) + + gateway.entry_point.registerWorker(worker) + print("Worker registered with Java, waiting for requests...", flush=True) + + # Keep the worker alive to handle callbacks from Java + # The callback server runs in a daemon thread, so we need to block here + import threading + shutdown_event = threading.Event() + try: + # Wait indefinitely until Java closes the connection or kills the process + shutdown_event.wait() + except KeyboardInterrupt: + print("Worker shutting down", flush=True) \ No newline at end of file diff --git a/src/test/java/org/apache/sysds/test/functions/jmlc/JMLCLLMInferenceTest.java b/src/test/java/org/apache/sysds/test/functions/jmlc/JMLCLLMInferenceTest.java new file mode 100644 index 00000000000..0e07e07d221 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/jmlc/JMLCLLMInferenceTest.java @@ -0,0 +1,171 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.jmlc; + +import org.apache.sysds.api.jmlc.Connection; +import org.apache.sysds.api.jmlc.LLMCallback; +import org.apache.sysds.api.jmlc.PreparedScript; +import org.apache.sysds.runtime.frame.data.FrameBlock; +import org.apache.sysds.test.AutomatedTestBase; +import org.junit.Assert; +import org.junit.Test; + +/** + * Test LLM inference capabilities via JMLC API. + * This test requires Python with transformers and torch installed. + */ +public class JMLCLLMInferenceTest extends AutomatedTestBase { + private final static String TEST_NAME = "JMLCLLMInferenceTest"; + private final static String TEST_DIR = "functions/jmlc/"; + private final static String MODEL_NAME = "distilgpt2"; + private final static String WORKER_SCRIPT = "src/main/python/llm_worker.py"; + private final static String DML_SCRIPT = "x = 1;\nwrite(x, './tmp/x');"; + + @Override + public void setUp() { + addTestConfiguration(TEST_DIR, TEST_NAME); + getAndLoadTestConfiguration(TEST_NAME); + } + + /** + * Creates a connection, loads the LLM model, and returns a PreparedScript + * with the LLM worker attached. + */ + private PreparedScript createLLMScript(Connection conn) throws Exception { + LLMCallback llmWorker = conn.loadModel(MODEL_NAME, WORKER_SCRIPT); + Assert.assertNotNull("LLM worker should not be null", llmWorker); + PreparedScript ps = conn.prepareScript(DML_SCRIPT, new String[]{}, new String[]{"x"}); + ps.setLLMWorker(llmWorker); + return ps; + } + + @Test + public void testLLMInference() { + Connection conn = null; + try { + conn = new Connection(); + PreparedScript ps = createLLMScript(conn); + + //generate text using llm + String prompt = "The meaning of life is"; + String result = ps.generate(prompt, 20, 0.7, 0.9); + + //verify result + Assert.assertNotNull("Generated text should not be null", result); + Assert.assertFalse("Generated text should not be empty", result.isEmpty()); + + System.out.println("Prompt: " + prompt); + System.out.println("Generated: " + result); + + } catch (Exception e) { + System.out.println("Skipping LLM test:"); + e.printStackTrace(); + org.junit.Assume.assumeNoException("LLM dependencies not available", e); + } finally { + if (conn != null) + conn.close(); + } + } + + @Test + public void testBatchInference() { + Connection conn = null; + try { + conn = new Connection(); + PreparedScript ps = createLLMScript(conn); + + //batch generate with multiple prompts + String[] prompts = { + "The meaning of life is", + "Machine learning is", + "Apache SystemDS enables" + }; + FrameBlock result = ps.generateBatch(prompts, 20, 0.7, 0.9); + + //verify FrameBlock structure + Assert.assertNotNull("Batch result should not be null", result); + Assert.assertEquals("Should have 3 rows", 3, result.getNumRows()); + Assert.assertEquals("Should have 2 columns", 2, result.getNumColumns()); + + //verify each row has prompt and generated text + for (int i = 0; i < prompts.length; i++) { + String prompt = (String) result.get(i, 0); + String generated = (String) result.get(i, 1); + Assert.assertEquals("Prompt should match", prompts[i], prompt); + Assert.assertNotNull("Generated text should not be null", generated); + Assert.assertFalse("Generated text should not be empty", generated.isEmpty()); + System.out.println("Prompt: " + prompt); + System.out.println("Generated: " + generated); + } + + } catch (Exception e) { + System.out.println("Skipping batch LLM test:"); + e.printStackTrace(); + org.junit.Assume.assumeNoException("LLM dependencies not available", e); + } finally { + if (conn != null) + conn.close(); + } + } + + @Test + public void testBatchWithMetrics() { + Connection conn = null; + try { + conn = new Connection(); + PreparedScript ps = createLLMScript(conn); + + //batch generate with metrics + String[] prompts = {"The meaning of life is", "Data science is"}; + FrameBlock result = ps.generateBatchWithMetrics(prompts, 20, 0.7, 0.9); + + //verify FrameBlock structure with metrics and token counts + Assert.assertNotNull("Metrics result should not be null", result); + Assert.assertEquals("Should have 2 rows", 2, result.getNumRows()); + Assert.assertEquals("Should have 5 columns", 5, result.getNumColumns()); + + //verify metrics columns contain timing and token data + for (int i = 0; i < prompts.length; i++) { + String prompt = (String) result.get(i, 0); + String generated = (String) result.get(i, 1); + long timeMs = Long.parseLong(result.get(i, 2).toString()); + long inputTokens = Long.parseLong(result.get(i, 3).toString()); + long outputTokens = Long.parseLong(result.get(i, 4).toString()); + Assert.assertEquals("Prompt should match", prompts[i], prompt); + Assert.assertFalse("Generated text should not be empty", generated.isEmpty()); + Assert.assertTrue("Time should be positive", timeMs > 0); + Assert.assertTrue("Input tokens should be positive", inputTokens > 0); + Assert.assertTrue("Output tokens should be positive", outputTokens > 0); + System.out.println("Prompt: " + prompt); + System.out.println("Generated: " + generated); + System.out.println("Time: " + timeMs + "ms"); + System.out.println("Tokens: " + inputTokens + " in, " + outputTokens + " out"); + } + + } catch (Exception e) { + System.out.println("Skipping metrics LLM test:"); + e.printStackTrace(); + org.junit.Assume.assumeNoException("LLM dependencies not available", e); + } finally { + if (conn != null) + conn.close(); + } + } +}