From f434ffdac7fae3ae135a936e8982b5c3b28f31bb Mon Sep 17 00:00:00 2001 From: Janardhan Pulivarthi Date: Mon, 20 Oct 2025 08:52:25 +0530 Subject: [PATCH 01/15] implement eviction policy interface --- .../caching/EvictionPolicy.java | 37 +++++++++++++++++++ 1 file changed, 37 insertions(+) create mode 100644 src/main/java/org/apache/sysds/runtime/controlprogram/caching/EvictionPolicy.java diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/EvictionPolicy.java b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/EvictionPolicy.java new file mode 100644 index 00000000000..f50f18181cb --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/EvictionPolicy.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.controlprogram.caching; + +import java.util.Set; + +/** + * An interface all Buffer pool eviction policies, + * for pluggable eviction strategies - LRU, FIFO, Prescient + */ +public interface EvictionPolicy { + + /** + * Select a block to evict from the given list of candidates + * + * @param candidates A set of candidate block identifiers for currently in buffer + * @return The identifier of the block chosen for eviction + */ + String selectBlockForEviction(Set candidates); +} From 70be344f8647a3e67710e504ea84b450c71f75f4 Mon Sep 17 00:00:00 2001 From: Janardhan Pulivarthi Date: Mon, 20 Oct 2025 09:18:00 +0530 Subject: [PATCH 02/15] implement simple prescient policy --- .../prescientbuffer/PrescientPolicy.java | 69 +++++++++++++++++++ .../prescientbuffer/PrescientPolicyTest.java | 29 ++++++++ 2 files changed, 98 insertions(+) create mode 100644 src/main/java/org/apache/sysds/runtime/controlprogram/caching/prescientbuffer/PrescientPolicy.java create mode 100644 src/main/java/org/apache/sysds/runtime/controlprogram/caching/prescientbuffer/PrescientPolicyTest.java diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/prescientbuffer/PrescientPolicy.java b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/prescientbuffer/PrescientPolicy.java new file mode 100644 index 00000000000..e83b7b39ee5 --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/prescientbuffer/PrescientPolicy.java @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.controlprogram.caching.prescientbuffer; + +import org.apache.sysds.runtime.controlprogram.caching.EvictionPolicy; + +import java.util.HashMap; +import java.util.Map; +import java.util.Set; + +/** + * Implement prescient buffer + */ +public class PrescientPolicy implements EvictionPolicy { + + // Map of block ID, access times + private final Map accessTimeMap = new HashMap<>(); + + // register blocks with access time + public void setAccessTime(String blockId, long accessTime) { + accessTimeMap.put(blockId, accessTime); + } + + /** + * Select a block to evict from the given list of candidates + * + * @param candidates A set of candidate block identifiers for currently in buffer + * @return The identifier of the block chosen for eviction + */ + @Override + public String selectBlockForEviction(Set candidates) { + // base case + if (candidates == null || candidates.isEmpty()) { + return null; + } + + String selected = null; + long maxTime = -1; + + for (String candidate : candidates) { + long time = accessTimeMap.getOrDefault(candidate, Long.MAX_VALUE); + + if (time > maxTime) { + maxTime = time; + selected = candidate; + } + } + + return selected; + } + +} diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/prescientbuffer/PrescientPolicyTest.java b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/prescientbuffer/PrescientPolicyTest.java new file mode 100644 index 00000000000..a0e73756a3d --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/prescientbuffer/PrescientPolicyTest.java @@ -0,0 +1,29 @@ +package org.apache.sysds.runtime.controlprogram.caching.prescientbuffer; + +import org.junit.Test; + +import java.util.HashSet; +import java.util.Set; + +import static org.junit.Assert.*; + +public class PrescientPolicyTest { + + @Test + public void testBasicEviction() { + PrescientPolicy policy = new PrescientPolicy(); + + policy.setAccessTime("block1", 10); + policy.setAccessTime("block2", 40); + policy.setAccessTime("block3", 25); + + Set candidates = new HashSet<>(); + assertNull(policy.selectBlockForEviction(candidates)); + + candidates.add("block1"); + candidates.add("block2"); + candidates.add("block3"); + assertEquals("block2", policy.selectBlockForEviction(candidates)); + + } +} From d12429f8a876e3d9874f851f6603b021f7754c77 Mon Sep 17 00:00:00 2001 From: Janardhan Pulivarthi Date: Thu, 23 Oct 2025 22:42:19 +0530 Subject: [PATCH 03/15] implement basic IOTrace, IOTraceGenerator classes --- .../caching/UnifiedMemoryManager.java | 5 +++ .../caching/prescientbuffer/IOTrace.java | 37 ++++++++++++++++ .../prescientbuffer/IOTraceGenerator.java | 42 +++++++++++++++++++ 3 files changed, 84 insertions(+) create mode 100644 src/main/java/org/apache/sysds/runtime/controlprogram/caching/prescientbuffer/IOTrace.java create mode 100644 src/main/java/org/apache/sysds/runtime/controlprogram/caching/prescientbuffer/IOTraceGenerator.java diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/UnifiedMemoryManager.java b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/UnifiedMemoryManager.java index 91f700e85f6..4f80d87fba8 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/UnifiedMemoryManager.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/UnifiedMemoryManager.java @@ -23,6 +23,7 @@ import org.apache.sysds.api.DMLScript; import org.apache.sysds.hops.OptimizerUtils; import org.apache.sysds.runtime.DMLRuntimeException; +import org.apache.sysds.runtime.controlprogram.caching.prescientbuffer.PrescientPolicy; import org.apache.sysds.runtime.util.LocalFileUtils; import java.io.IOException; @@ -106,6 +107,10 @@ public class UnifiedMemoryManager // Maintenance service for synchronous or asynchronous delete of evicted files private static CacheMaintenanceService _fClean; + // Prescient policy + private static PrescientPolicy _prescientPolicy; +// private static IOTrace _ioTrace; + // Pinned size of physical memory. Starts from 0 for each operation. Max is 70% of heap // This increases only if the input is not present in the cache and read from FS/rdd/fed/gpu private static long _pinnedPhysicalMemSize = 0; diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/prescientbuffer/IOTrace.java b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/prescientbuffer/IOTrace.java new file mode 100644 index 00000000000..c5185ecc2a3 --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/prescientbuffer/IOTrace.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.controlprogram.caching.prescientbuffer; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + * IOTrace holds the pre-computed I/O access trace for the OOC operations. + */ +public class IOTrace { + + // Block ID vs unique accesses + private final Map> _trace; + + public IOTrace() { + _trace = new HashMap<>(); + } +} diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/prescientbuffer/IOTraceGenerator.java b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/prescientbuffer/IOTraceGenerator.java new file mode 100644 index 00000000000..319035e46b7 --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/prescientbuffer/IOTraceGenerator.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.controlprogram.caching.prescientbuffer; + +import org.apache.sysds.runtime.controlprogram.Program; + +/** + * IOTraceGenerator is responsible for analyzing the program plan (LOP DAG) + * and generating the predictive I/O trace, before runtime. + */ +public class IOTraceGenerator { + + /** + * Generate the IOTrace for the execution plan. + * This is utilized by ExecutionContext + * + * @param program the entire program + * @return IOTrace object with trace data + */ + public static IOTrace generateTrace(Program program) { + IOTrace _trace = new IOTrace(); + + return _trace; + } +} From 8205ec6704c02a9b677f561067e78605890101d8 Mon Sep 17 00:00:00 2001 From: Janardhan Pulivarthi Date: Thu, 23 Oct 2025 23:16:20 +0530 Subject: [PATCH 04/15] add evictionpolicy option in UMM --- .../runtime/controlprogram/caching/UnifiedMemoryManager.java | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/UnifiedMemoryManager.java b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/UnifiedMemoryManager.java index 4f80d87fba8..54e6f9f7fe6 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/UnifiedMemoryManager.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/UnifiedMemoryManager.java @@ -23,6 +23,7 @@ import org.apache.sysds.api.DMLScript; import org.apache.sysds.hops.OptimizerUtils; import org.apache.sysds.runtime.DMLRuntimeException; +import org.apache.sysds.runtime.controlprogram.caching.prescientbuffer.IOTrace; import org.apache.sysds.runtime.controlprogram.caching.prescientbuffer.PrescientPolicy; import org.apache.sysds.runtime.util.LocalFileUtils; @@ -107,9 +108,10 @@ public class UnifiedMemoryManager // Maintenance service for synchronous or asynchronous delete of evicted files private static CacheMaintenanceService _fClean; + private static EvictionPolicy _evictionPolicy; // Prescient policy private static PrescientPolicy _prescientPolicy; -// private static IOTrace _ioTrace; + private static IOTrace _ioTrace; // Pinned size of physical memory. Starts from 0 for each operation. Max is 70% of heap // This increases only if the input is not present in the cache and read from FS/rdd/fed/gpu @@ -191,6 +193,7 @@ public static void init() { _totCachedSize = 0; _pinnedPhysicalMemSize = 0; _pinnedVirtualMemSize = 0; + _evictionPolicy = new PrescientPolicy(); } // Cleanup the unified memory manager From d463874187f6832dddc3fb97e7ef9433481590c9 Mon Sep 17 00:00:00 2001 From: Janardhan Pulivarthi Date: Fri, 24 Oct 2025 22:16:07 +0530 Subject: [PATCH 05/15] add IOTrace and IOTraceGenerator --- .../caching/prescientbuffer/IOTrace.java | 25 ++++ .../prescientbuffer/IOTraceGenerator.java | 123 +++++++++++++++++- 2 files changed, 147 insertions(+), 1 deletion(-) diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/prescientbuffer/IOTrace.java b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/prescientbuffer/IOTrace.java index c5185ecc2a3..15634f4b4ad 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/prescientbuffer/IOTrace.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/prescientbuffer/IOTrace.java @@ -19,6 +19,7 @@ package org.apache.sysds.runtime.controlprogram.caching.prescientbuffer; +import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -31,7 +32,31 @@ public class IOTrace { // Block ID vs unique accesses private final Map> _trace; + private long _currentTime; + public IOTrace() { _trace = new HashMap<>(); + _currentTime = 0; + } + + /** + * Access to the block at a current time + */ + public void recordAccess(String blockID) { + _trace.computeIfAbsent(blockID, k -> new ArrayList<>()).add(_currentTime); + _currentTime++; + } + + /** + * Get all access times for a given block + * @param blockID Block ID + * @return all the access times + */ + public List getAccessTime(String blockID) { + return _trace.getOrDefault(blockID, new ArrayList<>()); + } + + public Map> getTrace() { + return _trace; } } diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/prescientbuffer/IOTraceGenerator.java b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/prescientbuffer/IOTraceGenerator.java index 319035e46b7..87f3d1f342a 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/prescientbuffer/IOTraceGenerator.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/prescientbuffer/IOTraceGenerator.java @@ -19,7 +19,21 @@ package org.apache.sysds.runtime.controlprogram.caching.prescientbuffer; +import org.apache.sysds.runtime.DMLRuntimeException; +import org.apache.sysds.runtime.controlprogram.BasicProgramBlock; +import org.apache.sysds.runtime.controlprogram.ForProgramBlock; +import org.apache.sysds.runtime.controlprogram.FunctionProgramBlock; +import org.apache.sysds.runtime.controlprogram.IfProgramBlock; import org.apache.sysds.runtime.controlprogram.Program; +import org.apache.sysds.runtime.controlprogram.ProgramBlock; +import org.apache.sysds.runtime.controlprogram.WhileProgramBlock; +import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; +import org.apache.sysds.runtime.instructions.Instruction; +import org.apache.sysds.runtime.instructions.cp.CPOperand; +import org.apache.sysds.runtime.instructions.ooc.ReblockOOCInstruction; +import org.apache.sysds.runtime.meta.DataCharacteristics; + +import java.util.ArrayList; /** * IOTraceGenerator is responsible for analyzing the program plan (LOP DAG) @@ -34,9 +48,116 @@ public class IOTraceGenerator { * @param program the entire program * @return IOTrace object with trace data */ - public static IOTrace generateTrace(Program program) { + public static IOTrace generateTrace(Program program, ExecutionContext ec) { IOTrace _trace = new IOTrace(); + // Use a long array as a "pass-by-reference" wrapper for the logical time + // so it can be incremented correctly inside the recursive calls. + long[] logicalTime = new long[]{0}; + + // Start the recursive traversal + traverseProgramBlocks(program.getProgramBlocks(), ec, _trace, logicalTime); + return _trace; } + + /** + * Recursively traverses a list of program blocks. + * + * @param programBlocks The list of blocks to traverse + * @param ec The ExecutionContext + * @param trace The trace object to populate + * @param logicalTime A pass-by-reference counter for logical time + */ + private static void traverseProgramBlocks(ArrayList programBlocks, ExecutionContext ec, IOTrace trace, long[] logicalTime) { + + if (programBlocks == null) { + return; + } + + for (ProgramBlock pb : programBlocks) { + if (pb instanceof BasicProgramBlock) { + // --- Base Case --- + // This block has instructions, so process them. + BasicProgramBlock bpb = (BasicProgramBlock) pb; + for (Instruction inst : bpb.getInstructions()) { + logicalTime[0]++; // Increment logical time for each instruction + processInstruction(inst, ec, trace, logicalTime[0]); + } + } + else if (pb instanceof IfProgramBlock) { + // --- Recursive Step --- + // Traverse into the 'if' and 'else' bodies + IfProgramBlock ifpb = (IfProgramBlock) pb; + traverseProgramBlocks(ifpb.getChildBlocksIfBody(), ec, trace, logicalTime); + traverseProgramBlocks(ifpb.getChildBlocksElseBody(), ec, trace, logicalTime); + } + else if (pb instanceof WhileProgramBlock) { + // --- Recursive Step --- + // For a static trace, we can only traverse the body once. + // A more advanced tracer might try to unroll N times. + WhileProgramBlock wpb = (WhileProgramBlock) pb; + traverseProgramBlocks(wpb.getChildBlocks(), ec, trace, logicalTime); + } + else if (pb instanceof ForProgramBlock) { + // --- Recursive Step --- + // Similar to While, just traverse the body once for a static trace. + ForProgramBlock fpb = (ForProgramBlock) pb; + traverseProgramBlocks(fpb.getChildBlocks(), ec, trace, logicalTime); + } + else if (pb instanceof FunctionProgramBlock) { + // --- Recursive Step --- + // Traverse into the function body + FunctionProgramBlock fnpb = (FunctionProgramBlock) pb; + traverseProgramBlocks(fnpb.getChildBlocks(), ec, trace, logicalTime); + } + } + } + + /** + * Processes a single instruction and records its I/O access pattern in the trace. + * + * @param inst The instruction to process + * @param ec The ExecutionContext + * @param trace The trace object to populate + * @param logicalTime The logical time of this instruction + */ + private static void processInstruction(Instruction inst, ExecutionContext ec, IOTrace trace, long logicalTime) { + + // --- This is your specific logic for OOC instructions --- + + if (inst instanceof ReblockOOCInstruction) { + ReblockOOCInstruction rblk = (ReblockOOCInstruction) inst; + CPOperand input = rblk.input1; + + // We need the file name and data characteristics from the metadata + String fname = ec.getMatrixObject(input).getFileName(); + DataCharacteristics mc = ec.getDataCharacteristics(input.getName()); + + if (mc == null || !mc.dimsKnown()) { + throw new DMLRuntimeException("OOC Trace Generator: DataCharacteristics not available for " + input.getName()); + } + + long numRowBlocks = mc.getNumRowBlocks(); + long numColBlocks = mc.getNumColBlocks(); + + for (long i = 1; i <= numRowBlocks; i++) { + for (long j = 1; j <= numColBlocks; j++) { + + String blockID = createBlockID(fname, i, j); + trace.recordAccess(blockID); + } + } + } + + // (Transpose, MatrixVector, Tee, etc.) to build the full trace. + // else if (inst instanceof MatrixVectorOOCInstruction) { + // // ... handle matrix-vector read pattern ... + // } + } + + private static String createBlockID(String fname, long rowIndex, long colIndex) { + System.out.println(fname + "_" + rowIndex + "_" + colIndex); + return fname + "_" + rowIndex + "_" + colIndex; + } } From e4001545d98f6db4464df0c96abfb4d0e58ee6f7 Mon Sep 17 00:00:00 2001 From: Janardhan Pulivarthi Date: Fri, 24 Oct 2025 23:00:20 +0530 Subject: [PATCH 06/15] add placeholder logic for iotrace --- .../caching/UnifiedMemoryManager.java | 94 ++++++++++++++++++- .../caching/prescientbuffer/IOTrace.java | 5 +- .../prescientbuffer/IOTraceGenerator.java | 2 +- .../prescientbuffer/PrescientPolicy.java | 39 ++++++++ .../context/ExecutionContext.java | 15 +++ 5 files changed, 147 insertions(+), 8 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/UnifiedMemoryManager.java b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/UnifiedMemoryManager.java index 54e6f9f7fe6..adeb41df826 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/UnifiedMemoryManager.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/UnifiedMemoryManager.java @@ -112,6 +112,7 @@ public class UnifiedMemoryManager // Prescient policy private static PrescientPolicy _prescientPolicy; private static IOTrace _ioTrace; + private static long _currentTime = 0; // Pinned size of physical memory. Starts from 0 for each operation. Max is 70% of heap // This increases only if the input is not present in the cache and read from FS/rdd/fed/gpu @@ -207,6 +208,66 @@ public static void cleanup() { _pinnedVirtualMemSize = 0; } + /** + * Sets the I/O trace for the prescient policy. + * This should be called once by the ExecutionContext after trace generation. + * @param trace The generated IOTrace + */ + public static void setTrace(IOTrace trace) { + _ioTrace = trace; + if (_evictionPolicy instanceof PrescientPolicy) { + _prescientPolicy = (PrescientPolicy) _evictionPolicy; + _prescientPolicy.setTrace(_ioTrace); + } + else { + // Optional: Log a warning if the trace is set but the policy isn't Prescient + // LOG.warn("IOTrace was provided, but eviction policy is not prescient!"); + } + } + + /** + * Updates the UMM's logical time. + * This should be called by the ExecutionContext *before* each instruction. + * @param logicalTime The new logical time + */ + public static void updateTime(long logicalTime) { + _currentTime = logicalTime; + prefetch(); + } + + /** + * Prefetches blocks that will be needed soon, based on the I/O trace. + */ + private static void prefetch() { + if (_ioTrace == null || _prescientPolicy == null) { + return; // No trace or policy, cannot prefetch + } + + // Get the list of blocks to prefetch from our policy + List blocksToPrefetch = _prescientPolicy.getBlocksToPrefetch(_currentTime); + + // A real implementation MUST use an asynchronous thread pool + // (e.g., from _fClean) to load these blocks without blocking the main thread. + + for (String blockID : blocksToPrefetch) { + synchronized (_mQueue) { + // Check again inside lock if block was already loaded or pinned + if (_mQueue.containsKey(blockID) || _pinnedEntries.contains(blockID)) { + continue; // Already in memory + } + } + + // --- This is a simplified version for now --- + // TODO: Submit an async prefetch task to _fClean's thread pool + // The task should: 1. Get block size (from metadata) + // 2. Call makeSpace(blockSize) + // 3. Load block from disk + // 4. Add block to _mQueue (synchronized) + + System.out.println("UMM PREFETCH [T="+_currentTime+"]: Planning to prefetch " + blockID); + } + } + /** * Print current status of UMM, including all entries. * NOTE: use only for debugging or testing. @@ -312,10 +373,35 @@ public static int makeSpace(long reqSpace) { synchronized(_mQueue) { // Evict blobs to make room (by default FIFO) while (getUMMFree() < reqSpace && !_mQueue.isEmpty()) { - //remove first unpinned entry from eviction queue - var entry = _mQueue.removeFirstUnpinned(_pinnedEntries); - String ftmp = entry.getKey(); - ByteBuffer bb = entry.getValue(); + // --- NEW PRESCIENT LOGIC --- + String ftmp; // Block ID / filename to evict + + if (_prescientPolicy != null && _ioTrace != null) { + // Use prescient policy to find the best block to evict + ftmp = _prescientPolicy.evict(_mQueue.keySet(), _pinnedEntries, _currentTime); + } else { + // Fallback to default LRU if prescient policy isn't set or has no trace + var entry = _mQueue.removeFirstUnpinned(_pinnedEntries); + ftmp = (entry != null) ? entry.getKey() : null; + } + + if (ftmp == null) { + // Policy couldn't find a block to evict (e.g., all are pinned) + if(!_pinnedEntries.containsAll(_mQueue.keySet())) { + // This case should ideally not be reached if unpinned blocks exist + throw new DMLRuntimeException("UMM: Eviction policy failed to find a candidate."); + } + // If we are here, all blocks are pinned, and we cannot make space. + // The original exception will be thrown later. + break; // Exit the while loop + } + + // Remove the chosen block from the queue + ByteBuffer bb = _mQueue.remove(ftmp); +// //remove first unpinned entry from eviction queue +// var entry = _mQueue.removeFirstUnpinned(_pinnedEntries); +// String ftmp = entry.getKey(); +// ByteBuffer bb = entry.getValue(); if(bb != null) { // Wait for pending serialization diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/prescientbuffer/IOTrace.java b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/prescientbuffer/IOTrace.java index 15634f4b4ad..7514f2f296e 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/prescientbuffer/IOTrace.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/prescientbuffer/IOTrace.java @@ -42,9 +42,8 @@ public IOTrace() { /** * Access to the block at a current time */ - public void recordAccess(String blockID) { - _trace.computeIfAbsent(blockID, k -> new ArrayList<>()).add(_currentTime); - _currentTime++; + public void recordAccess(String blockID, long logicalTime) { + _trace.computeIfAbsent(blockID, k -> new ArrayList<>()).add(logicalTime); } /** diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/prescientbuffer/IOTraceGenerator.java b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/prescientbuffer/IOTraceGenerator.java index 87f3d1f342a..21999863ee5 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/prescientbuffer/IOTraceGenerator.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/prescientbuffer/IOTraceGenerator.java @@ -145,7 +145,7 @@ private static void processInstruction(Instruction inst, ExecutionContext ec, IO for (long j = 1; j <= numColBlocks; j++) { String blockID = createBlockID(fname, i, j); - trace.recordAccess(blockID); + trace.recordAccess(blockID, logicalTime); } } } diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/prescientbuffer/PrescientPolicy.java b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/prescientbuffer/PrescientPolicy.java index e83b7b39ee5..6523a381b3d 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/prescientbuffer/PrescientPolicy.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/prescientbuffer/PrescientPolicy.java @@ -22,6 +22,7 @@ import org.apache.sysds.runtime.controlprogram.caching.EvictionPolicy; import java.util.HashMap; +import java.util.List; import java.util.Map; import java.util.Set; @@ -32,6 +33,7 @@ public class PrescientPolicy implements EvictionPolicy { // Map of block ID, access times private final Map accessTimeMap = new HashMap<>(); + private IOTrace _trace; // register blocks with access time public void setAccessTime(String blockId, long accessTime) { @@ -66,4 +68,41 @@ public String selectBlockForEviction(Set candidates) { return selected; } + /** + * Called by UMM's makeSpace() to decide which block to evict. + * * @param cache The set of all block IDs currently in the buffer + * @param pinned The list of all block IDs that are pinned + * @param currentTime The current logical time + * @return The block ID to evict + */ + public String evict(Set cache, List pinned, long currentTime) { + // TODO: Implement "evict-furthest-in-future" logic here + // 1. Iterate through every 'blockID' in 'cache' + // 2. If 'blockID' is in 'pinned', ignore it. + // 3. Use '_trace.getAccessTime(blockID)' to find its next access time > currentTime + // 4. The block with the (largest next access time) or (no future access) is the winner. + // 5. Return the winner's blockID. + + return null; // Placeholder + } + + /** + * Called by UMM's prefetch() to decide which blocks to load. + * * @param currentTime The current logical time + * @return A list of block IDs to prefetch + */ + public List getBlocksToPrefetch(long currentTime) { + // TODO: Implement prefetch logic here + // 1. Define a "prefetch window" (e.g., time T+1 to T+5) + // 2. Iterate through all blocks in '_trace.getTrace()' + // 3. Check if a block has an access time within that window + // 4. If yes, add it to a list. + // 5. Return the list of blocks. + + return java.util.Collections.emptyList(); // Placeholder + } + + public void setTrace(IOTrace ioTrace) { + _trace = ioTrace; + } } diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java b/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java index fa87d452d15..2658d35cf9f 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java @@ -21,6 +21,7 @@ import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; +import org.apache.hadoop.yarn.webapp.hamlet2.HamletSpec; import org.apache.sysds.api.DMLScript; import org.apache.sysds.common.Types; import org.apache.sysds.common.Types.FileFormat; @@ -37,6 +38,7 @@ import org.apache.sysds.runtime.controlprogram.caching.MatrixObject; import org.apache.sysds.runtime.controlprogram.caching.MatrixObject.UpdateType; import org.apache.sysds.runtime.controlprogram.caching.TensorObject; +import org.apache.sysds.runtime.controlprogram.caching.prescientbuffer.IOTrace; import org.apache.sysds.runtime.controlprogram.federated.MatrixLineagePair; import org.apache.sysds.runtime.controlprogram.paramserv.homomorphicEncryption.SEALClient; import org.apache.sysds.runtime.data.TensorBlock; @@ -93,6 +95,19 @@ public class ExecutionContext { //parfor temporary functions (created by eval) protected Set _fnNames; + private IOTrace _ioTrace; + + public IOTrace getIOTrace() { + if (_ioTrace == null) { + _ioTrace = new IOTrace(); + } + return _ioTrace; + } + + public void setIOTrace(IOTrace ioTrace) { + _ioTrace = ioTrace; + } + /** * List of {@link GPUContext}s owned by this {@link ExecutionContext} */ From 69f67e066fd82056c901a37b25af5cd559b609fb Mon Sep 17 00:00:00 2001 From: Janardhan Pulivarthi Date: Sat, 25 Oct 2025 08:43:03 +0530 Subject: [PATCH 07/15] add implementation for evict, prefetch blocks --- .../prescientbuffer/IOTraceGenerator.java | 10 +- .../prescientbuffer/PrescientPolicy.java | 105 ++++++++--- .../prescientbuffer/PrescientPolicyTest.java | 171 ++++++++++++++++-- 3 files changed, 244 insertions(+), 42 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/prescientbuffer/IOTraceGenerator.java b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/prescientbuffer/IOTraceGenerator.java index 21999863ee5..0d5f8154f86 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/prescientbuffer/IOTraceGenerator.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/prescientbuffer/IOTraceGenerator.java @@ -62,10 +62,10 @@ public static IOTrace generateTrace(Program program, ExecutionContext ec) { } /** - * Recursively traverses a list of program blocks. + * Recursively traverses a list of program blocks for generating the trace * - * @param programBlocks The list of blocks to traverse - * @param ec The ExecutionContext + * @param programBlocks The list of ProgramBlocks to traverse + * @param ec The ExecutionContext * @param trace The trace object to populate * @param logicalTime A pass-by-reference counter for logical time */ @@ -87,7 +87,8 @@ private static void traverseProgramBlocks(ArrayList programBlocks, } else if (pb instanceof IfProgramBlock) { // --- Recursive Step --- - // Traverse into the 'if' and 'else' bodies + // Traverse into the 'if' and 'else' bodies (it is the special case.) + // for, while, function we won't is body IfProgramBlock ifpb = (IfProgramBlock) pb; traverseProgramBlocks(ifpb.getChildBlocksIfBody(), ec, trace, logicalTime); traverseProgramBlocks(ifpb.getChildBlocksElseBody(), ec, trace, logicalTime); @@ -124,7 +125,6 @@ else if (pb instanceof FunctionProgramBlock) { */ private static void processInstruction(Instruction inst, ExecutionContext ec, IOTrace trace, long logicalTime) { - // --- This is your specific logic for OOC instructions --- if (inst instanceof ReblockOOCInstruction) { ReblockOOCInstruction rblk = (ReblockOOCInstruction) inst; diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/prescientbuffer/PrescientPolicy.java b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/prescientbuffer/PrescientPolicy.java index 6523a381b3d..cb8df7b6865 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/prescientbuffer/PrescientPolicy.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/prescientbuffer/PrescientPolicy.java @@ -21,7 +21,10 @@ import org.apache.sysds.runtime.controlprogram.caching.EvictionPolicy; +import java.util.ArrayList; +import java.util.Collections; import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Set; @@ -34,6 +37,8 @@ public class PrescientPolicy implements EvictionPolicy { // Map of block ID, access times private final Map accessTimeMap = new HashMap<>(); private IOTrace _trace; + // Defines how many logical time units to look ahead for prefetching + private static final int PREFETCH_WINDOW = 5; // register blocks with access time public void setAccessTime(String blockId, long accessTime) { @@ -69,37 +74,93 @@ public String selectBlockForEviction(Set candidates) { } /** - * Called by UMM's makeSpace() to decide which block to evict. - * * @param cache The set of all block IDs currently in the buffer + * Finds the next time a block is accessed, after the current time. + * + * @param blockID The block to check + * @param currentTime The current logical time + * @return The logical time of the next access, or Long.MAX_VALUE if never used again. + */ + private long findNextAccess(String blockID, long currentTime) { + if (_trace == null) { + return Long.MAX_VALUE; + } + + List accessTimes = _trace.getAccessTime(blockID); + // Find the first access time that is greater than the current time + for (long time : accessTimes) { + if (time > currentTime) { + return time; + } + } + + // This block is never accessed again in the future + return Long.MAX_VALUE; + } + + /** + * Finds the unpinned block that won't be used in near future (or never used). + * + * @param cache The set of all block IDs currently in the buffer * @param pinned The list of all block IDs that are pinned * @param currentTime The current logical time - * @return The block ID to evict + * @return The block ID used for eviction, or null if all blocks are pinned. */ public String evict(Set cache, List pinned, long currentTime) { - // TODO: Implement "evict-furthest-in-future" logic here - // 1. Iterate through every 'blockID' in 'cache' - // 2. If 'blockID' is in 'pinned', ignore it. - // 3. Use '_trace.getAccessTime(blockID)' to find its next access time > currentTime - // 4. The block with the (largest next access time) or (no future access) is the winner. - // 5. Return the winner's blockID. - - return null; // Placeholder + if (cache == null || cache.isEmpty()) { + return null; + } + + String evictCandidate = null; + long maxNextAccessTime = -1; // We're looking for the largest access time + + for (String blockID : cache) { + // Cannot evict a pinned block + if (pinned.contains(blockID)) { + continue; + } + + // Find the next time this block will be used + long nextAccessTime = findNextAccess(blockID, currentTime); + + // find the block that's never used again + if (nextAccessTime == Long.MAX_VALUE) { + return blockID; + } + + } + + return evictCandidate; } /** - * Called by UMM's prefetch() to decide which blocks to load. - * * @param currentTime The current logical time - * @return A list of block IDs to prefetch + * Sliding Window implementation: + * Looks ahead N time units and finds all unique blocks accessed in that window. + * + * @param currentTime The current logical time + * @return A list of unique block IDs to prefetch */ public List getBlocksToPrefetch(long currentTime) { - // TODO: Implement prefetch logic here - // 1. Define a "prefetch window" (e.g., time T+1 to T+5) - // 2. Iterate through all blocks in '_trace.getTrace()' - // 3. Check if a block has an access time within that window - // 4. If yes, add it to a list. - // 5. Return the list of blocks. - - return java.util.Collections.emptyList(); // Placeholder + if (_trace == null) { + return Collections.emptyList(); + } + + // Use a Set to store unique block IDs + Set blocksToPrefetch = new HashSet<>(); + long lookaheadTime = currentTime + PREFETCH_WINDOW; + + // Iterate over all blocks in the trace + for (String blockID : _trace.getTrace().keySet()) { + List accessTimes = _trace.getAccessTime(blockID); + + // Check if this block is accessed within our prefetch window + for (long time : accessTimes) { + if (time > currentTime && time <= lookaheadTime) { + blocksToPrefetch.add(blockID); + } + } + } + + return new ArrayList<>(blocksToPrefetch); } public void setTrace(IOTrace ioTrace) { diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/prescientbuffer/PrescientPolicyTest.java b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/prescientbuffer/PrescientPolicyTest.java index a0e73756a3d..b538ab34435 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/prescientbuffer/PrescientPolicyTest.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/prescientbuffer/PrescientPolicyTest.java @@ -1,29 +1,170 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + package org.apache.sysds.runtime.controlprogram.caching.prescientbuffer; -import org.junit.Test; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; -import java.util.HashSet; +import org.junit.Before; +import org.junit.Ignore; +import org.junit.Test; +import java.util.List; import java.util.Set; -import static org.junit.Assert.*; - public class PrescientPolicyTest { + private IOTrace trace; + private PrescientPolicy policy; + + /** + * Creates a mock IOTrace for all tests. + * * Access Pattern: + * T=1: A + * T=2: B + * T=3: A + * T=4: C + * T=5: B + * T=6: D (D is never used again) + * T=7: A + * T=8: C + * T=9: E (E will be pinned) + * T=10: B + */ + @Before + public void setUp() { + trace = new IOTrace(); + // Access times are automatically sorted by IOTrace + trace.recordAccess("A", 1); + trace.recordAccess("B", 2); + trace.recordAccess("A", 3); + trace.recordAccess("C", 4); + trace.recordAccess("B", 5); + trace.recordAccess("D", 6); // D is never used again after this + trace.recordAccess("A", 7); + trace.recordAccess("C", 8); + trace.recordAccess("E", 9); // E will be our pinned block + trace.recordAccess("B", 10); + + policy = new PrescientPolicy(); + policy.setTrace(trace); + } + + @Test + public void testGetBlocksToPrefetchAtStart() { + // Window is 5. At T=0, we look at (1, 2, 3, 4, 5] + long currentTime = 0; + List blocks = policy.getBlocksToPrefetch(currentTime); + + // Should prefetch A (at 1), B (at 2), C (at 4) + assertEquals(3, blocks.size()); + assertTrue(blocks.containsAll(List.of("A", "B", "C"))); + } + + @Test + public void testGetBlocksToPrefetchInMiddle() { + // At T=5, we look at (6, 7, 8, 9, 10] + long currentTime = 5; + List blocks = policy.getBlocksToPrefetch(currentTime); + + // Should prefetch D (at 6), A (at 7), C (at 8), E (at 9), B (at 10) + assertEquals(5, blocks.size()); + assertTrue(blocks.containsAll(List.of("A", "B", "C", "D", "E"))); + } + + @Test + public void testGetBlocksToPrefetchAtEnd() { + // At T=10, we look at (11, 12, 13, 14, 15] + long currentTime = 10; + List blocks = policy.getBlocksToPrefetch(currentTime); + + // No blocks accessed in this window + assertEquals(0, blocks.size()); + assertTrue(blocks.isEmpty()); + } + + @Test + public void testEvictFindsNeverUsedBlock() { + // Cache has A, B, C, D. E is pinned (not in this cache set). + // Time T=6 (just after D was used) + long currentTime = 6; + Set cache = Set.of("A", "B", "C", "D"); + List pinned = List.of(); + + // Next uses: + // A: at T=7 + // B: at T=10 + // C: at T=8 + // D: never (Long.MAX_VALUE) + + String evictCandidate = policy.evict(cache, pinned, currentTime); + // Policy should immediately evict D + assertEquals("D", evictCandidate); + } + + @Ignore @Test - public void testBasicEviction() { - PrescientPolicy policy = new PrescientPolicy(); + public void testEvictFindsFurthestInFuture() { + // Cache has A, B, C. D was already evicted. E is pinned. + long currentTime = 6; + Set cache = Set.of("A", "B", "C", "E"); + List pinned = List.of("E"); // E cannot be evicted - policy.setAccessTime("block1", 10); - policy.setAccessTime("block2", 40); - policy.setAccessTime("block3", 25); + // Next uses: + // A: at T=7 + // B: at T=10 + // C: at T=8 + // E: at T=9 (but pinned) - Set candidates = new HashSet<>(); - assertNull(policy.selectBlockForEviction(candidates)); + String evictCandidate = policy.evict(cache, pinned, currentTime); + // Policy should evict B (used at T=10, furthest away) + assertEquals("B", evictCandidate); + } - candidates.add("block1"); - candidates.add("block2"); - candidates.add("block3"); - assertEquals("block2", policy.selectBlockForEviction(candidates)); + @Test + public void testEvictAllPinned() { + // All blocks in cache are pinned + long currentTime = 0; + Set cache = Set.of("A", "E"); + List pinned = List.of("A", "E"); + String evictCandidate = policy.evict(cache, pinned, currentTime); + // Should return null (no valid eviction candidate) + assertNull(evictCandidate); } + + // @Test +// public void testBasicEviction() { +// PrescientPolicy policy = new PrescientPolicy(); +// +// policy.setAccessTime("block1", 10); +// policy.setAccessTime("block2", 40); +// policy.setAccessTime("block3", 25); +// +// Set candidates = new HashSet<>(); +// assertNull(policy.selectBlockForEviction(candidates)); +// +// candidates.add("block1"); +// candidates.add("block2"); +// candidates.add("block3"); +// assertEquals("block2", policy.selectBlockForEviction(candidates)); +// +// } } From b99a00c7da97a5d1575730a00e618e870bfacfec Mon Sep 17 00:00:00 2001 From: Janardhan Pulivarthi Date: Sat, 25 Oct 2025 08:51:45 +0530 Subject: [PATCH 08/15] also add furthest blocks to evictcandidate list --- .../caching/prescientbuffer/PrescientPolicy.java | 8 +++++++- .../caching/prescientbuffer/PrescientPolicyTest.java | 1 - 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/prescientbuffer/PrescientPolicy.java b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/prescientbuffer/PrescientPolicy.java index cb8df7b6865..63af8a8c8fb 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/prescientbuffer/PrescientPolicy.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/prescientbuffer/PrescientPolicy.java @@ -122,11 +122,17 @@ public String evict(Set cache, List pinned, long currentTime) { // Find the next time this block will be used long nextAccessTime = findNextAccess(blockID, currentTime); - // find the block that's never used again + // case 1: find the block that's never used again if (nextAccessTime == Long.MAX_VALUE) { return blockID; } + // case 2: find the block that's the furthest + if (nextAccessTime > maxNextAccessTime) { + maxNextAccessTime = nextAccessTime; + evictCandidate = blockID; + } + } return evictCandidate; diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/prescientbuffer/PrescientPolicyTest.java b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/prescientbuffer/PrescientPolicyTest.java index b538ab34435..cbb7f363283 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/prescientbuffer/PrescientPolicyTest.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/prescientbuffer/PrescientPolicyTest.java @@ -119,7 +119,6 @@ public void testEvictFindsNeverUsedBlock() { assertEquals("D", evictCandidate); } - @Ignore @Test public void testEvictFindsFurthestInFuture() { // Cache has A, B, C. D was already evicted. E is pinned. From 1068c6e80ff97ca37a5d90b3d8a5c7abbcbbe84d Mon Sep 17 00:00:00 2001 From: Janardhan Pulivarthi Date: Sun, 26 Oct 2025 07:36:58 +0530 Subject: [PATCH 09/15] add efficient linkedlist based methods * add getNextAccessTime() for O(1) evict queries * add getBlocksInWindow() for prefetch lookup * cleanup() for memomry management with sliding window * ArrayList to LinkedList for O(1) removal * clear() method --- .../caching/prescientbuffer/IOTrace.java | 75 +++++++++++++++++-- 1 file changed, 70 insertions(+), 5 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/prescientbuffer/IOTrace.java b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/prescientbuffer/IOTrace.java index 7514f2f296e..05fcdbbcd43 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/prescientbuffer/IOTrace.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/prescientbuffer/IOTrace.java @@ -21,6 +21,7 @@ import java.util.ArrayList; import java.util.HashMap; +import java.util.LinkedList; import java.util.List; import java.util.Map; @@ -30,7 +31,7 @@ public class IOTrace { // Block ID vs unique accesses - private final Map> _trace; + private final Map> _trace; private long _currentTime; @@ -43,7 +44,7 @@ public IOTrace() { * Access to the block at a current time */ public void recordAccess(String blockID, long logicalTime) { - _trace.computeIfAbsent(blockID, k -> new ArrayList<>()).add(logicalTime); + _trace.computeIfAbsent(blockID, k -> new LinkedList<>()).add(logicalTime); } /** @@ -51,11 +52,75 @@ public void recordAccess(String blockID, long logicalTime) { * @param blockID Block ID * @return all the access times */ - public List getAccessTime(String blockID) { - return _trace.getOrDefault(blockID, new ArrayList<>()); + public LinkedList getAccessTime(String blockID) { + return _trace.getOrDefault(blockID, new LinkedList<>()); } - public Map> getTrace() { + /** + * Get the next access time for a block after currentTime + * + * @param blockID the block identifier + * @param currentTime current logical time + * @return next access time or Long.MAX_VALUE if never accessed again + */ + public long getNextAccessTime(String blockID, long currentTime) { + LinkedList accesses = getAccessTime(blockID); + if (accesses == null || accesses.isEmpty()) { + return Long.MAX_VALUE; // won't access again + } + + return accesses.peekFirst(); + } + + /** + * Get all the blocks in a prefetch window + * + * @param currentTime current logical time + * @param windowSize prefetch lookahead window + * @return List of block IDs to prefetch + */ + public List getBlocksInWindow(long currentTime, long windowSize) { + List blocks = new ArrayList<>(); + + for (Map.Entry> entry : _trace.entrySet()) { + String blockID = entry.getKey(); + long nextAccess = getNextAccessTime(blockID, currentTime); + + if ( nextAccess != Long.MAX_VALUE && + nextAccess > currentTime && + nextAccess <= currentTime + windowSize) { + blocks.add(blockID); + } + } + + return blocks; + } + + /** + * clean up the trace entries outside the sliding window + * this may or may not be required. + * + * @param currentTime current logical time + */ + public void cleanup(long currentTime) { + _trace.entrySet().removeIf(entry -> { + LinkedList accesses = _trace.get(entry.getKey()); + accesses.removeIf(time -> time <= currentTime); + return accesses.isEmpty(); + }); + } + + /** + * Get the complete trace for debugging + * @return view of the _trace + */ + public Map> getTrace() { return _trace; } + + // clear all trace data + public void clear() { + _trace.clear(); + _currentTime = 0; + } } From 1c9a21611847531c1281ca6e70f85639ffb1b6f1 Mon Sep 17 00:00:00 2001 From: Janardhan Pulivarthi Date: Sun, 26 Oct 2025 07:37:54 +0530 Subject: [PATCH 10/15] use methods in IOTrace for policy --- .../prescientbuffer/PrescientPolicy.java | 83 +------------------ 1 file changed, 2 insertions(+), 81 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/prescientbuffer/PrescientPolicy.java b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/prescientbuffer/PrescientPolicy.java index 63af8a8c8fb..400c34b9ea6 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/prescientbuffer/PrescientPolicy.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/prescientbuffer/PrescientPolicy.java @@ -21,12 +21,8 @@ import org.apache.sysds.runtime.controlprogram.caching.EvictionPolicy; -import java.util.ArrayList; import java.util.Collections; -import java.util.HashMap; -import java.util.HashSet; import java.util.List; -import java.util.Map; import java.util.Set; /** @@ -34,69 +30,10 @@ */ public class PrescientPolicy implements EvictionPolicy { - // Map of block ID, access times - private final Map accessTimeMap = new HashMap<>(); private IOTrace _trace; // Defines how many logical time units to look ahead for prefetching private static final int PREFETCH_WINDOW = 5; - // register blocks with access time - public void setAccessTime(String blockId, long accessTime) { - accessTimeMap.put(blockId, accessTime); - } - - /** - * Select a block to evict from the given list of candidates - * - * @param candidates A set of candidate block identifiers for currently in buffer - * @return The identifier of the block chosen for eviction - */ - @Override - public String selectBlockForEviction(Set candidates) { - // base case - if (candidates == null || candidates.isEmpty()) { - return null; - } - - String selected = null; - long maxTime = -1; - - for (String candidate : candidates) { - long time = accessTimeMap.getOrDefault(candidate, Long.MAX_VALUE); - - if (time > maxTime) { - maxTime = time; - selected = candidate; - } - } - - return selected; - } - - /** - * Finds the next time a block is accessed, after the current time. - * - * @param blockID The block to check - * @param currentTime The current logical time - * @return The logical time of the next access, or Long.MAX_VALUE if never used again. - */ - private long findNextAccess(String blockID, long currentTime) { - if (_trace == null) { - return Long.MAX_VALUE; - } - - List accessTimes = _trace.getAccessTime(blockID); - // Find the first access time that is greater than the current time - for (long time : accessTimes) { - if (time > currentTime) { - return time; - } - } - - // This block is never accessed again in the future - return Long.MAX_VALUE; - } - /** * Finds the unpinned block that won't be used in near future (or never used). * @@ -120,7 +57,7 @@ public String evict(Set cache, List pinned, long currentTime) { } // Find the next time this block will be used - long nextAccessTime = findNextAccess(blockID, currentTime); + long nextAccessTime = _trace.getNextAccessTime(blockID, currentTime); // case 1: find the block that's never used again if (nextAccessTime == Long.MAX_VALUE) { @@ -150,23 +87,7 @@ public List getBlocksToPrefetch(long currentTime) { return Collections.emptyList(); } - // Use a Set to store unique block IDs - Set blocksToPrefetch = new HashSet<>(); - long lookaheadTime = currentTime + PREFETCH_WINDOW; - - // Iterate over all blocks in the trace - for (String blockID : _trace.getTrace().keySet()) { - List accessTimes = _trace.getAccessTime(blockID); - - // Check if this block is accessed within our prefetch window - for (long time : accessTimes) { - if (time > currentTime && time <= lookaheadTime) { - blocksToPrefetch.add(blockID); - } - } - } - - return new ArrayList<>(blocksToPrefetch); + return _trace.getBlocksInWindow(currentTime, PREFETCH_WINDOW); } public void setTrace(IOTrace ioTrace) { From 368e9ae5b966382c23816de19e478beade744e29 Mon Sep 17 00:00:00 2001 From: Janardhan Pulivarthi Date: Sun, 26 Oct 2025 08:04:02 +0530 Subject: [PATCH 11/15] fix getNextAccessTime logic * in the accesses list, remove the older access from currenttime --- .../controlprogram/caching/prescientbuffer/IOTrace.java | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/prescientbuffer/IOTrace.java b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/prescientbuffer/IOTrace.java index 05fcdbbcd43..3bc40b0c63e 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/prescientbuffer/IOTrace.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/prescientbuffer/IOTrace.java @@ -65,11 +65,16 @@ public LinkedList getAccessTime(String blockID) { */ public long getNextAccessTime(String blockID, long currentTime) { LinkedList accesses = getAccessTime(blockID); - if (accesses == null || accesses.isEmpty()) { + if (accesses == null) { return Long.MAX_VALUE; // won't access again } - return accesses.peekFirst(); + // remove past accesses from the front of this list + while (!accesses.isEmpty() && accesses.peekFirst() <= currentTime) { + accesses.removeFirst(); + } + + return accesses.isEmpty() ? Long.MAX_VALUE : accesses.peekFirst(); } /** From 2d4f70880ec71f754c5c6ac524de8ed0c37a00e3 Mon Sep 17 00:00:00 2001 From: Janardhan Pulivarthi Date: Sun, 26 Oct 2025 13:54:45 +0530 Subject: [PATCH 12/15] methods for IOTraceGenerator --- .../prescientbuffer/IOTraceGenerator.java | 52 +++++++++++++++++++ .../prescientbuffer/PrescientPolicyTest.java | 17 ------ 2 files changed, 52 insertions(+), 17 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/prescientbuffer/IOTraceGenerator.java b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/prescientbuffer/IOTraceGenerator.java index 0d5f8154f86..5371c473836 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/prescientbuffer/IOTraceGenerator.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/prescientbuffer/IOTraceGenerator.java @@ -19,6 +19,7 @@ package org.apache.sysds.runtime.controlprogram.caching.prescientbuffer; +import org.apache.sysds.hops.OptimizerUtils; import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.controlprogram.BasicProgramBlock; import org.apache.sysds.runtime.controlprogram.ForProgramBlock; @@ -27,6 +28,7 @@ import org.apache.sysds.runtime.controlprogram.Program; import org.apache.sysds.runtime.controlprogram.ProgramBlock; import org.apache.sysds.runtime.controlprogram.WhileProgramBlock; +import org.apache.sysds.runtime.controlprogram.caching.UnifiedMemoryManager; import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; import org.apache.sysds.runtime.instructions.Instruction; import org.apache.sysds.runtime.instructions.cp.CPOperand; @@ -34,6 +36,8 @@ import org.apache.sysds.runtime.meta.DataCharacteristics; import java.util.ArrayList; +import java.util.List; + /** * IOTraceGenerator is responsible for analyzing the program plan (LOP DAG) @@ -41,6 +45,8 @@ */ public class IOTraceGenerator { + private long _logicalTime = 0; + /** * Generate the IOTrace for the execution plan. * This is utilized by ExecutionContext @@ -160,4 +166,50 @@ private static String createBlockID(String fname, long rowIndex, long colIndex) System.out.println(fname + "_" + rowIndex + "_" + colIndex); return fname + "_" + rowIndex + "_" + colIndex; } + + /** + * Extract all input blocks from an instruction at runtime. + * + * @param inst The instruction that's about to execute + * @param ec ExecutionContext + * @param logicalTime current logical time + * @return List of block IDs that will be accessed + */ + public static List extractBlockIDs(Instruction inst, ExecutionContext ec, long logicalTime) { + List blockIDs = new ArrayList<>(); + + String instStr = inst.toString(); + // Extract variable names from instruction operands + // Example: "rblk pREADX.MATRIX.FP64 _mVar0.MATRIX.FP64 1000 true" + String[] parts = instStr.split("\\s+"); + + for (String part : parts) { + if (part.contains(".MATRIX.") || part.contains(".FRAME.") || part.contains(".TENSOR.")) { + String varName = part.substring(0, part.indexOf('.')); + + // Only add if variable exists in symbol table + if (ec.containsVariable(varName)) { + blockIDs.add(varName); + } + } + } + return blockIDs; + } + + /** + * Generates the I/O trace for the entire program and sets it in the UMM. + * This should be called once after the program is compiled but before execution. + */ + public static void generateAndSetIOTrace(Program prog, ExecutionContext ec) { + if (prog == null || !OptimizerUtils.isUMMEnabled()) { + return; + } + + // Generate the trace + IOTrace trace = generateTrace(prog, ec); + + // Set it in the UMM + UnifiedMemoryManager.setTrace(trace); + } + } diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/prescientbuffer/PrescientPolicyTest.java b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/prescientbuffer/PrescientPolicyTest.java index cbb7f363283..1c775b9b2bc 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/prescientbuffer/PrescientPolicyTest.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/prescientbuffer/PrescientPolicyTest.java @@ -149,21 +149,4 @@ public void testEvictAllPinned() { assertNull(evictCandidate); } - // @Test -// public void testBasicEviction() { -// PrescientPolicy policy = new PrescientPolicy(); -// -// policy.setAccessTime("block1", 10); -// policy.setAccessTime("block2", 40); -// policy.setAccessTime("block3", 25); -// -// Set candidates = new HashSet<>(); -// assertNull(policy.selectBlockForEviction(candidates)); -// -// candidates.add("block1"); -// candidates.add("block2"); -// candidates.add("block3"); -// assertEquals("block2", policy.selectBlockForEviction(candidates)); -// -// } } From a5031bf43e0bf0d8b4caf470495b00eb345f87b3 Mon Sep 17 00:00:00 2001 From: Janardhan Pulivarthi Date: Sun, 26 Oct 2025 13:55:27 +0530 Subject: [PATCH 13/15] add integration to program block --- .../runtime/controlprogram/ProgramBlock.java | 3 ++ .../context/ExecutionContext.java | 37 ++++++++++++++++++- 2 files changed, 39 insertions(+), 1 deletion(-) diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/ProgramBlock.java b/src/main/java/org/apache/sysds/runtime/controlprogram/ProgramBlock.java index aee08516db6..bd899345009 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/ProgramBlock.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/ProgramBlock.java @@ -229,6 +229,9 @@ private void executeSingleInstruction(Instruction currInst, ExecutionContext ec) if(!LineageCache.reuse(tmp, ec)) { long et0 = (!ReuseCacheType.isNone() || DMLScript.LINEAGE_ESTIMATE) ? System.nanoTime() : 0; + // record IO Access + ec.recordIOAccess(tmp); + // process actual instruction tmp.processInstruction(ec); diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java b/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java index 2658d35cf9f..3bb6a4f0b3e 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java @@ -21,7 +21,6 @@ import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; -import org.apache.hadoop.yarn.webapp.hamlet2.HamletSpec; import org.apache.sysds.api.DMLScript; import org.apache.sysds.common.Types; import org.apache.sysds.common.Types.FileFormat; @@ -39,6 +38,8 @@ import org.apache.sysds.runtime.controlprogram.caching.MatrixObject.UpdateType; import org.apache.sysds.runtime.controlprogram.caching.TensorObject; import org.apache.sysds.runtime.controlprogram.caching.prescientbuffer.IOTrace; +import org.apache.sysds.runtime.controlprogram.caching.prescientbuffer.IOTraceGenerator; +import org.apache.sysds.runtime.controlprogram.caching.UnifiedMemoryManager; import org.apache.sysds.runtime.controlprogram.federated.MatrixLineagePair; import org.apache.sysds.runtime.controlprogram.paramserv.homomorphicEncryption.SEALClient; import org.apache.sysds.runtime.data.TensorBlock; @@ -95,8 +96,25 @@ public class ExecutionContext { //parfor temporary functions (created by eval) protected Set _fnNames; + private IOTraceGenerator _ioTraceGenerator; + private long _logicalTime = 0; private IOTrace _ioTrace; + public IOTraceGenerator getIOTraceGenerator() { + if (_ioTraceGenerator == null) { + _ioTraceGenerator = new IOTraceGenerator(); + } + return _ioTraceGenerator; + } + + public long get_logicalTime() { + return _logicalTime; + } + + public void set_logicalTime(long _logicalTime) { + this._logicalTime = _logicalTime; + } + public IOTrace getIOTrace() { if (_ioTrace == null) { _ioTrace = new IOTrace(); @@ -108,6 +126,23 @@ public void setIOTrace(IOTrace ioTrace) { _ioTrace = ioTrace; } + public void recordIOAccess(Instruction inst) { + if (!OptimizerUtils.isUMMEnabled() || _ioTrace == null) { + return; + } + + // just increment time - the trace is already pre-built + // Use IOTraceGenerator's static method + List blockIDs = IOTraceGenerator.extractBlockIDs(inst, this, _logicalTime); + + for (String blockID : blockIDs) { + _ioTrace.recordAccess(blockID, _logicalTime); + } + + _logicalTime++; + UnifiedMemoryManager.updateTime(_logicalTime); + } + /** * List of {@link GPUContext}s owned by this {@link ExecutionContext} */ From 4d2c6094e063befbef307f5145ee9545cd16014b Mon Sep 17 00:00:00 2001 From: Janardhan Pulivarthi Date: Sun, 26 Oct 2025 14:01:28 +0530 Subject: [PATCH 14/15] use a separate method for ooc evictions --- .../prescientbuffer/OOCEvictionManager.java | 39 +++++++++++++++++++ 1 file changed, 39 insertions(+) create mode 100644 src/main/java/org/apache/sysds/runtime/controlprogram/caching/prescientbuffer/OOCEvictionManager.java diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/prescientbuffer/OOCEvictionManager.java b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/prescientbuffer/OOCEvictionManager.java new file mode 100644 index 00000000000..1b059712a7c --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/prescientbuffer/OOCEvictionManager.java @@ -0,0 +1,39 @@ +package org.apache.sysds.runtime.controlprogram.caching.prescientbuffer; + +import org.apache.sysds.runtime.controlprogram.caching.CacheBlock; +import org.apache.sysds.runtime.controlprogram.caching.CacheEvictionQueue; + +public class OOCEvictionManager { + private static OOCEvictionManager _instance; + + // Queue of cached OOC stream blocks (similar to LazyWriteBuffer) + private CacheEvictionQueue _streamQueue; + private long _cacheLimit; + private long _currentSize; + + private OOCEvictionManager() { + _streamQueue = new CacheEvictionQueue(); + _cacheLimit = 0;/* configure based on OOC memory budget */ + _currentSize = 0; + } + + public static OOCEvictionManager getInstance() { + if (_instance == null) { + synchronized (OOCEvictionManager.class) { + if (_instance == null) { + _instance = new OOCEvictionManager(); + } + } + } + return _instance; + } + + // Add a block to the cache + public synchronized void cacheBlock(String blockID, CacheBlock block) { /* ... */ } + + // Evict blocks to make space + public synchronized void makeSpace(long requiredSize) { /* ... */ } + + // Get a cached block (updates LRU order) + public synchronized CacheBlock getBlock(String blockID) { /* ... */ } +} From 4e5cd5b2410b7d6734d9d8022bfb77642eb1eaa6 Mon Sep 17 00:00:00 2001 From: Janardhan Pulivarthi Date: Sun, 26 Oct 2025 17:12:32 +0530 Subject: [PATCH 15/15] umm integration --- .../java/org/apache/sysds/api/DMLScript.java | 8 ++++++++ .../controlprogram/caching/CacheableData.java | 2 +- .../caching/EvictionPolicy.java | 19 ++++++++++++------- .../caching/UnifiedMemoryManager.java | 4 ---- 4 files changed, 21 insertions(+), 12 deletions(-) diff --git a/src/main/java/org/apache/sysds/api/DMLScript.java b/src/main/java/org/apache/sysds/api/DMLScript.java index 65805b5c2ed..de060ab2f07 100644 --- a/src/main/java/org/apache/sysds/api/DMLScript.java +++ b/src/main/java/org/apache/sysds/api/DMLScript.java @@ -85,6 +85,7 @@ import org.apache.sysds.utils.NativeHelper; import org.apache.sysds.utils.SettingsChecker; import org.apache.sysds.utils.Statistics; +import org.apache.sysds.runtime.controlprogram.caching.prescientbuffer.IOTraceGenerator; import com.fasterxml.jackson.databind.ObjectMapper; @@ -494,6 +495,13 @@ private static void execute(String dmlScriptStr, String fnameOptConfig, Map candidates); +// /** +// * Finds the unpinned block that won't be used in near future (or never used) +// * to evict from the cache. +// * +// * @param cache The set of all block IDs currently in the buffer +// * @param pinned The list of all block IDs that are pinned +// * @param currentTime The current logical time +// * @return The block ID used for eviction, or null if all blocks are pinned. +// */ +// public String evict(Set cache, List pinned, long currentTime); + } diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/UnifiedMemoryManager.java b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/UnifiedMemoryManager.java index adeb41df826..7845eaf0a63 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/UnifiedMemoryManager.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/UnifiedMemoryManager.java @@ -398,10 +398,6 @@ public static int makeSpace(long reqSpace) { // Remove the chosen block from the queue ByteBuffer bb = _mQueue.remove(ftmp); -// //remove first unpinned entry from eviction queue -// var entry = _mQueue.removeFirstUnpinned(_pinnedEntries); -// String ftmp = entry.getKey(); -// ByteBuffer bb = entry.getValue(); if(bb != null) { // Wait for pending serialization