diff --git a/src/main/java/org/apache/sysds/common/InstructionType.java b/src/main/java/org/apache/sysds/common/InstructionType.java index 29148f03e92..e0e77c46c59 100644 --- a/src/main/java/org/apache/sysds/common/InstructionType.java +++ b/src/main/java/org/apache/sysds/common/InstructionType.java @@ -89,5 +89,5 @@ public enum InstructionType { PMM, MatrixReshape, Write, - Init, + Init, Tee, } diff --git a/src/main/java/org/apache/sysds/common/Opcodes.java b/src/main/java/org/apache/sysds/common/Opcodes.java index 6cd15611283..251f773a18c 100644 --- a/src/main/java/org/apache/sysds/common/Opcodes.java +++ b/src/main/java/org/apache/sysds/common/Opcodes.java @@ -220,6 +220,7 @@ public enum Opcodes { READ("read", InstructionType.Variable), WRITE("write", InstructionType.Variable, InstructionType.Write), CREATEVAR("createvar", InstructionType.Variable), + TEE("tee", InstructionType.Tee), //Reorg instruction opcodes TRANSPOSE("r'", InstructionType.Reorg), diff --git a/src/main/java/org/apache/sysds/hops/DataOp.java b/src/main/java/org/apache/sysds/hops/DataOp.java index eb0d1961cf5..0aac3ca70c8 100644 --- a/src/main/java/org/apache/sysds/hops/DataOp.java +++ b/src/main/java/org/apache/sysds/hops/DataOp.java @@ -38,6 +38,7 @@ import org.apache.sysds.common.Types.ExecType; import org.apache.sysds.lops.LopsException; import org.apache.sysds.lops.Sql; +import org.apache.sysds.lops.Tee; import org.apache.sysds.parser.DataExpression; import static org.apache.sysds.parser.DataExpression.FED_RANGES; import org.apache.sysds.runtime.controlprogram.caching.MatrixObject.UpdateType; @@ -60,6 +61,8 @@ public class DataOp extends Hop { private boolean _recompileRead = true; + private boolean _isTeeOp = false; + /** * List of "named" input parameters. They are maintained as a hashmap: * parameter names (String) are mapped as indices (Integer) into getInput() @@ -73,6 +76,10 @@ public class DataOp extends Hop { private DataOp() { //default constructor for clone } + + public void setIsTeeOp(boolean isTeeOp) { + this._isTeeOp = isTeeOp; + } /** * READ operation for Matrix w/ dim1, dim2. @@ -251,57 +258,65 @@ public Lop constructLops() ExecType et = optFindExecType(); Lop l = null; - - // construct lops for all input parameters - HashMap inputLops = new HashMap<>(); - for (Entry cur : _paramIndexMap.entrySet()) { - inputLops.put(cur.getKey(), getInput().get(cur.getValue()).constructLops()); + + if (_isTeeOp) { + Tee teeLop = new Tee(getInput().get(0).constructLops(), + getDataType(), getValueType()); + setLineNumbers(teeLop); + setLops(teeLop); + setOutputDimensions(teeLop); } + else { - // Create the lop - switch(_op) - { - case TRANSIENTREAD: - l = new Data(_op, null, inputLops, getName(), null, - getDataType(), getValueType(), getFileFormat()); - setOutputDimensions(l); - break; - - case PERSISTENTREAD: - l = new Data(_op, null, inputLops, getName(), null, - getDataType(), getValueType(), getFileFormat()); - l.getOutputParameters().setDimensions(getDim1(), getDim2(), _inBlocksize, getNnz(), getUpdateType()); - break; - - case PERSISTENTWRITE: - case FUNCTIONOUTPUT: - l = new Data(_op, getInput().get(0).constructLops(), inputLops, getName(), null, - getDataType(), getValueType(), getFileFormat()); - ((Data)l).setExecType(et); - setOutputDimensions(l); - break; - - case TRANSIENTWRITE: - l = new Data(_op, getInput().get(0).constructLops(), inputLops, getName(), null, - getDataType(), getValueType(), getFileFormat()); - setOutputDimensions(l); - break; - - case SQLREAD: - l = new Sql(inputLops, getDataType(), getValueType()); - break; - - case FEDERATED: - l = new Federated(inputLops, getDataType(), getValueType()); - break; - - default: - throw new LopsException("Invalid operation type for Data LOP: " + _op); + // construct lops for all input parameters + HashMap inputLops = new HashMap<>(); + for (Entry cur : _paramIndexMap.entrySet()) { + inputLops.put(cur.getKey(), getInput().get(cur.getValue()).constructLops()); + } + + // Create the lop + switch (_op) { + case TRANSIENTREAD: + l = new Data(_op, null, inputLops, getName(), null, + getDataType(), getValueType(), getFileFormat()); + setOutputDimensions(l); + break; + + case PERSISTENTREAD: + l = new Data(_op, null, inputLops, getName(), null, + getDataType(), getValueType(), getFileFormat()); + l.getOutputParameters().setDimensions(getDim1(), getDim2(), _inBlocksize, getNnz(), getUpdateType()); + break; + + case PERSISTENTWRITE: + case FUNCTIONOUTPUT: + l = new Data(_op, getInput().get(0).constructLops(), inputLops, getName(), null, + getDataType(), getValueType(), getFileFormat()); + ((Data) l).setExecType(et); + setOutputDimensions(l); + break; + + case TRANSIENTWRITE: + l = new Data(_op, getInput().get(0).constructLops(), inputLops, getName(), null, + getDataType(), getValueType(), getFileFormat()); + setOutputDimensions(l); + break; + + case SQLREAD: + l = new Sql(inputLops, getDataType(), getValueType()); + break; + + case FEDERATED: + l = new Federated(inputLops, getDataType(), getValueType()); + break; + + default: + throw new LopsException("Invalid operation type for Data LOP: " + _op); + } + setLineNumbers(l); + setLops(l); } - setLineNumbers(l); - setLops(l); - //add reblock/checkpoint lops if necessary constructAndSetLopsDataFlowProperties(); @@ -346,6 +361,9 @@ public boolean isFederatedDataOp(){ public String getOpString() { String s = new String(""); s += _op.toString(); + if (_isTeeOp) { + s += " tee"; + } s += " "+getName(); return s; } @@ -536,6 +554,7 @@ public Object clone() throws CloneNotSupportedException ret._inFormat = _inFormat; ret._inBlocksize = _inBlocksize; ret._recompileRead = _recompileRead; + ret._isTeeOp = _isTeeOp; // copy the Tee flag ret._paramIndexMap = (HashMap) _paramIndexMap.clone(); //note: no deep cp of params since read-only diff --git a/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java b/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java index 874ddae0347..c2602dba510 100644 --- a/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java +++ b/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java @@ -77,6 +77,7 @@ public ProgramRewriter(boolean staticRewrites, boolean dynamicRewrites) //add static HOP DAG rewrite rules _dagRuleSet.add( new RewriteRemoveReadAfterWrite() ); //dependency: before blocksize _dagRuleSet.add( new RewriteBlockSizeAndReblock() ); + _dagRuleSet.add( new RewriteInjectOOCTee() ); if( OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION ) _dagRuleSet.add( new RewriteRemoveUnnecessaryCasts() ); if( OptimizerUtils.ALLOW_COMMON_SUBEXPRESSION_ELIMINATION ) diff --git a/src/main/java/org/apache/sysds/hops/rewrite/RewriteInjectOOCTee.java b/src/main/java/org/apache/sysds/hops/rewrite/RewriteInjectOOCTee.java new file mode 100644 index 00000000000..6fa8f1849f0 --- /dev/null +++ b/src/main/java/org/apache/sysds/hops/rewrite/RewriteInjectOOCTee.java @@ -0,0 +1,255 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.hops.rewrite; + +import org.apache.sysds.api.DMLScript; +import org.apache.sysds.common.Types; +import org.apache.sysds.hops.AggBinaryOp; +import org.apache.sysds.hops.DataOp; +import org.apache.sysds.hops.Hop; +import org.apache.sysds.hops.ReorgOp; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + +/** + * This Rewrite rule injects a Tee Operator for specific Out-Of-Core (OOC) patterns + * where a value or an intermediate result is shared twice. Since for OOC we data streams + * can only be consumed once. + * + *

+ * Pattern identified {@code t(X) %*% X}, where the data {@code X} will be shared by + * {@code t(X)} and {@code %*%} multiplication. + *

+ * + * The rewrite uses a stable two-pass approach: + * 1. Find candidates (Read-Only): Traverse the entire HOP DAG to identify candidates + * the fit the target pattern. + * 2. Apply Rewrites (Modification): Iterate over the collected candidate and put + * {@code TeeOp}, and safely rewire the graph. + */ +public class RewriteInjectOOCTee extends HopRewriteRule { + + private static final Set rewrittenHops = new HashSet<>(); + private static final Map handledHop = new HashMap<>(); + + // Maintain a list of candidates to rewrite in the second pass + private final List rewriteCandidates = new ArrayList<>(); + + /** + * Handle a generic (last-level) hop DAG with multiple roots. + * + * @param roots high-level operator roots + * @param state program rewrite status + * @return list of high-level operators + */ + @Override + public ArrayList rewriteHopDAGs(ArrayList roots, ProgramRewriteStatus state) { + if (roots == null) { + return null; + } + + // Clear candidates for this pass + rewriteCandidates.clear(); + + // PASS 1: Identify candidates without modifying the graph + for (Hop root : roots) { + root.resetVisitStatus(); + findRewriteCandidates(root); + } + + // PASS 2: Apply rewrites to identified candidates + for (Hop candidate : rewriteCandidates) { + applyTopDownTeeRewrite(candidate); + } + + return roots; + } + + /** + * Handle a predicate hop DAG with exactly one root. + * + * @param root high-level operator root + * @param state program rewrite status + * @return high-level operator + */ + @Override + public Hop rewriteHopDAG(Hop root, ProgramRewriteStatus state) { + if (root == null) { + return null; + } + + // Clear candidates for this pass + rewriteCandidates.clear(); + + // PASS 1: Identify candidates without modifying the graph + root.resetVisitStatus(); + findRewriteCandidates(root); + + // PASS 2: Apply rewrites to identified candidates + for (Hop candidate : rewriteCandidates) { + applyTopDownTeeRewrite(candidate); + } + + return root; + } + + /** + * First pass: Find candidates for rewrite without modifying the graph. + * This method traverses the graph and identifies nodes that need to be + * rewritten based on the transpose-matrix multiply pattern. + * + * @param hop current hop being examined + */ + private void findRewriteCandidates(Hop hop) { + if (hop.isVisited()) { + return; + } + + // Mark as visited to avoid processing the same hop multiple times + hop.setVisited(true); + + // Recursively traverse the graph (depth-first) + for (Hop input : hop.getInput()) { + findRewriteCandidates(input); + } + + // Check if this hop is a candidate for OOC Tee injection + if (isRewriteCandidate(hop)) { + rewriteCandidates.add(hop); + } + } + + /** + * Check if a hop should be considered for rewrite. + * + * @param hop the hop to check + * @return true if the hop meets all criteria for rewrite + */ + private boolean isRewriteCandidate(Hop hop) { + // Skip if already handled + if (rewrittenHops.contains(hop.getHopID()) || handledHop.containsKey(hop.getHopID())) { + return false; + } + + boolean multipleConsumers = hop.getParent().size() > 1; + boolean isNotAlreadyTee = isNotAlreadyTee(hop); + boolean isOOCEnabled = DMLScript.USE_OOC; + boolean isTransposeMM = isTranposePattern(hop); + boolean isMatrix = hop.getDataType() == Types.DataType.MATRIX; + + return isOOCEnabled && multipleConsumers && isNotAlreadyTee && isTransposeMM && isMatrix; + } + + /** + * Second pass: Apply the TeeOp transformation to a candidate hop. + * This safely rewires the graph by creating a TeeOp node and placeholders. + * + * @param sharedInput the hop to be rewritten + */ + private void applyTopDownTeeRewrite(Hop sharedInput) { + // Only process if not already handled + if (handledHop.containsKey(sharedInput.getHopID())) { + return; + } + + // Take a defensive copy of consumers before modifying the graph + ArrayList consumers = new ArrayList<>(sharedInput.getParent()); + + // Create the new TeeOp with the original hop as input + DataOp teeOp = new DataOp("tee_out_" + sharedInput.getName(), + sharedInput.getDataType(), + sharedInput.getValueType(), + Types.OpOpData.TRANSIENTWRITE, + null, + sharedInput.getDim1(), + sharedInput.getDim2(), + sharedInput.getNnz(), + sharedInput.getBlocksize() + ); + + teeOp.setIsTeeOp(true); + HopRewriteUtils.addChildReference(teeOp, sharedInput); + + // Rewire the graph: replace original connections with TeeOp outputs + int i = 0; + for (Hop consumer : consumers) { + Hop placeholder = new DataOp("tee_out_" + sharedInput.getName() + "_" + i, + sharedInput.getDataType(), + sharedInput.getValueType(), + Types.OpOpData.TRANSIENTWRITE, + null, + sharedInput.getDim1(), + sharedInput.getDim2(), + sharedInput.getNnz(), + sharedInput.getBlocksize() + ); + + // Copy metadata + placeholder.setBeginLine(sharedInput.getBeginLine()); + placeholder.setBeginColumn(sharedInput.getBeginColumn()); + placeholder.setEndLine(sharedInput.getEndLine()); + placeholder.setEndColumn(sharedInput.getEndColumn()); + + // Connect placeholder to TeeOp and consumer + HopRewriteUtils.addChildReference(placeholder, teeOp); + HopRewriteUtils.replaceChildReference(consumer, sharedInput, placeholder); + + i++; + } + + // Record that we've handled this hop + handledHop.put(sharedInput.getHopID(), teeOp); + rewrittenHops.add(sharedInput.getHopID()); + } + + private boolean isNotAlreadyTee(Hop hop) { + if (hop.getParent().size() > 1) { + for (Hop consumer : hop.getParent()) { + if (consumer instanceof DataOp) { + return false; + } + } + } + return true; + } + + private boolean isTranposePattern (Hop hop) { + boolean hasTransposeConsumer = false; // t(X) + boolean hasMatrixMultiplyConsumer = false; // %*% + + for (Hop parent: hop.getParent()) { + String opString = parent.getOpString(); + if (parent instanceof ReorgOp) { + if (HopRewriteUtils.isTransposeOperation(parent)) { + hasTransposeConsumer = true; + } + } + else if (HopRewriteUtils.isMatrixMultiply(parent)) { + hasMatrixMultiplyConsumer = true; + } + } + return hasTransposeConsumer && hasMatrixMultiplyConsumer; + } +} diff --git a/src/main/java/org/apache/sysds/lops/Lop.java b/src/main/java/org/apache/sysds/lops/Lop.java index 447201a5fd3..7bfea43c2e5 100644 --- a/src/main/java/org/apache/sysds/lops/Lop.java +++ b/src/main/java/org/apache/sysds/lops/Lop.java @@ -63,7 +63,8 @@ public enum Type { PlusMult, MinusMult, //CP SpoofFused, //CP/SP generated fused operator Sql, //CP sql read - Federated //FED federated read + Federated, //FED federated read + Tee, //OOC Tee operator } diff --git a/src/main/java/org/apache/sysds/lops/Tee.java b/src/main/java/org/apache/sysds/lops/Tee.java new file mode 100644 index 00000000000..6734e69cf31 --- /dev/null +++ b/src/main/java/org/apache/sysds/lops/Tee.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.lops; + +import org.apache.sysds.common.Types; +import org.apache.sysds.common.Types.DataType; +import org.apache.sysds.common.Types.ValueType; +import org.apache.sysds.runtime.instructions.InstructionUtils; + +public class Tee extends Lop { + + public static final String OPCODE = "tee"; + /** + * Constructor to be invoked by base class. + * + * @param input1 lop type + * @param dt data type of the output + * @param vt value type of the output + */ + public Tee(Lop input1, DataType dt, ValueType vt) { + super(Lop.Type.Tee, dt, vt); + this.addInput(input1); + input1.addOutput(this); + lps.setProperties(inputs, Types.ExecType.OOC); + } + + @Override + public String toString() { + return "Operation = Tee"; + } + + @Override + public String getInstructions(String input1, String outputs) { + + String[] out = outputs.split(Lop.OPERAND_DELIMITOR); + + // This method generates the instruction string: OOC°tee°input°output1°output2... + String ret = InstructionUtils.concatOperands( + getExecType().name(), OPCODE, + getInputs().get(0).prepInputOperand(input1), + prepOutputOperand(out[0]), + prepOutputOperand(out[1]) + ); + + return ret; + } +} diff --git a/src/main/java/org/apache/sysds/lops/compile/Dag.java b/src/main/java/org/apache/sysds/lops/compile/Dag.java index b26c539e9a8..05fd88731c3 100644 --- a/src/main/java/org/apache/sysds/lops/compile/Dag.java +++ b/src/main/java/org/apache/sysds/lops/compile/Dag.java @@ -548,6 +548,18 @@ else if ( node.getType() == Lop.Type.FunctionCallCP ) outputs[count++] = out.getOutputParameters().getLabel(); inst_string = node.getInstructions(inputs, outputs); } + else if ( node.getType() == Type.Tee ) { + String input = node.getInputs().get(0).getOutputParameters().getLabel(); + + ArrayList outputs = new ArrayList<>(); + for( Lop out : node.getOutputs() ) { + outputs.add(out.getOutputParameters().getLabel()); + } + + String packedOutputs = String.join(Lop.OPERAND_DELIMITOR, outputs); + + inst_string = node.getInstructions(input, packedOutputs); + } else if (node.getType() == Lop.Type.Nary) { String[] inputs = new String[node.getInputs().size()]; int count = 0; diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/ResettableStream.java b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/ResettableStream.java new file mode 100644 index 00000000000..957b03ad7f8 --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/ResettableStream.java @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.controlprogram.parfor; + +import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue; + +import java.util.ArrayList; + +/** + * A wrapper around LocalTaskQueue to consume the source stream and reset to + * consume again for other operators. + * + */ +public class ResettableStream extends LocalTaskQueue { + + // original live stream + private final LocalTaskQueue _source; + + // in-memory cache to store stream for re-play + private final ArrayList _cache; + + // state flags + private boolean _cacheInProgress = true; // caching in progress, in the first pass. + private int _replayPosition = 0; // slider position in the stream + + public ResettableStream(LocalTaskQueue source) { + this._source = source; + this._cache = new ArrayList<>(); + } + + /** + * Dequeues a task. If it is the first, it reads from the disk and stores in the cache. + * For subsequent passes it reads from the memory. + * + * @return The next matrix value in the stream, or NO_MORE_TASKS + * @throws InterruptedException + */ + @Override + public synchronized IndexedMatrixValue dequeueTask() + throws InterruptedException { + if (_cacheInProgress) { + // First pass: Read value from the source and cache it, and return. + IndexedMatrixValue task = _source.dequeueTask(); + if (task != NO_MORE_TASKS) { + _cache.add(new IndexedMatrixValue(task)); + } else { + _cacheInProgress = false; // caching is complete + _source.closeInput(); // close source stream + } + notifyAll(); // Notify all the waiting consumers waiting for cache to fill with this stream + return task; + } else { + // Replay pass: read directly from in-memory cache + if (_replayPosition < _cache.size()) { + // Return a copy to ensure consumer won't modify the cache + return new IndexedMatrixValue(_cache.get(_replayPosition++)); + } else { + return (IndexedMatrixValue) NO_MORE_TASKS; + } + } + } + + /** + * Resets the stream to beginning to read the stream from start. + * This can only be called once the stream is fully consumed once. + */ + public synchronized void reset() throws InterruptedException { + if (_cacheInProgress) { + // Attempted to reset a stream that's not been fully cached yet. + wait(); + } + _replayPosition = 0; + } + + @Override + public synchronized void closeInput() { + + _source.closeInput(); + } +} diff --git a/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java b/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java index 73b5ca02618..94e1a7ff180 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java @@ -30,6 +30,7 @@ import org.apache.sysds.runtime.instructions.ooc.UnaryOOCInstruction; import org.apache.sysds.runtime.instructions.ooc.MatrixVectorBinaryOOCInstruction; import org.apache.sysds.runtime.instructions.ooc.TransposeOOCInstruction; +import org.apache.sysds.runtime.instructions.ooc.TeeOOCInstruction; public class OOCInstructionParser extends InstructionParser { protected static final Log LOG = LogFactory.getLog(OOCInstructionParser.class.getName()); @@ -63,6 +64,8 @@ public static OOCInstruction parseSingleInstruction(InstructionType ooctype, Str return MatrixVectorBinaryOOCInstruction.parseInstruction(str); case Reorg: return TransposeOOCInstruction.parseInstruction(str); + case Tee: + return TeeOOCInstruction.parseInstruction(str); default: throw new DMLRuntimeException("Invalid OOC Instruction Type: " + ooctype); diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java index 0495dcfde51..c76fd4e4a4b 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java @@ -30,7 +30,7 @@ public abstract class OOCInstruction extends Instruction { protected static final Log LOG = LogFactory.getLog(OOCInstruction.class.getName()); public enum OOCType { - Reblock, AggregateUnary, Binary, Unary, MAPMM, Reorg, AggregateBinary + Reblock, AggregateUnary, Binary, Unary, MAPMM, Reorg, AggregateBinary, Tee } protected final OOCInstruction.OOCType _ooctype; diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/TeeOOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/TeeOOCInstruction.java new file mode 100644 index 00000000000..bb5c8e24acc --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/TeeOOCInstruction.java @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.instructions.ooc; + +import org.apache.sysds.runtime.controlprogram.caching.MatrixObject; +import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; +import org.apache.sysds.runtime.controlprogram.parfor.LocalTaskQueue; +import org.apache.sysds.runtime.controlprogram.parfor.ResettableStream; +import org.apache.sysds.runtime.instructions.InstructionUtils; +import org.apache.sysds.runtime.instructions.cp.CPOperand; +import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue; + +import java.util.Arrays; +import java.util.List; + +public class TeeOOCInstruction extends ComputationOOCInstruction { + + private final List _outputs; + + protected TeeOOCInstruction(OOCType type, CPOperand in1, CPOperand out, CPOperand out2, String opcode, String istr) { + super(type, null, in1, out, opcode, istr); + _outputs = Arrays.asList(out, out2); + } + + public static TeeOOCInstruction parseInstruction(String str) { + String[] parts = InstructionUtils.getInstructionPartsWithValueType(str); + InstructionUtils.checkNumFields(parts, 3); + String opcode = parts[0]; + CPOperand in1 = new CPOperand(parts[1]); + CPOperand out = new CPOperand(parts[2]); + CPOperand out2 = new CPOperand(parts[3]); + + return new TeeOOCInstruction(OOCType.Tee, in1, out, out2, opcode, str); + } + + public void processInstruction( ExecutionContext ec ) { + + // Create thread and process the tee operation + MatrixObject min = ec.getMatrixObject(input1); + LocalTaskQueue qIn = min.getStreamHandle(); + + // Create a single, shared, resettable stream (cached) + final ResettableStream sharedStream = new ResettableStream(qIn); + + LocalTaskQueue stream2 = new LocalTaskQueue() { + private boolean isFirstCall = true; + + @Override + public IndexedMatrixValue dequeueTask() + throws InterruptedException { + if (isFirstCall) { + sharedStream.reset(); + isFirstCall = false; + } + return sharedStream.dequeueTask(); + } + + @Override + public void closeInput() { + // This a no-op, since sharedStream is managed internally + } + }; + + CPOperand out1 = _outputs.get(0); + MatrixObject mout1 = ec.createMatrixObject(min.getDataCharacteristics()); + mout1.setStreamHandle(sharedStream); + ec.setVariable(out1.getName(), mout1); + + CPOperand out2 = _outputs.get(1); + MatrixObject mout2 = ec.createMatrixObject(min.getDataCharacteristics()); + mout2.setStreamHandle(stream2); + ec.setVariable(out2.getName(), mout2); + + } +} diff --git a/src/test/java/org/apache/sysds/test/functions/ooc/TeeTest.java b/src/test/java/org/apache/sysds/test/functions/ooc/TeeTest.java new file mode 100644 index 00000000000..63357719f01 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/ooc/TeeTest.java @@ -0,0 +1,127 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.ooc; + +import org.apache.sysds.common.Opcodes; +import org.apache.sysds.common.Types.ExecMode; +import org.apache.sysds.common.Types.FileFormat; +import org.apache.sysds.common.Types.ValueType; +import org.apache.sysds.hops.OptimizerUtils; +import org.apache.sysds.runtime.instructions.Instruction; +import org.apache.sysds.runtime.io.MatrixWriter; +import org.apache.sysds.runtime.io.MatrixWriterFactory; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.meta.MatrixCharacteristics; +import org.apache.sysds.runtime.util.DataConverter; +import org.apache.sysds.runtime.util.HDFSTool; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.junit.Assert; +import org.junit.Test; + +import java.io.IOException; + +public class TeeTest extends AutomatedTestBase { + + private static final String TEST_NAME = "Tee"; + private static final String TEST_DIR = "functions/ooc/"; + private static final String TEST_CLASS_DIR = TEST_DIR + TeeTest.class.getSimpleName() + "/"; + private static final String INPUT_NAME = "X"; + private static final String OUTPUT_NAME = "res"; + private final static double eps = 1e-10; + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME)); + } + + @Test + public void testTeeNoRewrite() { + testTeeOperation(false); + } + + @Test + public void testTeeRewrite() { + testTeeOperation(true); + } + + + public void testTeeOperation(boolean rewrite) + { + ExecMode platformOld = setExecMode(ExecMode.SINGLE_NODE); + boolean oldRewrite = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION; + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrite; + + try { + getAndLoadTestConfiguration(TEST_NAME); + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + TEST_NAME + ".dml"; + programArgs = new String[] {"-explain", "-stats", "-ooc", + "-args", input(INPUT_NAME), output(OUTPUT_NAME)}; + + int rows = 1000, cols = 1000; + MatrixBlock mb = MatrixBlock.randOperations(rows, cols, 1.0, -1, 1, "uniform", 7); + MatrixWriter writer = MatrixWriterFactory.createMatrixWriter(FileFormat.BINARY); + writer.writeMatrixToHDFS(mb, input(INPUT_NAME), rows, cols, 1000, rows*cols); + HDFSTool.writeMetaDataFile(input(INPUT_NAME+".mtd"), ValueType.FP64, + new MatrixCharacteristics(rows,cols,1000,rows*cols), FileFormat.BINARY); + + runTest(true, false, null, -1); + + + double[][] C1 = readMatrix(output(OUTPUT_NAME), FileFormat.BINARY, rows, cols, 1000, 1000); + double result = 0.0; + for(int i = 0; i < cols; i++) { // verify the results with Java + for(int j = 0; j < cols; j++) { + double expected = 0.0; + for (int k = 0; k < rows; k++) { + expected += mb.get(k, i) * mb.get(k, j); + } + result = C1[i][j]; + Assert.assertEquals( "value mismatch at cell ("+i+","+j+")",expected, result, eps); + } + } + + String prefix = Instruction.OOC_INST_PREFIX; + Assert.assertTrue("OOC wasn't used for RBLK", + heavyHittersContainsString(prefix + Opcodes.RBLK)); + Assert.assertTrue("OOC wasn't used for TEE", + heavyHittersContainsString(prefix + "tee")); + } + catch(Exception ex) { + Assert.fail(ex.getMessage()); + } + finally { + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldRewrite; + resetExecMode(platformOld); + } + } + + private static double[][] readMatrix( String fname, FileFormat fmt, + long rows, long cols, int brows, int bcols ) + throws IOException + { + MatrixBlock mb = DataConverter.readMatrixFromHDFS(fname, fmt, rows, cols, brows, bcols); + double[][] C = DataConverter.convertToDoubleMatrix(mb); + return C; + } +} diff --git a/src/test/scripts/functions/ooc/Tee.dml b/src/test/scripts/functions/ooc/Tee.dml new file mode 100644 index 00000000000..e6faabfc7a1 --- /dev/null +++ b/src/test/scripts/functions/ooc/Tee.dml @@ -0,0 +1,27 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +# Read the input matrix as a stream +X = read($1); + +res = t(X) %*% X; + +write(res, $2, format="binary");