diff --git a/src/main/java/org/apache/sysds/api/DMLScript.java b/src/main/java/org/apache/sysds/api/DMLScript.java index 65805b5c2ed..de060ab2f07 100644 --- a/src/main/java/org/apache/sysds/api/DMLScript.java +++ b/src/main/java/org/apache/sysds/api/DMLScript.java @@ -85,6 +85,7 @@ import org.apache.sysds.utils.NativeHelper; import org.apache.sysds.utils.SettingsChecker; import org.apache.sysds.utils.Statistics; +import org.apache.sysds.runtime.controlprogram.caching.prescientbuffer.IOTraceGenerator; import com.fasterxml.jackson.databind.ObjectMapper; @@ -494,6 +495,13 @@ private static void execute(String dmlScriptStr, String fnameOptConfig, Map cache, List pinned, long currentTime); + +} diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/UnifiedMemoryManager.java b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/UnifiedMemoryManager.java index 91f700e85f6..7845eaf0a63 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/UnifiedMemoryManager.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/UnifiedMemoryManager.java @@ -23,6 +23,8 @@ import org.apache.sysds.api.DMLScript; import org.apache.sysds.hops.OptimizerUtils; import org.apache.sysds.runtime.DMLRuntimeException; +import org.apache.sysds.runtime.controlprogram.caching.prescientbuffer.IOTrace; +import org.apache.sysds.runtime.controlprogram.caching.prescientbuffer.PrescientPolicy; import org.apache.sysds.runtime.util.LocalFileUtils; import java.io.IOException; @@ -106,6 +108,12 @@ public class UnifiedMemoryManager // Maintenance service for synchronous or asynchronous delete of evicted files private static CacheMaintenanceService _fClean; + private static EvictionPolicy _evictionPolicy; + // Prescient policy + private static PrescientPolicy _prescientPolicy; + private static IOTrace _ioTrace; + private static long _currentTime = 0; + // Pinned size of physical memory. Starts from 0 for each operation. Max is 70% of heap // This increases only if the input is not present in the cache and read from FS/rdd/fed/gpu private static long _pinnedPhysicalMemSize = 0; @@ -186,6 +194,7 @@ public static void init() { _totCachedSize = 0; _pinnedPhysicalMemSize = 0; _pinnedVirtualMemSize = 0; + _evictionPolicy = new PrescientPolicy(); } // Cleanup the unified memory manager @@ -199,6 +208,66 @@ public static void cleanup() { _pinnedVirtualMemSize = 0; } + /** + * Sets the I/O trace for the prescient policy. + * This should be called once by the ExecutionContext after trace generation. + * @param trace The generated IOTrace + */ + public static void setTrace(IOTrace trace) { + _ioTrace = trace; + if (_evictionPolicy instanceof PrescientPolicy) { + _prescientPolicy = (PrescientPolicy) _evictionPolicy; + _prescientPolicy.setTrace(_ioTrace); + } + else { + // Optional: Log a warning if the trace is set but the policy isn't Prescient + // LOG.warn("IOTrace was provided, but eviction policy is not prescient!"); + } + } + + /** + * Updates the UMM's logical time. + * This should be called by the ExecutionContext *before* each instruction. + * @param logicalTime The new logical time + */ + public static void updateTime(long logicalTime) { + _currentTime = logicalTime; + prefetch(); + } + + /** + * Prefetches blocks that will be needed soon, based on the I/O trace. + */ + private static void prefetch() { + if (_ioTrace == null || _prescientPolicy == null) { + return; // No trace or policy, cannot prefetch + } + + // Get the list of blocks to prefetch from our policy + List blocksToPrefetch = _prescientPolicy.getBlocksToPrefetch(_currentTime); + + // A real implementation MUST use an asynchronous thread pool + // (e.g., from _fClean) to load these blocks without blocking the main thread. + + for (String blockID : blocksToPrefetch) { + synchronized (_mQueue) { + // Check again inside lock if block was already loaded or pinned + if (_mQueue.containsKey(blockID) || _pinnedEntries.contains(blockID)) { + continue; // Already in memory + } + } + + // --- This is a simplified version for now --- + // TODO: Submit an async prefetch task to _fClean's thread pool + // The task should: 1. Get block size (from metadata) + // 2. Call makeSpace(blockSize) + // 3. Load block from disk + // 4. Add block to _mQueue (synchronized) + + System.out.println("UMM PREFETCH [T="+_currentTime+"]: Planning to prefetch " + blockID); + } + } + /** * Print current status of UMM, including all entries. * NOTE: use only for debugging or testing. @@ -304,10 +373,31 @@ public static int makeSpace(long reqSpace) { synchronized(_mQueue) { // Evict blobs to make room (by default FIFO) while (getUMMFree() < reqSpace && !_mQueue.isEmpty()) { - //remove first unpinned entry from eviction queue - var entry = _mQueue.removeFirstUnpinned(_pinnedEntries); - String ftmp = entry.getKey(); - ByteBuffer bb = entry.getValue(); + // --- NEW PRESCIENT LOGIC --- + String ftmp; // Block ID / filename to evict + + if (_prescientPolicy != null && _ioTrace != null) { + // Use prescient policy to find the best block to evict + ftmp = _prescientPolicy.evict(_mQueue.keySet(), _pinnedEntries, _currentTime); + } else { + // Fallback to default LRU if prescient policy isn't set or has no trace + var entry = _mQueue.removeFirstUnpinned(_pinnedEntries); + ftmp = (entry != null) ? entry.getKey() : null; + } + + if (ftmp == null) { + // Policy couldn't find a block to evict (e.g., all are pinned) + if(!_pinnedEntries.containsAll(_mQueue.keySet())) { + // This case should ideally not be reached if unpinned blocks exist + throw new DMLRuntimeException("UMM: Eviction policy failed to find a candidate."); + } + // If we are here, all blocks are pinned, and we cannot make space. + // The original exception will be thrown later. + break; // Exit the while loop + } + + // Remove the chosen block from the queue + ByteBuffer bb = _mQueue.remove(ftmp); if(bb != null) { // Wait for pending serialization diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/prescientbuffer/IOTrace.java b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/prescientbuffer/IOTrace.java new file mode 100644 index 00000000000..3bc40b0c63e --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/prescientbuffer/IOTrace.java @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.controlprogram.caching.prescientbuffer; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; + +/** + * IOTrace holds the pre-computed I/O access trace for the OOC operations. + */ +public class IOTrace { + + // Block ID vs unique accesses + private final Map> _trace; + + private long _currentTime; + + public IOTrace() { + _trace = new HashMap<>(); + _currentTime = 0; + } + + /** + * Access to the block at a current time + */ + public void recordAccess(String blockID, long logicalTime) { + _trace.computeIfAbsent(blockID, k -> new LinkedList<>()).add(logicalTime); + } + + /** + * Get all access times for a given block + * @param blockID Block ID + * @return all the access times + */ + public LinkedList getAccessTime(String blockID) { + return _trace.getOrDefault(blockID, new LinkedList<>()); + } + + /** + * Get the next access time for a block after currentTime + * + * @param blockID the block identifier + * @param currentTime current logical time + * @return next access time or Long.MAX_VALUE if never accessed again + */ + public long getNextAccessTime(String blockID, long currentTime) { + LinkedList accesses = getAccessTime(blockID); + if (accesses == null) { + return Long.MAX_VALUE; // won't access again + } + + // remove past accesses from the front of this list + while (!accesses.isEmpty() && accesses.peekFirst() <= currentTime) { + accesses.removeFirst(); + } + + return accesses.isEmpty() ? Long.MAX_VALUE : accesses.peekFirst(); + } + + /** + * Get all the blocks in a prefetch window + * + * @param currentTime current logical time + * @param windowSize prefetch lookahead window + * @return List of block IDs to prefetch + */ + public List getBlocksInWindow(long currentTime, long windowSize) { + List blocks = new ArrayList<>(); + + for (Map.Entry> entry : _trace.entrySet()) { + String blockID = entry.getKey(); + long nextAccess = getNextAccessTime(blockID, currentTime); + + if ( nextAccess != Long.MAX_VALUE && + nextAccess > currentTime && + nextAccess <= currentTime + windowSize) { + blocks.add(blockID); + } + } + + return blocks; + } + + /** + * clean up the trace entries outside the sliding window + * this may or may not be required. + * + * @param currentTime current logical time + */ + public void cleanup(long currentTime) { + _trace.entrySet().removeIf(entry -> { + LinkedList accesses = _trace.get(entry.getKey()); + accesses.removeIf(time -> time <= currentTime); + return accesses.isEmpty(); + }); + } + + /** + * Get the complete trace for debugging + * @return view of the _trace + */ + public Map> getTrace() { + return _trace; + } + + // clear all trace data + public void clear() { + _trace.clear(); + _currentTime = 0; + } +} diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/prescientbuffer/IOTraceGenerator.java b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/prescientbuffer/IOTraceGenerator.java new file mode 100644 index 00000000000..5371c473836 --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/prescientbuffer/IOTraceGenerator.java @@ -0,0 +1,215 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.controlprogram.caching.prescientbuffer; + +import org.apache.sysds.hops.OptimizerUtils; +import org.apache.sysds.runtime.DMLRuntimeException; +import org.apache.sysds.runtime.controlprogram.BasicProgramBlock; +import org.apache.sysds.runtime.controlprogram.ForProgramBlock; +import org.apache.sysds.runtime.controlprogram.FunctionProgramBlock; +import org.apache.sysds.runtime.controlprogram.IfProgramBlock; +import org.apache.sysds.runtime.controlprogram.Program; +import org.apache.sysds.runtime.controlprogram.ProgramBlock; +import org.apache.sysds.runtime.controlprogram.WhileProgramBlock; +import org.apache.sysds.runtime.controlprogram.caching.UnifiedMemoryManager; +import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; +import org.apache.sysds.runtime.instructions.Instruction; +import org.apache.sysds.runtime.instructions.cp.CPOperand; +import org.apache.sysds.runtime.instructions.ooc.ReblockOOCInstruction; +import org.apache.sysds.runtime.meta.DataCharacteristics; + +import java.util.ArrayList; +import java.util.List; + + +/** + * IOTraceGenerator is responsible for analyzing the program plan (LOP DAG) + * and generating the predictive I/O trace, before runtime. + */ +public class IOTraceGenerator { + + private long _logicalTime = 0; + + /** + * Generate the IOTrace for the execution plan. + * This is utilized by ExecutionContext + * + * @param program the entire program + * @return IOTrace object with trace data + */ + public static IOTrace generateTrace(Program program, ExecutionContext ec) { + IOTrace _trace = new IOTrace(); + + // Use a long array as a "pass-by-reference" wrapper for the logical time + // so it can be incremented correctly inside the recursive calls. + long[] logicalTime = new long[]{0}; + + // Start the recursive traversal + traverseProgramBlocks(program.getProgramBlocks(), ec, _trace, logicalTime); + + return _trace; + } + + /** + * Recursively traverses a list of program blocks for generating the trace + * + * @param programBlocks The list of ProgramBlocks to traverse + * @param ec The ExecutionContext + * @param trace The trace object to populate + * @param logicalTime A pass-by-reference counter for logical time + */ + private static void traverseProgramBlocks(ArrayList programBlocks, ExecutionContext ec, IOTrace trace, long[] logicalTime) { + + if (programBlocks == null) { + return; + } + + for (ProgramBlock pb : programBlocks) { + if (pb instanceof BasicProgramBlock) { + // --- Base Case --- + // This block has instructions, so process them. + BasicProgramBlock bpb = (BasicProgramBlock) pb; + for (Instruction inst : bpb.getInstructions()) { + logicalTime[0]++; // Increment logical time for each instruction + processInstruction(inst, ec, trace, logicalTime[0]); + } + } + else if (pb instanceof IfProgramBlock) { + // --- Recursive Step --- + // Traverse into the 'if' and 'else' bodies (it is the special case.) + // for, while, function we won't is body + IfProgramBlock ifpb = (IfProgramBlock) pb; + traverseProgramBlocks(ifpb.getChildBlocksIfBody(), ec, trace, logicalTime); + traverseProgramBlocks(ifpb.getChildBlocksElseBody(), ec, trace, logicalTime); + } + else if (pb instanceof WhileProgramBlock) { + // --- Recursive Step --- + // For a static trace, we can only traverse the body once. + // A more advanced tracer might try to unroll N times. + WhileProgramBlock wpb = (WhileProgramBlock) pb; + traverseProgramBlocks(wpb.getChildBlocks(), ec, trace, logicalTime); + } + else if (pb instanceof ForProgramBlock) { + // --- Recursive Step --- + // Similar to While, just traverse the body once for a static trace. + ForProgramBlock fpb = (ForProgramBlock) pb; + traverseProgramBlocks(fpb.getChildBlocks(), ec, trace, logicalTime); + } + else if (pb instanceof FunctionProgramBlock) { + // --- Recursive Step --- + // Traverse into the function body + FunctionProgramBlock fnpb = (FunctionProgramBlock) pb; + traverseProgramBlocks(fnpb.getChildBlocks(), ec, trace, logicalTime); + } + } + } + + /** + * Processes a single instruction and records its I/O access pattern in the trace. + * + * @param inst The instruction to process + * @param ec The ExecutionContext + * @param trace The trace object to populate + * @param logicalTime The logical time of this instruction + */ + private static void processInstruction(Instruction inst, ExecutionContext ec, IOTrace trace, long logicalTime) { + + + if (inst instanceof ReblockOOCInstruction) { + ReblockOOCInstruction rblk = (ReblockOOCInstruction) inst; + CPOperand input = rblk.input1; + + // We need the file name and data characteristics from the metadata + String fname = ec.getMatrixObject(input).getFileName(); + DataCharacteristics mc = ec.getDataCharacteristics(input.getName()); + + if (mc == null || !mc.dimsKnown()) { + throw new DMLRuntimeException("OOC Trace Generator: DataCharacteristics not available for " + input.getName()); + } + + long numRowBlocks = mc.getNumRowBlocks(); + long numColBlocks = mc.getNumColBlocks(); + + for (long i = 1; i <= numRowBlocks; i++) { + for (long j = 1; j <= numColBlocks; j++) { + + String blockID = createBlockID(fname, i, j); + trace.recordAccess(blockID, logicalTime); + } + } + } + + // (Transpose, MatrixVector, Tee, etc.) to build the full trace. + // else if (inst instanceof MatrixVectorOOCInstruction) { + // // ... handle matrix-vector read pattern ... + // } + } + + private static String createBlockID(String fname, long rowIndex, long colIndex) { + System.out.println(fname + "_" + rowIndex + "_" + colIndex); + return fname + "_" + rowIndex + "_" + colIndex; + } + + /** + * Extract all input blocks from an instruction at runtime. + * + * @param inst The instruction that's about to execute + * @param ec ExecutionContext + * @param logicalTime current logical time + * @return List of block IDs that will be accessed + */ + public static List extractBlockIDs(Instruction inst, ExecutionContext ec, long logicalTime) { + List blockIDs = new ArrayList<>(); + + String instStr = inst.toString(); + // Extract variable names from instruction operands + // Example: "rblk pREADX.MATRIX.FP64 _mVar0.MATRIX.FP64 1000 true" + String[] parts = instStr.split("\\s+"); + + for (String part : parts) { + if (part.contains(".MATRIX.") || part.contains(".FRAME.") || part.contains(".TENSOR.")) { + String varName = part.substring(0, part.indexOf('.')); + + // Only add if variable exists in symbol table + if (ec.containsVariable(varName)) { + blockIDs.add(varName); + } + } + } + return blockIDs; + } + + /** + * Generates the I/O trace for the entire program and sets it in the UMM. + * This should be called once after the program is compiled but before execution. + */ + public static void generateAndSetIOTrace(Program prog, ExecutionContext ec) { + if (prog == null || !OptimizerUtils.isUMMEnabled()) { + return; + } + + // Generate the trace + IOTrace trace = generateTrace(prog, ec); + + // Set it in the UMM + UnifiedMemoryManager.setTrace(trace); + } + +} diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/prescientbuffer/OOCEvictionManager.java b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/prescientbuffer/OOCEvictionManager.java new file mode 100644 index 00000000000..1b059712a7c --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/prescientbuffer/OOCEvictionManager.java @@ -0,0 +1,39 @@ +package org.apache.sysds.runtime.controlprogram.caching.prescientbuffer; + +import org.apache.sysds.runtime.controlprogram.caching.CacheBlock; +import org.apache.sysds.runtime.controlprogram.caching.CacheEvictionQueue; + +public class OOCEvictionManager { + private static OOCEvictionManager _instance; + + // Queue of cached OOC stream blocks (similar to LazyWriteBuffer) + private CacheEvictionQueue _streamQueue; + private long _cacheLimit; + private long _currentSize; + + private OOCEvictionManager() { + _streamQueue = new CacheEvictionQueue(); + _cacheLimit = 0;/* configure based on OOC memory budget */ + _currentSize = 0; + } + + public static OOCEvictionManager getInstance() { + if (_instance == null) { + synchronized (OOCEvictionManager.class) { + if (_instance == null) { + _instance = new OOCEvictionManager(); + } + } + } + return _instance; + } + + // Add a block to the cache + public synchronized void cacheBlock(String blockID, CacheBlock block) { /* ... */ } + + // Evict blocks to make space + public synchronized void makeSpace(long requiredSize) { /* ... */ } + + // Get a cached block (updates LRU order) + public synchronized CacheBlock getBlock(String blockID) { /* ... */ } +} diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/prescientbuffer/PrescientPolicy.java b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/prescientbuffer/PrescientPolicy.java new file mode 100644 index 00000000000..400c34b9ea6 --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/prescientbuffer/PrescientPolicy.java @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.controlprogram.caching.prescientbuffer; + +import org.apache.sysds.runtime.controlprogram.caching.EvictionPolicy; + +import java.util.Collections; +import java.util.List; +import java.util.Set; + +/** + * Implement prescient buffer + */ +public class PrescientPolicy implements EvictionPolicy { + + private IOTrace _trace; + // Defines how many logical time units to look ahead for prefetching + private static final int PREFETCH_WINDOW = 5; + + /** + * Finds the unpinned block that won't be used in near future (or never used). + * + * @param cache The set of all block IDs currently in the buffer + * @param pinned The list of all block IDs that are pinned + * @param currentTime The current logical time + * @return The block ID used for eviction, or null if all blocks are pinned. + */ + public String evict(Set cache, List pinned, long currentTime) { + if (cache == null || cache.isEmpty()) { + return null; + } + + String evictCandidate = null; + long maxNextAccessTime = -1; // We're looking for the largest access time + + for (String blockID : cache) { + // Cannot evict a pinned block + if (pinned.contains(blockID)) { + continue; + } + + // Find the next time this block will be used + long nextAccessTime = _trace.getNextAccessTime(blockID, currentTime); + + // case 1: find the block that's never used again + if (nextAccessTime == Long.MAX_VALUE) { + return blockID; + } + + // case 2: find the block that's the furthest + if (nextAccessTime > maxNextAccessTime) { + maxNextAccessTime = nextAccessTime; + evictCandidate = blockID; + } + + } + + return evictCandidate; + } + + /** + * Sliding Window implementation: + * Looks ahead N time units and finds all unique blocks accessed in that window. + * + * @param currentTime The current logical time + * @return A list of unique block IDs to prefetch + */ + public List getBlocksToPrefetch(long currentTime) { + if (_trace == null) { + return Collections.emptyList(); + } + + return _trace.getBlocksInWindow(currentTime, PREFETCH_WINDOW); + } + + public void setTrace(IOTrace ioTrace) { + _trace = ioTrace; + } +} diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/prescientbuffer/PrescientPolicyTest.java b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/prescientbuffer/PrescientPolicyTest.java new file mode 100644 index 00000000000..1c775b9b2bc --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/prescientbuffer/PrescientPolicyTest.java @@ -0,0 +1,152 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.controlprogram.caching.prescientbuffer; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; + +import org.junit.Before; +import org.junit.Ignore; +import org.junit.Test; +import java.util.List; +import java.util.Set; + +public class PrescientPolicyTest { + + private IOTrace trace; + private PrescientPolicy policy; + + /** + * Creates a mock IOTrace for all tests. + * * Access Pattern: + * T=1: A + * T=2: B + * T=3: A + * T=4: C + * T=5: B + * T=6: D (D is never used again) + * T=7: A + * T=8: C + * T=9: E (E will be pinned) + * T=10: B + */ + @Before + public void setUp() { + trace = new IOTrace(); + // Access times are automatically sorted by IOTrace + trace.recordAccess("A", 1); + trace.recordAccess("B", 2); + trace.recordAccess("A", 3); + trace.recordAccess("C", 4); + trace.recordAccess("B", 5); + trace.recordAccess("D", 6); // D is never used again after this + trace.recordAccess("A", 7); + trace.recordAccess("C", 8); + trace.recordAccess("E", 9); // E will be our pinned block + trace.recordAccess("B", 10); + + policy = new PrescientPolicy(); + policy.setTrace(trace); + } + + @Test + public void testGetBlocksToPrefetchAtStart() { + // Window is 5. At T=0, we look at (1, 2, 3, 4, 5] + long currentTime = 0; + List blocks = policy.getBlocksToPrefetch(currentTime); + + // Should prefetch A (at 1), B (at 2), C (at 4) + assertEquals(3, blocks.size()); + assertTrue(blocks.containsAll(List.of("A", "B", "C"))); + } + + @Test + public void testGetBlocksToPrefetchInMiddle() { + // At T=5, we look at (6, 7, 8, 9, 10] + long currentTime = 5; + List blocks = policy.getBlocksToPrefetch(currentTime); + + // Should prefetch D (at 6), A (at 7), C (at 8), E (at 9), B (at 10) + assertEquals(5, blocks.size()); + assertTrue(blocks.containsAll(List.of("A", "B", "C", "D", "E"))); + } + + @Test + public void testGetBlocksToPrefetchAtEnd() { + // At T=10, we look at (11, 12, 13, 14, 15] + long currentTime = 10; + List blocks = policy.getBlocksToPrefetch(currentTime); + + // No blocks accessed in this window + assertEquals(0, blocks.size()); + assertTrue(blocks.isEmpty()); + } + + @Test + public void testEvictFindsNeverUsedBlock() { + // Cache has A, B, C, D. E is pinned (not in this cache set). + // Time T=6 (just after D was used) + long currentTime = 6; + Set cache = Set.of("A", "B", "C", "D"); + List pinned = List.of(); + + // Next uses: + // A: at T=7 + // B: at T=10 + // C: at T=8 + // D: never (Long.MAX_VALUE) + + String evictCandidate = policy.evict(cache, pinned, currentTime); + // Policy should immediately evict D + assertEquals("D", evictCandidate); + } + + @Test + public void testEvictFindsFurthestInFuture() { + // Cache has A, B, C. D was already evicted. E is pinned. + long currentTime = 6; + Set cache = Set.of("A", "B", "C", "E"); + List pinned = List.of("E"); // E cannot be evicted + + // Next uses: + // A: at T=7 + // B: at T=10 + // C: at T=8 + // E: at T=9 (but pinned) + + String evictCandidate = policy.evict(cache, pinned, currentTime); + // Policy should evict B (used at T=10, furthest away) + assertEquals("B", evictCandidate); + } + + @Test + public void testEvictAllPinned() { + // All blocks in cache are pinned + long currentTime = 0; + Set cache = Set.of("A", "E"); + List pinned = List.of("A", "E"); + + String evictCandidate = policy.evict(cache, pinned, currentTime); + // Should return null (no valid eviction candidate) + assertNull(evictCandidate); + } + +} diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java b/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java index fa87d452d15..3bb6a4f0b3e 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java @@ -37,6 +37,9 @@ import org.apache.sysds.runtime.controlprogram.caching.MatrixObject; import org.apache.sysds.runtime.controlprogram.caching.MatrixObject.UpdateType; import org.apache.sysds.runtime.controlprogram.caching.TensorObject; +import org.apache.sysds.runtime.controlprogram.caching.prescientbuffer.IOTrace; +import org.apache.sysds.runtime.controlprogram.caching.prescientbuffer.IOTraceGenerator; +import org.apache.sysds.runtime.controlprogram.caching.UnifiedMemoryManager; import org.apache.sysds.runtime.controlprogram.federated.MatrixLineagePair; import org.apache.sysds.runtime.controlprogram.paramserv.homomorphicEncryption.SEALClient; import org.apache.sysds.runtime.data.TensorBlock; @@ -93,6 +96,53 @@ public class ExecutionContext { //parfor temporary functions (created by eval) protected Set _fnNames; + private IOTraceGenerator _ioTraceGenerator; + private long _logicalTime = 0; + private IOTrace _ioTrace; + + public IOTraceGenerator getIOTraceGenerator() { + if (_ioTraceGenerator == null) { + _ioTraceGenerator = new IOTraceGenerator(); + } + return _ioTraceGenerator; + } + + public long get_logicalTime() { + return _logicalTime; + } + + public void set_logicalTime(long _logicalTime) { + this._logicalTime = _logicalTime; + } + + public IOTrace getIOTrace() { + if (_ioTrace == null) { + _ioTrace = new IOTrace(); + } + return _ioTrace; + } + + public void setIOTrace(IOTrace ioTrace) { + _ioTrace = ioTrace; + } + + public void recordIOAccess(Instruction inst) { + if (!OptimizerUtils.isUMMEnabled() || _ioTrace == null) { + return; + } + + // just increment time - the trace is already pre-built + // Use IOTraceGenerator's static method + List blockIDs = IOTraceGenerator.extractBlockIDs(inst, this, _logicalTime); + + for (String blockID : blockIDs) { + _ioTrace.recordAccess(blockID, _logicalTime); + } + + _logicalTime++; + UnifiedMemoryManager.updateTime(_logicalTime); + } + /** * List of {@link GPUContext}s owned by this {@link ExecutionContext} */