Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/main/java/org/apache/sysds/api/DMLScript.java
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@
import org.apache.sysds.utils.NativeHelper;
import org.apache.sysds.utils.SettingsChecker;
import org.apache.sysds.utils.Statistics;
import org.apache.sysds.runtime.controlprogram.caching.prescientbuffer.IOTraceGenerator;

import com.fasterxml.jackson.databind.ObjectMapper;

Expand Down Expand Up @@ -494,6 +495,13 @@ private static void execute(String dmlScriptStr, String fnameOptConfig, Map<Stri
ExecutionContext ec = null;
try {
ec = ExecutionContextFactory.createContext(rtprog);

// Generate IO trace after ec in created but before execution
if (OptimizerUtils.isUMMEnabled()) {
// IOTraceGenerator.generateAndSetIOTrace(rtprog, ec);

}

ScriptExecutorUtils.executeRuntimeProgram(rtprog, ec, ConfigurationManager.getDMLConfig(), STATISTICS ? STATISTICS_COUNT : 0, null);
}
finally {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,9 @@ private void executeSingleInstruction(Instruction currInst, ExecutionContext ec)
if(!LineageCache.reuse(tmp, ec)) {
long et0 = (!ReuseCacheType.isNone() || DMLScript.LINEAGE_ESTIMATE) ? System.nanoTime() : 0;

// record IO Access
ec.recordIOAccess(tmp);

// process actual instruction
tmp.processInstruction(ec);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1236,7 +1236,7 @@ protected boolean isEqualOutputFormat(String outputFormat) {

// ------------- IMPLEMENTED CACHE LOGIC METHODS --------------

protected String getCacheFilePathAndName () {
public String getCacheFilePathAndName() {
if( _cacheFileName==null ) {
StringBuilder sb = new StringBuilder();
sb.append(CacheableData.cacheEvictionLocalFilePath);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.sysds.runtime.controlprogram.caching;

import java.util.List;
import java.util.Set;

/**
* An interface all Buffer pool eviction policies,
* for pluggable eviction strategies - LRU, FIFO, Prescient
*/
public interface EvictionPolicy {

// /**
// * Finds the unpinned block that won't be used in near future (or never used)
// * to evict from the cache.
// *
// * @param cache The set of all block IDs currently in the buffer
// * @param pinned The list of all block IDs that are pinned
// * @param currentTime The current logical time
// * @return The block ID used for eviction, or null if all blocks are pinned.
// */
// public String evict(Set<String> cache, List<String> pinned, long currentTime);

}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.prescientbuffer.IOTrace;
import org.apache.sysds.runtime.controlprogram.caching.prescientbuffer.PrescientPolicy;
import org.apache.sysds.runtime.util.LocalFileUtils;

import java.io.IOException;
Expand Down Expand Up @@ -106,6 +108,12 @@ public class UnifiedMemoryManager
// Maintenance service for synchronous or asynchronous delete of evicted files
private static CacheMaintenanceService _fClean;

private static EvictionPolicy _evictionPolicy;
// Prescient policy
private static PrescientPolicy _prescientPolicy;
private static IOTrace _ioTrace;
private static long _currentTime = 0;

// Pinned size of physical memory. Starts from 0 for each operation. Max is 70% of heap
// This increases only if the input is not present in the cache and read from FS/rdd/fed/gpu
private static long _pinnedPhysicalMemSize = 0;
Expand Down Expand Up @@ -186,6 +194,7 @@ public static void init() {
_totCachedSize = 0;
_pinnedPhysicalMemSize = 0;
_pinnedVirtualMemSize = 0;
_evictionPolicy = new PrescientPolicy();
}

// Cleanup the unified memory manager
Expand All @@ -199,6 +208,66 @@ public static void cleanup() {
_pinnedVirtualMemSize = 0;
}

/**
* Sets the I/O trace for the prescient policy.
* This should be called once by the ExecutionContext after trace generation.
* @param trace The generated IOTrace
*/
public static void setTrace(IOTrace trace) {
_ioTrace = trace;
if (_evictionPolicy instanceof PrescientPolicy) {
_prescientPolicy = (PrescientPolicy) _evictionPolicy;
_prescientPolicy.setTrace(_ioTrace);
}
else {
// Optional: Log a warning if the trace is set but the policy isn't Prescient
// LOG.warn("IOTrace was provided, but eviction policy is not prescient!");
}
}

/**
* Updates the UMM's logical time.
* This should be called by the ExecutionContext *before* each instruction.
* @param logicalTime The new logical time
*/
public static void updateTime(long logicalTime) {
_currentTime = logicalTime;
prefetch();
}

/**
* Prefetches blocks that will be needed soon, based on the I/O trace.
*/
private static void prefetch() {
if (_ioTrace == null || _prescientPolicy == null) {
return; // No trace or policy, cannot prefetch
}

// Get the list of blocks to prefetch from our policy
List<String> blocksToPrefetch = _prescientPolicy.getBlocksToPrefetch(_currentTime);

// A real implementation MUST use an asynchronous thread pool
// (e.g., from _fClean) to load these blocks without blocking the main thread.

for (String blockID : blocksToPrefetch) {
synchronized (_mQueue) {
// Check again inside lock if block was already loaded or pinned
if (_mQueue.containsKey(blockID) || _pinnedEntries.contains(blockID)) {
continue; // Already in memory
}
}

// --- This is a simplified version for now ---
// TODO: Submit an async prefetch task to _fClean's thread pool
// The task should: 1. Get block size (from metadata)
// 2. Call makeSpace(blockSize)
// 3. Load block from disk
// 4. Add block to _mQueue (synchronized)

System.out.println("UMM PREFETCH [T="+_currentTime+"]: Planning to prefetch " + blockID);
}
}

/**
* Print current status of UMM, including all entries.
* NOTE: use only for debugging or testing.
Expand Down Expand Up @@ -304,10 +373,31 @@ public static int makeSpace(long reqSpace) {
synchronized(_mQueue) {
// Evict blobs to make room (by default FIFO)
while (getUMMFree() < reqSpace && !_mQueue.isEmpty()) {
//remove first unpinned entry from eviction queue
var entry = _mQueue.removeFirstUnpinned(_pinnedEntries);
String ftmp = entry.getKey();
ByteBuffer bb = entry.getValue();
// --- NEW PRESCIENT LOGIC ---
String ftmp; // Block ID / filename to evict

if (_prescientPolicy != null && _ioTrace != null) {
// Use prescient policy to find the best block to evict
ftmp = _prescientPolicy.evict(_mQueue.keySet(), _pinnedEntries, _currentTime);
} else {
// Fallback to default LRU if prescient policy isn't set or has no trace
var entry = _mQueue.removeFirstUnpinned(_pinnedEntries);
ftmp = (entry != null) ? entry.getKey() : null;
}

if (ftmp == null) {
// Policy couldn't find a block to evict (e.g., all are pinned)
if(!_pinnedEntries.containsAll(_mQueue.keySet())) {
// This case should ideally not be reached if unpinned blocks exist
throw new DMLRuntimeException("UMM: Eviction policy failed to find a candidate.");
}
// If we are here, all blocks are pinned, and we cannot make space.
// The original exception will be thrown later.
break; // Exit the while loop
}

// Remove the chosen block from the queue
ByteBuffer bb = _mQueue.remove(ftmp);

if(bb != null) {
// Wait for pending serialization
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.sysds.runtime.controlprogram.caching.prescientbuffer;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;

/**
* IOTrace holds the pre-computed I/O access trace for the OOC operations.
*/
public class IOTrace {

// Block ID vs unique accesses
private final Map<String, LinkedList<Long>> _trace;

private long _currentTime;

public IOTrace() {
_trace = new HashMap<>();
_currentTime = 0;
}

/**
* Access to the block at a current time
*/
public void recordAccess(String blockID, long logicalTime) {
_trace.computeIfAbsent(blockID, k -> new LinkedList<>()).add(logicalTime);
}

/**
* Get all access times for a given block
* @param blockID Block ID
* @return all the access times
*/
public LinkedList<Long> getAccessTime(String blockID) {
return _trace.getOrDefault(blockID, new LinkedList<>());
}

/**
* Get the next access time for a block after currentTime
*
* @param blockID the block identifier
* @param currentTime current logical time
* @return next access time or Long.MAX_VALUE if never accessed again
*/
public long getNextAccessTime(String blockID, long currentTime) {
LinkedList<Long> accesses = getAccessTime(blockID);
if (accesses == null) {
return Long.MAX_VALUE; // won't access again
}

// remove past accesses from the front of this list
while (!accesses.isEmpty() && accesses.peekFirst() <= currentTime) {
accesses.removeFirst();
}

return accesses.isEmpty() ? Long.MAX_VALUE : accesses.peekFirst();
}

/**
* Get all the blocks in a prefetch window
*
* @param currentTime current logical time
* @param windowSize prefetch lookahead window
* @return List of block IDs to prefetch
*/
public List<String> getBlocksInWindow(long currentTime, long windowSize) {
List<String> blocks = new ArrayList<>();

for (Map.Entry<String, LinkedList<Long>> entry : _trace.entrySet()) {
String blockID = entry.getKey();
long nextAccess = getNextAccessTime(blockID, currentTime);

if ( nextAccess != Long.MAX_VALUE &&
nextAccess > currentTime &&
nextAccess <= currentTime + windowSize) {
blocks.add(blockID);
}
}

return blocks;
}

/**
* clean up the trace entries outside the sliding window
* this may or may not be required.
*
* @param currentTime current logical time
*/
public void cleanup(long currentTime) {
_trace.entrySet().removeIf(entry -> {
LinkedList<Long> accesses = _trace.get(entry.getKey());
accesses.removeIf(time -> time <= currentTime);
return accesses.isEmpty();
});
}

/**
* Get the complete trace for debugging
* @return view of the _trace
*/
public Map<String, LinkedList<Long>> getTrace() {
return _trace;
}

// clear all trace data
public void clear() {
_trace.clear();
_currentTime = 0;
}
}
Loading