From ae62b29264ef24056c61a1d228349a9085cae023 Mon Sep 17 00:00:00 2001 From: Janardhan Pulivarthi Date: Wed, 27 Aug 2025 19:49:40 +0530 Subject: [PATCH 01/10] [SYSTEMDS-3907] OOC Tee operator --- .../apache/sysds/common/InstructionType.java | 2 +- .../java/org/apache/sysds/common/Opcodes.java | 1 + .../java/org/apache/sysds/hops/TeeOp.java | 170 ++++++++++++ .../sysds/hops/rewrite/ProgramRewriter.java | 1 + .../hops/rewrite/RewriteInjectOOCTee.java | 241 ++++++++++++++++++ src/main/java/org/apache/sysds/lops/Lop.java | 3 +- src/main/java/org/apache/sysds/lops/Tee.java | 65 +++++ .../org/apache/sysds/lops/compile/Dag.java | 12 + .../instructions/OOCInstructionParser.java | 3 + .../instructions/ooc/OOCInstruction.java | 2 +- .../instructions/ooc/TeeOOCInstruction.java | 98 +++++++ .../sysds/test/functions/ooc/TeeTest.java | 127 +++++++++ src/test/scripts/functions/ooc/Tee.dml | 27 ++ 13 files changed, 749 insertions(+), 3 deletions(-) create mode 100644 src/main/java/org/apache/sysds/hops/TeeOp.java create mode 100644 src/main/java/org/apache/sysds/hops/rewrite/RewriteInjectOOCTee.java create mode 100644 src/main/java/org/apache/sysds/lops/Tee.java create mode 100644 src/main/java/org/apache/sysds/runtime/instructions/ooc/TeeOOCInstruction.java create mode 100644 src/test/java/org/apache/sysds/test/functions/ooc/TeeTest.java create mode 100644 src/test/scripts/functions/ooc/Tee.dml diff --git a/src/main/java/org/apache/sysds/common/InstructionType.java b/src/main/java/org/apache/sysds/common/InstructionType.java index 29148f03e92..e0e77c46c59 100644 --- a/src/main/java/org/apache/sysds/common/InstructionType.java +++ b/src/main/java/org/apache/sysds/common/InstructionType.java @@ -89,5 +89,5 @@ public enum InstructionType { PMM, MatrixReshape, Write, - Init, + Init, Tee, } diff --git a/src/main/java/org/apache/sysds/common/Opcodes.java b/src/main/java/org/apache/sysds/common/Opcodes.java index 6cd15611283..251f773a18c 100644 --- a/src/main/java/org/apache/sysds/common/Opcodes.java +++ b/src/main/java/org/apache/sysds/common/Opcodes.java @@ -220,6 +220,7 @@ public enum Opcodes { READ("read", InstructionType.Variable), WRITE("write", InstructionType.Variable, InstructionType.Write), CREATEVAR("createvar", InstructionType.Variable), + TEE("tee", InstructionType.Tee), //Reorg instruction opcodes TRANSPOSE("r'", InstructionType.Reorg), diff --git a/src/main/java/org/apache/sysds/hops/TeeOp.java b/src/main/java/org/apache/sysds/hops/TeeOp.java new file mode 100644 index 00000000000..113d839e2bb --- /dev/null +++ b/src/main/java/org/apache/sysds/hops/TeeOp.java @@ -0,0 +1,170 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.hops; + +import org.apache.sysds.common.Types.ExecType; +import org.apache.sysds.lops.Lop; +import org.apache.sysds.lops.Tee; +import org.apache.sysds.runtime.meta.DataCharacteristics; + +import java.util.ArrayList; + + +public class TeeOp extends Hop { + + private final ArrayList _outputs = new ArrayList<>(); + + private TeeOp() { + // default constructor + } + + /** + * Takes in a single Hop input and gives two outputs + * + * @param input + */ + public TeeOp(Hop input) { + super(input.getName(), input.getDataType(), input._valueType); + + // add single input for this hop + getInput().add(0, input); + input.getParent().add(this); + + // output variables list to feed tee output into +// for (Hop out: outputs) { +// _outputs.add(out); +// } + + // This characteristics are same as the input + refreshSizeInformation(); + } + + @Override + public boolean allowsAllExecTypes() { + return false; + } + + /** + * Computes the output matrix characteristics (rows, cols, nnz) based on worst-case output + * and/or input estimates. Should return null if dimensions are unknown. + * + * @param memo memory table + * @return output characteristics + */ + @Override + protected DataCharacteristics inferOutputCharacteristics(MemoTable memo) { + return null; + } + + @Override + public Lop constructLops() { + // return already created Lops + if (getLops() != null) { + return getLops(); + } + + Tee teeLop = new Tee(getInput().get(0).constructLops(), + getDataType(), getValueType()); + setOutputDimensions(teeLop); + setLineNumbers(teeLop); + setLops(teeLop); + + return getLops(); + } + + @Override + protected ExecType optFindExecType(boolean transitive) { + return ExecType.OOC; + } + + @Override + public String getOpString() { + return "tee"; + } + + /** + * In memory-based optimizer mode (see OptimizerUtils.isMemoryBasedOptLevel()), + * the exectype is determined by checking this method as well as memory budget of this Hop. + * Please see findExecTypeByMemEstimate for more detail. + *

+ * This method is necessary because not all operator are supported efficiently + * on GPU (for example: operations on frames and scalar as well as operations such as table). + * + * @return true if the Hop is eligible for GPU Exectype. + */ + @Override + public boolean isGPUEnabled() { + return false; + } + + /** + * Computes the hop-specific output memory estimate in bytes. Should be 0 if not + * applicable. + * + * @param dim1 dimension 1 + * @param dim2 dimension 2 + * @param nnz number of non-zeros + * @return memory estimate + */ + @Override + protected double computeOutputMemEstimate(long dim1, long dim2, long nnz) { + return 0; + } + + /** + * Computes the hop-specific intermediate memory estimate in bytes. Should be 0 if not + * applicable. + * + * @param dim1 dimension 1 + * @param dim2 dimension 2 + * @param nnz number of non-zeros + * @return memory estimate + */ + @Override + protected double computeIntermediateMemEstimate(long dim1, long dim2, long nnz) { + return 0; + } + + /** + * Update the output size information for this hop. + */ + @Override + public void refreshSizeInformation() { + Hop input1 = getInput().get(0); + setDim1(input1.getDim1()); + setDim2(input1.getDim2()); + setNnz(input1.getNnz()); + setBlocksize(input1.getBlocksize()); + } + + @Override + public Object clone() throws CloneNotSupportedException { + return null; + } + + @Override + public boolean compare(Hop that) { + return false; + } + + public Hop getOutput(int index) { + return _outputs.get(index); + } +} diff --git a/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java b/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java index 874ddae0347..c2602dba510 100644 --- a/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java +++ b/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java @@ -77,6 +77,7 @@ public ProgramRewriter(boolean staticRewrites, boolean dynamicRewrites) //add static HOP DAG rewrite rules _dagRuleSet.add( new RewriteRemoveReadAfterWrite() ); //dependency: before blocksize _dagRuleSet.add( new RewriteBlockSizeAndReblock() ); + _dagRuleSet.add( new RewriteInjectOOCTee() ); if( OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION ) _dagRuleSet.add( new RewriteRemoveUnnecessaryCasts() ); if( OptimizerUtils.ALLOW_COMMON_SUBEXPRESSION_ELIMINATION ) diff --git a/src/main/java/org/apache/sysds/hops/rewrite/RewriteInjectOOCTee.java b/src/main/java/org/apache/sysds/hops/rewrite/RewriteInjectOOCTee.java new file mode 100644 index 00000000000..7a3cd95c85c --- /dev/null +++ b/src/main/java/org/apache/sysds/hops/rewrite/RewriteInjectOOCTee.java @@ -0,0 +1,241 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.hops.rewrite; + +import org.apache.sysds.api.DMLScript; +import org.apache.sysds.common.Types; +import org.apache.sysds.hops.*; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + +/** + * This Rewrite rule injects a Tee Operator for specific Out-Of-Core (OOC) patterns + * where a value or an intermediate result is shared twice. Since for OOC we data streams + * can only be consumed once. + * + *

+ * Pattern identified {@code t(X) %*% X}, where the data {@code X} will be shared by + * {@code t(X)} and {@code %*%} multiplication. + *

+ * + * The rewrite uses a stable two-pass approach: + * 1. Find candidates (Read-Only): Traverse the entire HOP DAG to identify candidates + * the fit the target pattern. + * 2. Apply Rewrites (Modification): Iterate over the collected candidate and put + * {@code TeeOp}, and safely rewire the graph. + */ +public class RewriteInjectOOCTee extends HopRewriteRule { + + private static final Set rewrittenHops = new HashSet<>(); + private static final Map handledHop = new HashMap<>(); + + // Maintain a list of candidates to rewrite in the second pass + private final List rewriteCandidates = new ArrayList<>(); + + /** + * Handle a generic (last-level) hop DAG with multiple roots. + * + * @param roots high-level operator roots + * @param state program rewrite status + * @return list of high-level operators + */ + @Override + public ArrayList rewriteHopDAGs(ArrayList roots, ProgramRewriteStatus state) { + if (roots == null) { + return null; + } + + // Clear candidates for this pass + rewriteCandidates.clear(); + + // PASS 1: Identify candidates without modifying the graph + for (Hop root : roots) { + root.resetVisitStatus(); + findRewriteCandidates(root); + } + + // PASS 2: Apply rewrites to identified candidates + for (Hop candidate : rewriteCandidates) { + applyTopDownTeeRewrite(candidate); + } + + return roots; + } + + /** + * Handle a predicate hop DAG with exactly one root. + * + * @param root high-level operator root + * @param state program rewrite status + * @return high-level operator + */ + @Override + public Hop rewriteHopDAG(Hop root, ProgramRewriteStatus state) { + if (root == null) { + return null; + } + + // Clear candidates for this pass + rewriteCandidates.clear(); + + // PASS 1: Identify candidates without modifying the graph + root.resetVisitStatus(); + findRewriteCandidates(root); + + // PASS 2: Apply rewrites to identified candidates + for (Hop candidate : rewriteCandidates) { + applyTopDownTeeRewrite(candidate); + } + + return root; + } + + /** + * First pass: Find candidates for rewrite without modifying the graph. + * This method traverses the graph and identifies nodes that need to be + * rewritten based on the transpose-matrix multiply pattern. + * + * @param hop current hop being examined + */ + private void findRewriteCandidates(Hop hop) { + if (hop.isVisited()) { + return; + } + + // Mark as visited to avoid processing the same hop multiple times + hop.setVisited(true); + + // Recursively traverse the graph (depth-first) + for (Hop input : hop.getInput()) { + findRewriteCandidates(input); + } + + // Check if this hop is a candidate for OOC Tee injection + if (isRewriteCandidate(hop)) { + rewriteCandidates.add(hop); + } + } + + /** + * Check if a hop should be considered for rewrite. + * + * @param hop the hop to check + * @return true if the hop meets all criteria for rewrite + */ + private boolean isRewriteCandidate(Hop hop) { + // Skip if already handled + if (rewrittenHops.contains(hop.getHopID()) || handledHop.containsKey(hop.getHopID())) { + return false; + } + + boolean multipleConsumers = hop.getParent().size() > 1; + boolean isNotAlreadyTee = isNotAlreadyTee(hop); + boolean isOOCEnabled = DMLScript.USE_OOC; + boolean isTransposeMM = isTranposePattern(hop); + boolean isMatrix = hop.getDataType() == Types.DataType.MATRIX; + + return isOOCEnabled && multipleConsumers && isNotAlreadyTee && isTransposeMM && isMatrix; + } + + /** + * Second pass: Apply the TeeOp transformation to a candidate hop. + * This safely rewires the graph by creating a TeeOp node and placeholders. + * + * @param sharedInput the hop to be rewritten + */ + private void applyTopDownTeeRewrite(Hop sharedInput) { + // Only process if not already handled + if (handledHop.containsKey(sharedInput.getHopID())) { + return; + } + + // Take a defensive copy of consumers before modifying the graph + ArrayList consumers = new ArrayList<>(sharedInput.getParent()); + + // Create the new TeeOp with the original hop as input + TeeOp teeOp = new TeeOp(sharedInput); + + // Rewire the graph: replace original connections with TeeOp outputs + int i = 0; + for (Hop consumer : consumers) { + Hop placeholder = new DataOp("tee_out_" + sharedInput.getName() + "_" + i, + sharedInput.getDataType(), + sharedInput.getValueType(), + Types.OpOpData.TRANSIENTWRITE, + null, + sharedInput.getDim1(), + sharedInput.getDim2(), + sharedInput.getNnz(), + sharedInput.getBlocksize() + ); + + // Copy metadata + placeholder.setBeginLine(sharedInput.getBeginLine()); + placeholder.setBeginColumn(sharedInput.getBeginColumn()); + placeholder.setEndLine(sharedInput.getEndLine()); + placeholder.setEndColumn(sharedInput.getEndColumn()); + + // Connect placeholder to TeeOp and consumer + HopRewriteUtils.addChildReference(placeholder, teeOp); + HopRewriteUtils.replaceChildReference(consumer, sharedInput, placeholder); + + i++; + } + + // Record that we've handled this hop + handledHop.put(sharedInput.getHopID(), teeOp); + rewrittenHops.add(sharedInput.getHopID()); + } + + private boolean isNotAlreadyTee(Hop hop) { + if (hop.getParent().size() > 1) { + for (Hop consumer : hop.getParent()) { + if (consumer instanceof TeeOp) { + return false; + } + } + } + return true; + } + + private boolean isTranposePattern (Hop hop) { + boolean hasTransposeConsumer = false; // t(X) + boolean hasMatrixMultiplyConsumer = false; // %*% + + for (Hop parent: hop.getParent()) { + String opString = parent.getOpString(); + if (parent instanceof ReorgOp) { + if (opString.contains("r'") || opString.contains("transpose")) { + hasTransposeConsumer = true; + } + } + else if (parent instanceof AggBinaryOp) + if (opString.contains("*") || opString.contains("ba+*")) { + hasMatrixMultiplyConsumer = true; + } + } + return hasTransposeConsumer && hasMatrixMultiplyConsumer; + } +} diff --git a/src/main/java/org/apache/sysds/lops/Lop.java b/src/main/java/org/apache/sysds/lops/Lop.java index 447201a5fd3..7bfea43c2e5 100644 --- a/src/main/java/org/apache/sysds/lops/Lop.java +++ b/src/main/java/org/apache/sysds/lops/Lop.java @@ -63,7 +63,8 @@ public enum Type { PlusMult, MinusMult, //CP SpoofFused, //CP/SP generated fused operator Sql, //CP sql read - Federated //FED federated read + Federated, //FED federated read + Tee, //OOC Tee operator } diff --git a/src/main/java/org/apache/sysds/lops/Tee.java b/src/main/java/org/apache/sysds/lops/Tee.java new file mode 100644 index 00000000000..a9ce7ff970b --- /dev/null +++ b/src/main/java/org/apache/sysds/lops/Tee.java @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.lops; + +import org.apache.sysds.common.Types; +import org.apache.sysds.common.Types.DataType; +import org.apache.sysds.common.Types.ValueType; +import org.apache.sysds.runtime.instructions.InstructionUtils; + +public class Tee extends Lop { + + public static final String OPCODE = "tee"; + /** + * Constructor to be invoked by base class. + * + * @param input1 lop type + * @param dt data type of the output + * @param vt value type of the output + */ + public Tee(Lop input1, DataType dt, ValueType vt) { + super(Lop.Type.Tee, dt, vt); + this.addInput(input1); + input1.addOutput(this); + lps.setProperties(inputs, Types.ExecType.OOC); + } + + @Override + public String toString() { + return "Operation = Tee"; + } + + @Override + public String getInstructions(String input1, String outputs) { + + String[] out = outputs.split(Lop.OPERAND_DELIMITOR); + String output2 = outputs + "_copy"; + + // This method generates the instruction string: OOC°tee°input°output1°output2... + String ret = InstructionUtils.concatOperands( + getExecType().name(), OPCODE, + getInputs().get(0).prepInputOperand(input1), + prepOutputOperand(out[0]), + prepOutputOperand(out[1]) + ); + + return ret; + } +} diff --git a/src/main/java/org/apache/sysds/lops/compile/Dag.java b/src/main/java/org/apache/sysds/lops/compile/Dag.java index b26c539e9a8..05fd88731c3 100644 --- a/src/main/java/org/apache/sysds/lops/compile/Dag.java +++ b/src/main/java/org/apache/sysds/lops/compile/Dag.java @@ -548,6 +548,18 @@ else if ( node.getType() == Lop.Type.FunctionCallCP ) outputs[count++] = out.getOutputParameters().getLabel(); inst_string = node.getInstructions(inputs, outputs); } + else if ( node.getType() == Type.Tee ) { + String input = node.getInputs().get(0).getOutputParameters().getLabel(); + + ArrayList outputs = new ArrayList<>(); + for( Lop out : node.getOutputs() ) { + outputs.add(out.getOutputParameters().getLabel()); + } + + String packedOutputs = String.join(Lop.OPERAND_DELIMITOR, outputs); + + inst_string = node.getInstructions(input, packedOutputs); + } else if (node.getType() == Lop.Type.Nary) { String[] inputs = new String[node.getInputs().size()]; int count = 0; diff --git a/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java b/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java index 73b5ca02618..94e1a7ff180 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java @@ -30,6 +30,7 @@ import org.apache.sysds.runtime.instructions.ooc.UnaryOOCInstruction; import org.apache.sysds.runtime.instructions.ooc.MatrixVectorBinaryOOCInstruction; import org.apache.sysds.runtime.instructions.ooc.TransposeOOCInstruction; +import org.apache.sysds.runtime.instructions.ooc.TeeOOCInstruction; public class OOCInstructionParser extends InstructionParser { protected static final Log LOG = LogFactory.getLog(OOCInstructionParser.class.getName()); @@ -63,6 +64,8 @@ public static OOCInstruction parseSingleInstruction(InstructionType ooctype, Str return MatrixVectorBinaryOOCInstruction.parseInstruction(str); case Reorg: return TransposeOOCInstruction.parseInstruction(str); + case Tee: + return TeeOOCInstruction.parseInstruction(str); default: throw new DMLRuntimeException("Invalid OOC Instruction Type: " + ooctype); diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java index 0495dcfde51..c76fd4e4a4b 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java @@ -30,7 +30,7 @@ public abstract class OOCInstruction extends Instruction { protected static final Log LOG = LogFactory.getLog(OOCInstruction.class.getName()); public enum OOCType { - Reblock, AggregateUnary, Binary, Unary, MAPMM, Reorg, AggregateBinary + Reblock, AggregateUnary, Binary, Unary, MAPMM, Reorg, AggregateBinary, Tee } protected final OOCInstruction.OOCType _ooctype; diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/TeeOOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/TeeOOCInstruction.java new file mode 100644 index 00000000000..0248ef78b1b --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/TeeOOCInstruction.java @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.instructions.ooc; + +import org.apache.sysds.runtime.DMLRuntimeException; +import org.apache.sysds.runtime.controlprogram.caching.MatrixObject; +import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; +import org.apache.sysds.runtime.controlprogram.parfor.LocalTaskQueue; +import org.apache.sysds.runtime.instructions.InstructionUtils; +import org.apache.sysds.runtime.instructions.cp.CPOperand; +import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue; +import org.apache.sysds.runtime.util.CommonThreadPool; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.concurrent.ExecutorService; + +public class TeeOOCInstruction extends ComputationOOCInstruction { + + private final List _outputs; + + protected TeeOOCInstruction(OOCType type, CPOperand in1, CPOperand out, CPOperand out2, String opcode, String istr) { + super(type, null, in1, out, opcode, istr); + _outputs = Arrays.asList(out, out2); + } + + public static TeeOOCInstruction parseInstruction(String str) { + String[] parts = InstructionUtils.getInstructionPartsWithValueType(str); + InstructionUtils.checkNumFields(parts, 3); + String opcode = parts[0]; + CPOperand in1 = new CPOperand(parts[1]); + CPOperand out = new CPOperand(parts[2]); + CPOperand out2 = new CPOperand(parts[3]); + + return new TeeOOCInstruction(OOCType.Tee, in1, out, out2, opcode, str); + } + + public void processInstruction( ExecutionContext ec ) { + + // Create thread and process the tee operation + MatrixObject min = ec.getMatrixObject(input1); + LocalTaskQueue qIn = min.getStreamHandle(); + +// MatrixObject min = ec.getMatrixObject(input1); +// LocalTaskQueue qIn = min.getStreamHandle(); + List> qOuts = new ArrayList<>(); + for (CPOperand out : _outputs) { + MatrixObject mout = ec.createMatrixObject(min.getDataCharacteristics()); + ec.setVariable(out.getName(), mout); + LocalTaskQueue qOut = new LocalTaskQueue<>(); + mout.setStreamHandle(qOut); + qOuts.add(qOut); + } + + ExecutorService pool = CommonThreadPool.get(); + try { + pool.submit(() -> { + IndexedMatrixValue tmp = null; + try { + while ((tmp = qIn.dequeueTask()) != LocalTaskQueue.NO_MORE_TASKS) { + + for (int i = 0; i < qOuts.size(); i++) { + qOuts.get(i).enqueueTask(new IndexedMatrixValue(tmp)); + } + } + for (LocalTaskQueue qOut : qOuts) { + qOut.closeInput(); + } + } + catch(Exception ex) { + throw new DMLRuntimeException(ex); + } + }); + } catch (Exception ex) { + throw new DMLRuntimeException(ex); + } finally { + pool.shutdown(); + } + } +} diff --git a/src/test/java/org/apache/sysds/test/functions/ooc/TeeTest.java b/src/test/java/org/apache/sysds/test/functions/ooc/TeeTest.java new file mode 100644 index 00000000000..16e60288538 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/ooc/TeeTest.java @@ -0,0 +1,127 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.ooc; + +import org.apache.sysds.common.Opcodes; +import org.apache.sysds.common.Types.ExecMode; +import org.apache.sysds.common.Types.FileFormat; +import org.apache.sysds.common.Types.ValueType; +import org.apache.sysds.hops.OptimizerUtils; +import org.apache.sysds.runtime.instructions.Instruction; +import org.apache.sysds.runtime.io.MatrixWriter; +import org.apache.sysds.runtime.io.MatrixWriterFactory; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.meta.MatrixCharacteristics; +import org.apache.sysds.runtime.util.DataConverter; +import org.apache.sysds.runtime.util.HDFSTool; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.junit.Assert; +import org.junit.Test; + +import java.io.IOException; + +public class TeeTest extends AutomatedTestBase { + + private static final String TEST_NAME = "Tee"; + private static final String TEST_DIR = "functions/ooc/"; + private static final String TEST_CLASS_DIR = TEST_DIR + TeeTest.class.getSimpleName() + "/"; + private static final String INPUT_NAME = "X"; + private static final String OUTPUT_NAME = "res"; + private final static double eps = 1e-10; + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME)); + } + + @Test + public void testTeeNoRewrite() { + testTeeOperation(false); + } + + @Test + public void testTeeRewrite() { + testTeeOperation(true); + } + + + public void testTeeOperation(boolean rewrite) + { + ExecMode platformOld = setExecMode(ExecMode.SINGLE_NODE); + boolean oldRewrite = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION; + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrite; + + try { + getAndLoadTestConfiguration(TEST_NAME); + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + TEST_NAME + ".dml"; + programArgs = new String[] {"-explain", "-stats", "-ooc", + "-args", input(INPUT_NAME), output(OUTPUT_NAME)}; + + int rows = 1000, cols = 1000; + MatrixBlock mb = MatrixBlock.randOperations(rows, cols, 1.0, -1, 1, "uniform", 7); + MatrixWriter writer = MatrixWriterFactory.createMatrixWriter(FileFormat.BINARY); + writer.writeMatrixToHDFS(mb, input(INPUT_NAME), rows, cols, 1000, rows*cols); + HDFSTool.writeMetaDataFile(input(INPUT_NAME+".mtd"), ValueType.FP64, + new MatrixCharacteristics(rows,cols,1000,rows*cols), FileFormat.BINARY); + + runTest(true, false, null, -1); + + + double[][] C1 = readMatrix(output(OUTPUT_NAME), FileFormat.BINARY, rows, cols, 1000, 1000); + double result = 0.0; + for(int i = 0; i < cols; i++) { // verify the results with Java + for(int j = 0; j < cols; j++) { + double expected = 0.0; + for (int k = 0; k < rows; k++) { + expected += mb.get(k, i) * mb.get(k, j); + } + result = C1[i][j]; + Assert.assertEquals( "value mismatch at cell ("+i+","+j+")",expected, result, eps); + } + } + + String prefix = Instruction.OOC_INST_PREFIX; + Assert.assertTrue("OOC wasn't used for RBLK", + heavyHittersContainsString(prefix + Opcodes.RBLK)); + Assert.assertTrue("OOC wasn't used for TEE", + heavyHittersContainsString(prefix + "tee")); + } + catch(Exception ex) { + Assert.fail(ex.getMessage()); + } + finally { + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldRewrite; + resetExecMode(platformOld); + } + } + + private static double[][] readMatrix( String fname, FileFormat fmt, + long rows, long cols, int brows, int bcols ) + throws IOException + { + MatrixBlock mb = DataConverter.readMatrixFromHDFS(fname, fmt, rows, cols, brows, bcols); + double[][] C = DataConverter.convertToDoubleMatrix(mb); + return C; + } +} diff --git a/src/test/scripts/functions/ooc/Tee.dml b/src/test/scripts/functions/ooc/Tee.dml new file mode 100644 index 00000000000..e6faabfc7a1 --- /dev/null +++ b/src/test/scripts/functions/ooc/Tee.dml @@ -0,0 +1,27 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +# Read the input matrix as a stream +X = read($1); + +res = t(X) %*% X; + +write(res, $2, format="binary"); From 2b90ce068cb1b2e97dd0b8c3bab0cbb4df9f31d8 Mon Sep 17 00:00:00 2001 From: Janardhan Pulivarthi Date: Thu, 28 Aug 2025 08:42:43 +0530 Subject: [PATCH 02/10] fix formatting --- .../java/org/apache/sysds/hops/TeeOp.java | 260 ++++++------ .../hops/rewrite/RewriteInjectOOCTee.java | 380 +++++++++--------- src/main/java/org/apache/sysds/lops/Tee.java | 62 +-- .../instructions/ooc/TeeOOCInstruction.java | 100 ++--- .../sysds/test/functions/ooc/TeeTest.java | 166 ++++---- 5 files changed, 484 insertions(+), 484 deletions(-) diff --git a/src/main/java/org/apache/sysds/hops/TeeOp.java b/src/main/java/org/apache/sysds/hops/TeeOp.java index 113d839e2bb..e27cd043aad 100644 --- a/src/main/java/org/apache/sysds/hops/TeeOp.java +++ b/src/main/java/org/apache/sysds/hops/TeeOp.java @@ -29,142 +29,142 @@ public class TeeOp extends Hop { - private final ArrayList _outputs = new ArrayList<>(); + private final ArrayList _outputs = new ArrayList<>(); - private TeeOp() { - // default constructor - } + private TeeOp() { + // default constructor + } - /** - * Takes in a single Hop input and gives two outputs - * - * @param input - */ - public TeeOp(Hop input) { - super(input.getName(), input.getDataType(), input._valueType); + /** + * Takes in a single Hop input and gives two outputs + * + * @param input + */ + public TeeOp(Hop input) { + super(input.getName(), input.getDataType(), input._valueType); - // add single input for this hop - getInput().add(0, input); - input.getParent().add(this); + // add single input for this hop + getInput().add(0, input); + input.getParent().add(this); - // output variables list to feed tee output into + // output variables list to feed tee output into // for (Hop out: outputs) { // _outputs.add(out); // } - // This characteristics are same as the input - refreshSizeInformation(); - } - - @Override - public boolean allowsAllExecTypes() { - return false; - } - - /** - * Computes the output matrix characteristics (rows, cols, nnz) based on worst-case output - * and/or input estimates. Should return null if dimensions are unknown. - * - * @param memo memory table - * @return output characteristics - */ - @Override - protected DataCharacteristics inferOutputCharacteristics(MemoTable memo) { - return null; - } - - @Override - public Lop constructLops() { - // return already created Lops - if (getLops() != null) { - return getLops(); - } - - Tee teeLop = new Tee(getInput().get(0).constructLops(), - getDataType(), getValueType()); - setOutputDimensions(teeLop); - setLineNumbers(teeLop); - setLops(teeLop); - - return getLops(); - } - - @Override - protected ExecType optFindExecType(boolean transitive) { - return ExecType.OOC; - } - - @Override - public String getOpString() { - return "tee"; - } - - /** - * In memory-based optimizer mode (see OptimizerUtils.isMemoryBasedOptLevel()), - * the exectype is determined by checking this method as well as memory budget of this Hop. - * Please see findExecTypeByMemEstimate for more detail. - *

- * This method is necessary because not all operator are supported efficiently - * on GPU (for example: operations on frames and scalar as well as operations such as table). - * - * @return true if the Hop is eligible for GPU Exectype. - */ - @Override - public boolean isGPUEnabled() { - return false; - } - - /** - * Computes the hop-specific output memory estimate in bytes. Should be 0 if not - * applicable. - * - * @param dim1 dimension 1 - * @param dim2 dimension 2 - * @param nnz number of non-zeros - * @return memory estimate - */ - @Override - protected double computeOutputMemEstimate(long dim1, long dim2, long nnz) { - return 0; - } - - /** - * Computes the hop-specific intermediate memory estimate in bytes. Should be 0 if not - * applicable. - * - * @param dim1 dimension 1 - * @param dim2 dimension 2 - * @param nnz number of non-zeros - * @return memory estimate - */ - @Override - protected double computeIntermediateMemEstimate(long dim1, long dim2, long nnz) { - return 0; - } - - /** - * Update the output size information for this hop. - */ - @Override - public void refreshSizeInformation() { - Hop input1 = getInput().get(0); - setDim1(input1.getDim1()); - setDim2(input1.getDim2()); - setNnz(input1.getNnz()); - setBlocksize(input1.getBlocksize()); - } - - @Override - public Object clone() throws CloneNotSupportedException { - return null; - } - - @Override - public boolean compare(Hop that) { - return false; - } - - public Hop getOutput(int index) { - return _outputs.get(index); - } + // This characteristics are same as the input + refreshSizeInformation(); + } + + @Override + public boolean allowsAllExecTypes() { + return false; + } + + /** + * Computes the output matrix characteristics (rows, cols, nnz) based on worst-case output + * and/or input estimates. Should return null if dimensions are unknown. + * + * @param memo memory table + * @return output characteristics + */ + @Override + protected DataCharacteristics inferOutputCharacteristics(MemoTable memo) { + return null; + } + + @Override + public Lop constructLops() { + // return already created Lops + if (getLops() != null) { + return getLops(); + } + + Tee teeLop = new Tee(getInput().get(0).constructLops(), + getDataType(), getValueType()); + setOutputDimensions(teeLop); + setLineNumbers(teeLop); + setLops(teeLop); + + return getLops(); + } + + @Override + protected ExecType optFindExecType(boolean transitive) { + return ExecType.OOC; + } + + @Override + public String getOpString() { + return "tee"; + } + + /** + * In memory-based optimizer mode (see OptimizerUtils.isMemoryBasedOptLevel()), + * the exectype is determined by checking this method as well as memory budget of this Hop. + * Please see findExecTypeByMemEstimate for more detail. + *

+ * This method is necessary because not all operator are supported efficiently + * on GPU (for example: operations on frames and scalar as well as operations such as table). + * + * @return true if the Hop is eligible for GPU Exectype. + */ + @Override + public boolean isGPUEnabled() { + return false; + } + + /** + * Computes the hop-specific output memory estimate in bytes. Should be 0 if not + * applicable. + * + * @param dim1 dimension 1 + * @param dim2 dimension 2 + * @param nnz number of non-zeros + * @return memory estimate + */ + @Override + protected double computeOutputMemEstimate(long dim1, long dim2, long nnz) { + return 0; + } + + /** + * Computes the hop-specific intermediate memory estimate in bytes. Should be 0 if not + * applicable. + * + * @param dim1 dimension 1 + * @param dim2 dimension 2 + * @param nnz number of non-zeros + * @return memory estimate + */ + @Override + protected double computeIntermediateMemEstimate(long dim1, long dim2, long nnz) { + return 0; + } + + /** + * Update the output size information for this hop. + */ + @Override + public void refreshSizeInformation() { + Hop input1 = getInput().get(0); + setDim1(input1.getDim1()); + setDim2(input1.getDim2()); + setNnz(input1.getNnz()); + setBlocksize(input1.getBlocksize()); + } + + @Override + public Object clone() throws CloneNotSupportedException { + return null; + } + + @Override + public boolean compare(Hop that) { + return false; + } + + public Hop getOutput(int index) { + return _outputs.get(index); + } } diff --git a/src/main/java/org/apache/sysds/hops/rewrite/RewriteInjectOOCTee.java b/src/main/java/org/apache/sysds/hops/rewrite/RewriteInjectOOCTee.java index 7a3cd95c85c..f7f5422cd15 100644 --- a/src/main/java/org/apache/sysds/hops/rewrite/RewriteInjectOOCTee.java +++ b/src/main/java/org/apache/sysds/hops/rewrite/RewriteInjectOOCTee.java @@ -48,194 +48,194 @@ */ public class RewriteInjectOOCTee extends HopRewriteRule { - private static final Set rewrittenHops = new HashSet<>(); - private static final Map handledHop = new HashMap<>(); - - // Maintain a list of candidates to rewrite in the second pass - private final List rewriteCandidates = new ArrayList<>(); - - /** - * Handle a generic (last-level) hop DAG with multiple roots. - * - * @param roots high-level operator roots - * @param state program rewrite status - * @return list of high-level operators - */ - @Override - public ArrayList rewriteHopDAGs(ArrayList roots, ProgramRewriteStatus state) { - if (roots == null) { - return null; - } - - // Clear candidates for this pass - rewriteCandidates.clear(); - - // PASS 1: Identify candidates without modifying the graph - for (Hop root : roots) { - root.resetVisitStatus(); - findRewriteCandidates(root); - } - - // PASS 2: Apply rewrites to identified candidates - for (Hop candidate : rewriteCandidates) { - applyTopDownTeeRewrite(candidate); - } - - return roots; - } - - /** - * Handle a predicate hop DAG with exactly one root. - * - * @param root high-level operator root - * @param state program rewrite status - * @return high-level operator - */ - @Override - public Hop rewriteHopDAG(Hop root, ProgramRewriteStatus state) { - if (root == null) { - return null; - } - - // Clear candidates for this pass - rewriteCandidates.clear(); - - // PASS 1: Identify candidates without modifying the graph - root.resetVisitStatus(); - findRewriteCandidates(root); - - // PASS 2: Apply rewrites to identified candidates - for (Hop candidate : rewriteCandidates) { - applyTopDownTeeRewrite(candidate); - } - - return root; - } - - /** - * First pass: Find candidates for rewrite without modifying the graph. - * This method traverses the graph and identifies nodes that need to be - * rewritten based on the transpose-matrix multiply pattern. - * - * @param hop current hop being examined - */ - private void findRewriteCandidates(Hop hop) { - if (hop.isVisited()) { - return; - } - - // Mark as visited to avoid processing the same hop multiple times - hop.setVisited(true); - - // Recursively traverse the graph (depth-first) - for (Hop input : hop.getInput()) { - findRewriteCandidates(input); - } - - // Check if this hop is a candidate for OOC Tee injection - if (isRewriteCandidate(hop)) { - rewriteCandidates.add(hop); - } - } - - /** - * Check if a hop should be considered for rewrite. - * - * @param hop the hop to check - * @return true if the hop meets all criteria for rewrite - */ - private boolean isRewriteCandidate(Hop hop) { - // Skip if already handled - if (rewrittenHops.contains(hop.getHopID()) || handledHop.containsKey(hop.getHopID())) { - return false; - } - - boolean multipleConsumers = hop.getParent().size() > 1; - boolean isNotAlreadyTee = isNotAlreadyTee(hop); - boolean isOOCEnabled = DMLScript.USE_OOC; - boolean isTransposeMM = isTranposePattern(hop); - boolean isMatrix = hop.getDataType() == Types.DataType.MATRIX; - - return isOOCEnabled && multipleConsumers && isNotAlreadyTee && isTransposeMM && isMatrix; - } - - /** - * Second pass: Apply the TeeOp transformation to a candidate hop. - * This safely rewires the graph by creating a TeeOp node and placeholders. - * - * @param sharedInput the hop to be rewritten - */ - private void applyTopDownTeeRewrite(Hop sharedInput) { - // Only process if not already handled - if (handledHop.containsKey(sharedInput.getHopID())) { - return; - } - - // Take a defensive copy of consumers before modifying the graph - ArrayList consumers = new ArrayList<>(sharedInput.getParent()); - - // Create the new TeeOp with the original hop as input - TeeOp teeOp = new TeeOp(sharedInput); - - // Rewire the graph: replace original connections with TeeOp outputs - int i = 0; - for (Hop consumer : consumers) { - Hop placeholder = new DataOp("tee_out_" + sharedInput.getName() + "_" + i, - sharedInput.getDataType(), - sharedInput.getValueType(), - Types.OpOpData.TRANSIENTWRITE, - null, - sharedInput.getDim1(), - sharedInput.getDim2(), - sharedInput.getNnz(), - sharedInput.getBlocksize() - ); - - // Copy metadata - placeholder.setBeginLine(sharedInput.getBeginLine()); - placeholder.setBeginColumn(sharedInput.getBeginColumn()); - placeholder.setEndLine(sharedInput.getEndLine()); - placeholder.setEndColumn(sharedInput.getEndColumn()); - - // Connect placeholder to TeeOp and consumer - HopRewriteUtils.addChildReference(placeholder, teeOp); - HopRewriteUtils.replaceChildReference(consumer, sharedInput, placeholder); - - i++; - } - - // Record that we've handled this hop - handledHop.put(sharedInput.getHopID(), teeOp); - rewrittenHops.add(sharedInput.getHopID()); - } - - private boolean isNotAlreadyTee(Hop hop) { - if (hop.getParent().size() > 1) { - for (Hop consumer : hop.getParent()) { - if (consumer instanceof TeeOp) { - return false; - } - } - } - return true; - } - - private boolean isTranposePattern (Hop hop) { - boolean hasTransposeConsumer = false; // t(X) - boolean hasMatrixMultiplyConsumer = false; // %*% - - for (Hop parent: hop.getParent()) { - String opString = parent.getOpString(); - if (parent instanceof ReorgOp) { - if (opString.contains("r'") || opString.contains("transpose")) { - hasTransposeConsumer = true; - } - } - else if (parent instanceof AggBinaryOp) - if (opString.contains("*") || opString.contains("ba+*")) { - hasMatrixMultiplyConsumer = true; - } - } - return hasTransposeConsumer && hasMatrixMultiplyConsumer; - } + private static final Set rewrittenHops = new HashSet<>(); + private static final Map handledHop = new HashMap<>(); + + // Maintain a list of candidates to rewrite in the second pass + private final List rewriteCandidates = new ArrayList<>(); + + /** + * Handle a generic (last-level) hop DAG with multiple roots. + * + * @param roots high-level operator roots + * @param state program rewrite status + * @return list of high-level operators + */ + @Override + public ArrayList rewriteHopDAGs(ArrayList roots, ProgramRewriteStatus state) { + if (roots == null) { + return null; + } + + // Clear candidates for this pass + rewriteCandidates.clear(); + + // PASS 1: Identify candidates without modifying the graph + for (Hop root : roots) { + root.resetVisitStatus(); + findRewriteCandidates(root); + } + + // PASS 2: Apply rewrites to identified candidates + for (Hop candidate : rewriteCandidates) { + applyTopDownTeeRewrite(candidate); + } + + return roots; + } + + /** + * Handle a predicate hop DAG with exactly one root. + * + * @param root high-level operator root + * @param state program rewrite status + * @return high-level operator + */ + @Override + public Hop rewriteHopDAG(Hop root, ProgramRewriteStatus state) { + if (root == null) { + return null; + } + + // Clear candidates for this pass + rewriteCandidates.clear(); + + // PASS 1: Identify candidates without modifying the graph + root.resetVisitStatus(); + findRewriteCandidates(root); + + // PASS 2: Apply rewrites to identified candidates + for (Hop candidate : rewriteCandidates) { + applyTopDownTeeRewrite(candidate); + } + + return root; + } + + /** + * First pass: Find candidates for rewrite without modifying the graph. + * This method traverses the graph and identifies nodes that need to be + * rewritten based on the transpose-matrix multiply pattern. + * + * @param hop current hop being examined + */ + private void findRewriteCandidates(Hop hop) { + if (hop.isVisited()) { + return; + } + + // Mark as visited to avoid processing the same hop multiple times + hop.setVisited(true); + + // Recursively traverse the graph (depth-first) + for (Hop input : hop.getInput()) { + findRewriteCandidates(input); + } + + // Check if this hop is a candidate for OOC Tee injection + if (isRewriteCandidate(hop)) { + rewriteCandidates.add(hop); + } + } + + /** + * Check if a hop should be considered for rewrite. + * + * @param hop the hop to check + * @return true if the hop meets all criteria for rewrite + */ + private boolean isRewriteCandidate(Hop hop) { + // Skip if already handled + if (rewrittenHops.contains(hop.getHopID()) || handledHop.containsKey(hop.getHopID())) { + return false; + } + + boolean multipleConsumers = hop.getParent().size() > 1; + boolean isNotAlreadyTee = isNotAlreadyTee(hop); + boolean isOOCEnabled = DMLScript.USE_OOC; + boolean isTransposeMM = isTranposePattern(hop); + boolean isMatrix = hop.getDataType() == Types.DataType.MATRIX; + + return isOOCEnabled && multipleConsumers && isNotAlreadyTee && isTransposeMM && isMatrix; + } + + /** + * Second pass: Apply the TeeOp transformation to a candidate hop. + * This safely rewires the graph by creating a TeeOp node and placeholders. + * + * @param sharedInput the hop to be rewritten + */ + private void applyTopDownTeeRewrite(Hop sharedInput) { + // Only process if not already handled + if (handledHop.containsKey(sharedInput.getHopID())) { + return; + } + + // Take a defensive copy of consumers before modifying the graph + ArrayList consumers = new ArrayList<>(sharedInput.getParent()); + + // Create the new TeeOp with the original hop as input + TeeOp teeOp = new TeeOp(sharedInput); + + // Rewire the graph: replace original connections with TeeOp outputs + int i = 0; + for (Hop consumer : consumers) { + Hop placeholder = new DataOp("tee_out_" + sharedInput.getName() + "_" + i, + sharedInput.getDataType(), + sharedInput.getValueType(), + Types.OpOpData.TRANSIENTWRITE, + null, + sharedInput.getDim1(), + sharedInput.getDim2(), + sharedInput.getNnz(), + sharedInput.getBlocksize() + ); + + // Copy metadata + placeholder.setBeginLine(sharedInput.getBeginLine()); + placeholder.setBeginColumn(sharedInput.getBeginColumn()); + placeholder.setEndLine(sharedInput.getEndLine()); + placeholder.setEndColumn(sharedInput.getEndColumn()); + + // Connect placeholder to TeeOp and consumer + HopRewriteUtils.addChildReference(placeholder, teeOp); + HopRewriteUtils.replaceChildReference(consumer, sharedInput, placeholder); + + i++; + } + + // Record that we've handled this hop + handledHop.put(sharedInput.getHopID(), teeOp); + rewrittenHops.add(sharedInput.getHopID()); + } + + private boolean isNotAlreadyTee(Hop hop) { + if (hop.getParent().size() > 1) { + for (Hop consumer : hop.getParent()) { + if (consumer instanceof TeeOp) { + return false; + } + } + } + return true; + } + + private boolean isTranposePattern (Hop hop) { + boolean hasTransposeConsumer = false; // t(X) + boolean hasMatrixMultiplyConsumer = false; // %*% + + for (Hop parent: hop.getParent()) { + String opString = parent.getOpString(); + if (parent instanceof ReorgOp) { + if (opString.contains("r'") || opString.contains("transpose")) { + hasTransposeConsumer = true; + } + } + else if (parent instanceof AggBinaryOp) + if (opString.contains("*") || opString.contains("ba+*")) { + hasMatrixMultiplyConsumer = true; + } + } + return hasTransposeConsumer && hasMatrixMultiplyConsumer; + } } diff --git a/src/main/java/org/apache/sysds/lops/Tee.java b/src/main/java/org/apache/sysds/lops/Tee.java index a9ce7ff970b..113bd37b12d 100644 --- a/src/main/java/org/apache/sysds/lops/Tee.java +++ b/src/main/java/org/apache/sysds/lops/Tee.java @@ -26,40 +26,40 @@ public class Tee extends Lop { - public static final String OPCODE = "tee"; - /** - * Constructor to be invoked by base class. - * - * @param input1 lop type - * @param dt data type of the output - * @param vt value type of the output - */ - public Tee(Lop input1, DataType dt, ValueType vt) { - super(Lop.Type.Tee, dt, vt); - this.addInput(input1); - input1.addOutput(this); - lps.setProperties(inputs, Types.ExecType.OOC); - } + public static final String OPCODE = "tee"; + /** + * Constructor to be invoked by base class. + * + * @param input1 lop type + * @param dt data type of the output + * @param vt value type of the output + */ + public Tee(Lop input1, DataType dt, ValueType vt) { + super(Lop.Type.Tee, dt, vt); + this.addInput(input1); + input1.addOutput(this); + lps.setProperties(inputs, Types.ExecType.OOC); + } - @Override - public String toString() { - return "Operation = Tee"; - } + @Override + public String toString() { + return "Operation = Tee"; + } - @Override - public String getInstructions(String input1, String outputs) { + @Override + public String getInstructions(String input1, String outputs) { - String[] out = outputs.split(Lop.OPERAND_DELIMITOR); - String output2 = outputs + "_copy"; + String[] out = outputs.split(Lop.OPERAND_DELIMITOR); + String output2 = outputs + "_copy"; - // This method generates the instruction string: OOC°tee°input°output1°output2... - String ret = InstructionUtils.concatOperands( - getExecType().name(), OPCODE, - getInputs().get(0).prepInputOperand(input1), - prepOutputOperand(out[0]), - prepOutputOperand(out[1]) - ); + // This method generates the instruction string: OOC°tee°input°output1°output2... + String ret = InstructionUtils.concatOperands( + getExecType().name(), OPCODE, + getInputs().get(0).prepInputOperand(input1), + prepOutputOperand(out[0]), + prepOutputOperand(out[1]) + ); - return ret; - } + return ret; + } } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/TeeOOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/TeeOOCInstruction.java index 0248ef78b1b..52ee7778649 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/TeeOOCInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/TeeOOCInstruction.java @@ -35,64 +35,64 @@ public class TeeOOCInstruction extends ComputationOOCInstruction { - private final List _outputs; + private final List _outputs; - protected TeeOOCInstruction(OOCType type, CPOperand in1, CPOperand out, CPOperand out2, String opcode, String istr) { - super(type, null, in1, out, opcode, istr); - _outputs = Arrays.asList(out, out2); - } + protected TeeOOCInstruction(OOCType type, CPOperand in1, CPOperand out, CPOperand out2, String opcode, String istr) { + super(type, null, in1, out, opcode, istr); + _outputs = Arrays.asList(out, out2); + } - public static TeeOOCInstruction parseInstruction(String str) { - String[] parts = InstructionUtils.getInstructionPartsWithValueType(str); - InstructionUtils.checkNumFields(parts, 3); - String opcode = parts[0]; - CPOperand in1 = new CPOperand(parts[1]); - CPOperand out = new CPOperand(parts[2]); - CPOperand out2 = new CPOperand(parts[3]); + public static TeeOOCInstruction parseInstruction(String str) { + String[] parts = InstructionUtils.getInstructionPartsWithValueType(str); + InstructionUtils.checkNumFields(parts, 3); + String opcode = parts[0]; + CPOperand in1 = new CPOperand(parts[1]); + CPOperand out = new CPOperand(parts[2]); + CPOperand out2 = new CPOperand(parts[3]); - return new TeeOOCInstruction(OOCType.Tee, in1, out, out2, opcode, str); - } + return new TeeOOCInstruction(OOCType.Tee, in1, out, out2, opcode, str); + } - public void processInstruction( ExecutionContext ec ) { + public void processInstruction( ExecutionContext ec ) { - // Create thread and process the tee operation - MatrixObject min = ec.getMatrixObject(input1); - LocalTaskQueue qIn = min.getStreamHandle(); + // Create thread and process the tee operation + MatrixObject min = ec.getMatrixObject(input1); + LocalTaskQueue qIn = min.getStreamHandle(); // MatrixObject min = ec.getMatrixObject(input1); // LocalTaskQueue qIn = min.getStreamHandle(); - List> qOuts = new ArrayList<>(); - for (CPOperand out : _outputs) { - MatrixObject mout = ec.createMatrixObject(min.getDataCharacteristics()); - ec.setVariable(out.getName(), mout); - LocalTaskQueue qOut = new LocalTaskQueue<>(); - mout.setStreamHandle(qOut); - qOuts.add(qOut); - } + List> qOuts = new ArrayList<>(); + for (CPOperand out : _outputs) { + MatrixObject mout = ec.createMatrixObject(min.getDataCharacteristics()); + ec.setVariable(out.getName(), mout); + LocalTaskQueue qOut = new LocalTaskQueue<>(); + mout.setStreamHandle(qOut); + qOuts.add(qOut); + } - ExecutorService pool = CommonThreadPool.get(); - try { - pool.submit(() -> { - IndexedMatrixValue tmp = null; - try { - while ((tmp = qIn.dequeueTask()) != LocalTaskQueue.NO_MORE_TASKS) { + ExecutorService pool = CommonThreadPool.get(); + try { + pool.submit(() -> { + IndexedMatrixValue tmp = null; + try { + while ((tmp = qIn.dequeueTask()) != LocalTaskQueue.NO_MORE_TASKS) { - for (int i = 0; i < qOuts.size(); i++) { - qOuts.get(i).enqueueTask(new IndexedMatrixValue(tmp)); - } - } - for (LocalTaskQueue qOut : qOuts) { - qOut.closeInput(); - } - } - catch(Exception ex) { - throw new DMLRuntimeException(ex); - } - }); - } catch (Exception ex) { - throw new DMLRuntimeException(ex); - } finally { - pool.shutdown(); - } - } + for (int i = 0; i < qOuts.size(); i++) { + qOuts.get(i).enqueueTask(new IndexedMatrixValue(tmp)); + } + } + for (LocalTaskQueue qOut : qOuts) { + qOut.closeInput(); + } + } + catch(Exception ex) { + throw new DMLRuntimeException(ex); + } + }); + } catch (Exception ex) { + throw new DMLRuntimeException(ex); + } finally { + pool.shutdown(); + } + } } diff --git a/src/test/java/org/apache/sysds/test/functions/ooc/TeeTest.java b/src/test/java/org/apache/sysds/test/functions/ooc/TeeTest.java index 16e60288538..63357719f01 100644 --- a/src/test/java/org/apache/sysds/test/functions/ooc/TeeTest.java +++ b/src/test/java/org/apache/sysds/test/functions/ooc/TeeTest.java @@ -41,87 +41,87 @@ public class TeeTest extends AutomatedTestBase { - private static final String TEST_NAME = "Tee"; - private static final String TEST_DIR = "functions/ooc/"; - private static final String TEST_CLASS_DIR = TEST_DIR + TeeTest.class.getSimpleName() + "/"; - private static final String INPUT_NAME = "X"; - private static final String OUTPUT_NAME = "res"; - private final static double eps = 1e-10; - - @Override - public void setUp() { - TestUtils.clearAssertionInformation(); - addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME)); - } - - @Test - public void testTeeNoRewrite() { - testTeeOperation(false); - } - - @Test - public void testTeeRewrite() { - testTeeOperation(true); - } - - - public void testTeeOperation(boolean rewrite) - { - ExecMode platformOld = setExecMode(ExecMode.SINGLE_NODE); - boolean oldRewrite = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION; - OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrite; - - try { - getAndLoadTestConfiguration(TEST_NAME); - String HOME = SCRIPT_DIR + TEST_DIR; - fullDMLScriptName = HOME + TEST_NAME + ".dml"; - programArgs = new String[] {"-explain", "-stats", "-ooc", - "-args", input(INPUT_NAME), output(OUTPUT_NAME)}; - - int rows = 1000, cols = 1000; - MatrixBlock mb = MatrixBlock.randOperations(rows, cols, 1.0, -1, 1, "uniform", 7); - MatrixWriter writer = MatrixWriterFactory.createMatrixWriter(FileFormat.BINARY); - writer.writeMatrixToHDFS(mb, input(INPUT_NAME), rows, cols, 1000, rows*cols); - HDFSTool.writeMetaDataFile(input(INPUT_NAME+".mtd"), ValueType.FP64, - new MatrixCharacteristics(rows,cols,1000,rows*cols), FileFormat.BINARY); - - runTest(true, false, null, -1); - - - double[][] C1 = readMatrix(output(OUTPUT_NAME), FileFormat.BINARY, rows, cols, 1000, 1000); - double result = 0.0; - for(int i = 0; i < cols; i++) { // verify the results with Java - for(int j = 0; j < cols; j++) { - double expected = 0.0; - for (int k = 0; k < rows; k++) { - expected += mb.get(k, i) * mb.get(k, j); - } - result = C1[i][j]; - Assert.assertEquals( "value mismatch at cell ("+i+","+j+")",expected, result, eps); - } - } - - String prefix = Instruction.OOC_INST_PREFIX; - Assert.assertTrue("OOC wasn't used for RBLK", - heavyHittersContainsString(prefix + Opcodes.RBLK)); - Assert.assertTrue("OOC wasn't used for TEE", - heavyHittersContainsString(prefix + "tee")); - } - catch(Exception ex) { - Assert.fail(ex.getMessage()); - } - finally { - OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldRewrite; - resetExecMode(platformOld); - } - } - - private static double[][] readMatrix( String fname, FileFormat fmt, - long rows, long cols, int brows, int bcols ) - throws IOException - { - MatrixBlock mb = DataConverter.readMatrixFromHDFS(fname, fmt, rows, cols, brows, bcols); - double[][] C = DataConverter.convertToDoubleMatrix(mb); - return C; - } + private static final String TEST_NAME = "Tee"; + private static final String TEST_DIR = "functions/ooc/"; + private static final String TEST_CLASS_DIR = TEST_DIR + TeeTest.class.getSimpleName() + "/"; + private static final String INPUT_NAME = "X"; + private static final String OUTPUT_NAME = "res"; + private final static double eps = 1e-10; + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME)); + } + + @Test + public void testTeeNoRewrite() { + testTeeOperation(false); + } + + @Test + public void testTeeRewrite() { + testTeeOperation(true); + } + + + public void testTeeOperation(boolean rewrite) + { + ExecMode platformOld = setExecMode(ExecMode.SINGLE_NODE); + boolean oldRewrite = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION; + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrite; + + try { + getAndLoadTestConfiguration(TEST_NAME); + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + TEST_NAME + ".dml"; + programArgs = new String[] {"-explain", "-stats", "-ooc", + "-args", input(INPUT_NAME), output(OUTPUT_NAME)}; + + int rows = 1000, cols = 1000; + MatrixBlock mb = MatrixBlock.randOperations(rows, cols, 1.0, -1, 1, "uniform", 7); + MatrixWriter writer = MatrixWriterFactory.createMatrixWriter(FileFormat.BINARY); + writer.writeMatrixToHDFS(mb, input(INPUT_NAME), rows, cols, 1000, rows*cols); + HDFSTool.writeMetaDataFile(input(INPUT_NAME+".mtd"), ValueType.FP64, + new MatrixCharacteristics(rows,cols,1000,rows*cols), FileFormat.BINARY); + + runTest(true, false, null, -1); + + + double[][] C1 = readMatrix(output(OUTPUT_NAME), FileFormat.BINARY, rows, cols, 1000, 1000); + double result = 0.0; + for(int i = 0; i < cols; i++) { // verify the results with Java + for(int j = 0; j < cols; j++) { + double expected = 0.0; + for (int k = 0; k < rows; k++) { + expected += mb.get(k, i) * mb.get(k, j); + } + result = C1[i][j]; + Assert.assertEquals( "value mismatch at cell ("+i+","+j+")",expected, result, eps); + } + } + + String prefix = Instruction.OOC_INST_PREFIX; + Assert.assertTrue("OOC wasn't used for RBLK", + heavyHittersContainsString(prefix + Opcodes.RBLK)); + Assert.assertTrue("OOC wasn't used for TEE", + heavyHittersContainsString(prefix + "tee")); + } + catch(Exception ex) { + Assert.fail(ex.getMessage()); + } + finally { + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldRewrite; + resetExecMode(platformOld); + } + } + + private static double[][] readMatrix( String fname, FileFormat fmt, + long rows, long cols, int brows, int bcols ) + throws IOException + { + MatrixBlock mb = DataConverter.readMatrixFromHDFS(fname, fmt, rows, cols, brows, bcols); + double[][] C = DataConverter.convertToDoubleMatrix(mb); + return C; + } } From 57890b29cd4b20bb74f8f75ddbdbfcfc58f8833d Mon Sep 17 00:00:00 2001 From: Janardhan Pulivarthi Date: Thu, 28 Aug 2025 23:47:01 +0530 Subject: [PATCH 03/10] remove TeeOp class and use DataOp --- .../java/org/apache/sysds/hops/DataOp.java | 120 +++++++------ .../java/org/apache/sysds/hops/TeeOp.java | 170 ------------------ .../hops/rewrite/RewriteInjectOOCTee.java | 22 ++- 3 files changed, 87 insertions(+), 225 deletions(-) delete mode 100644 src/main/java/org/apache/sysds/hops/TeeOp.java diff --git a/src/main/java/org/apache/sysds/hops/DataOp.java b/src/main/java/org/apache/sysds/hops/DataOp.java index eb0d1961cf5..c2fc3842b22 100644 --- a/src/main/java/org/apache/sysds/hops/DataOp.java +++ b/src/main/java/org/apache/sysds/hops/DataOp.java @@ -32,12 +32,8 @@ import org.apache.sysds.conf.CompilerConfig.ConfigType; import org.apache.sysds.conf.ConfigurationManager; import org.apache.sysds.hops.rewrite.HopRewriteUtils; -import org.apache.sysds.lops.Data; -import org.apache.sysds.lops.Federated; -import org.apache.sysds.lops.Lop; +import org.apache.sysds.lops.*; import org.apache.sysds.common.Types.ExecType; -import org.apache.sysds.lops.LopsException; -import org.apache.sysds.lops.Sql; import org.apache.sysds.parser.DataExpression; import static org.apache.sysds.parser.DataExpression.FED_RANGES; import org.apache.sysds.runtime.controlprogram.caching.MatrixObject.UpdateType; @@ -60,6 +56,8 @@ public class DataOp extends Hop { private boolean _recompileRead = true; + private boolean _isTeeOp = false; + /** * List of "named" input parameters. They are maintained as a hashmap: * parameter names (String) are mapped as indices (Integer) into getInput() @@ -73,6 +71,10 @@ public class DataOp extends Hop { private DataOp() { //default constructor for clone } + + public void setIsTeeOp(boolean isTeeOp) { + this._isTeeOp = isTeeOp; + } /** * READ operation for Matrix w/ dim1, dim2. @@ -251,56 +253,66 @@ public Lop constructLops() ExecType et = optFindExecType(); Lop l = null; - - // construct lops for all input parameters - HashMap inputLops = new HashMap<>(); - for (Entry cur : _paramIndexMap.entrySet()) { - inputLops.put(cur.getKey(), getInput().get(cur.getValue()).constructLops()); + + if (_isTeeOp) { + Tee teeLop = new Tee(getInput().get(0).constructLops(), + getDataType(), getValueType()); + setLineNumbers(teeLop); + setLops(teeLop); + setOutputDimensions(teeLop); } + else { - // Create the lop - switch(_op) - { - case TRANSIENTREAD: - l = new Data(_op, null, inputLops, getName(), null, - getDataType(), getValueType(), getFileFormat()); - setOutputDimensions(l); - break; - - case PERSISTENTREAD: - l = new Data(_op, null, inputLops, getName(), null, - getDataType(), getValueType(), getFileFormat()); - l.getOutputParameters().setDimensions(getDim1(), getDim2(), _inBlocksize, getNnz(), getUpdateType()); - break; - - case PERSISTENTWRITE: - case FUNCTIONOUTPUT: - l = new Data(_op, getInput().get(0).constructLops(), inputLops, getName(), null, - getDataType(), getValueType(), getFileFormat()); - ((Data)l).setExecType(et); - setOutputDimensions(l); - break; - - case TRANSIENTWRITE: - l = new Data(_op, getInput().get(0).constructLops(), inputLops, getName(), null, - getDataType(), getValueType(), getFileFormat()); - setOutputDimensions(l); - break; - - case SQLREAD: - l = new Sql(inputLops, getDataType(), getValueType()); - break; - - case FEDERATED: - l = new Federated(inputLops, getDataType(), getValueType()); - break; - - default: - throw new LopsException("Invalid operation type for Data LOP: " + _op); + // construct lops for all input parameters + HashMap inputLops = new HashMap<>(); + for (Entry cur : _paramIndexMap.entrySet()) { + inputLops.put(cur.getKey(), getInput().get(cur.getValue()).constructLops()); + } + + // Create the lop + switch (_op) { + case TRANSIENTREAD: + l = new Data(_op, null, inputLops, getName(), null, + getDataType(), getValueType(), getFileFormat()); + setOutputDimensions(l); + break; + + case PERSISTENTREAD: + l = new Data(_op, null, inputLops, getName(), null, + getDataType(), getValueType(), getFileFormat()); + l.getOutputParameters().setDimensions(getDim1(), getDim2(), _inBlocksize, getNnz(), getUpdateType()); + break; + + case PERSISTENTWRITE: + case FUNCTIONOUTPUT: + l = new Data(_op, getInput().get(0).constructLops(), inputLops, getName(), null, + getDataType(), getValueType(), getFileFormat()); + ((Data) l).setExecType(et); + setOutputDimensions(l); + break; + + case TRANSIENTWRITE: + l = new Data(_op, getInput().get(0).constructLops(), inputLops, getName(), null, + getDataType(), getValueType(), getFileFormat()); + setOutputDimensions(l); + break; + + case SQLREAD: + l = new Sql(inputLops, getDataType(), getValueType()); + break; + + case FEDERATED: + l = new Federated(inputLops, getDataType(), getValueType()); + break; + + default: + throw new LopsException("Invalid operation type for Data LOP: " + _op); + } + setLineNumbers(l); + setLops(l); } - - setLineNumbers(l); - setLops(l); +// setLineNumbers(l); +// setLops(l); //add reblock/checkpoint lops if necessary constructAndSetLopsDataFlowProperties(); @@ -346,6 +358,9 @@ public boolean isFederatedDataOp(){ public String getOpString() { String s = new String(""); s += _op.toString(); + if (_isTeeOp) { + s += " tee"; + } s += " "+getName(); return s; } @@ -536,6 +551,7 @@ public Object clone() throws CloneNotSupportedException ret._inFormat = _inFormat; ret._inBlocksize = _inBlocksize; ret._recompileRead = _recompileRead; + ret._isTeeOp = _isTeeOp; // copy the Tee flag ret._paramIndexMap = (HashMap) _paramIndexMap.clone(); //note: no deep cp of params since read-only diff --git a/src/main/java/org/apache/sysds/hops/TeeOp.java b/src/main/java/org/apache/sysds/hops/TeeOp.java deleted file mode 100644 index e27cd043aad..00000000000 --- a/src/main/java/org/apache/sysds/hops/TeeOp.java +++ /dev/null @@ -1,170 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.sysds.hops; - -import org.apache.sysds.common.Types.ExecType; -import org.apache.sysds.lops.Lop; -import org.apache.sysds.lops.Tee; -import org.apache.sysds.runtime.meta.DataCharacteristics; - -import java.util.ArrayList; - - -public class TeeOp extends Hop { - - private final ArrayList _outputs = new ArrayList<>(); - - private TeeOp() { - // default constructor - } - - /** - * Takes in a single Hop input and gives two outputs - * - * @param input - */ - public TeeOp(Hop input) { - super(input.getName(), input.getDataType(), input._valueType); - - // add single input for this hop - getInput().add(0, input); - input.getParent().add(this); - - // output variables list to feed tee output into -// for (Hop out: outputs) { -// _outputs.add(out); -// } - - // This characteristics are same as the input - refreshSizeInformation(); - } - - @Override - public boolean allowsAllExecTypes() { - return false; - } - - /** - * Computes the output matrix characteristics (rows, cols, nnz) based on worst-case output - * and/or input estimates. Should return null if dimensions are unknown. - * - * @param memo memory table - * @return output characteristics - */ - @Override - protected DataCharacteristics inferOutputCharacteristics(MemoTable memo) { - return null; - } - - @Override - public Lop constructLops() { - // return already created Lops - if (getLops() != null) { - return getLops(); - } - - Tee teeLop = new Tee(getInput().get(0).constructLops(), - getDataType(), getValueType()); - setOutputDimensions(teeLop); - setLineNumbers(teeLop); - setLops(teeLop); - - return getLops(); - } - - @Override - protected ExecType optFindExecType(boolean transitive) { - return ExecType.OOC; - } - - @Override - public String getOpString() { - return "tee"; - } - - /** - * In memory-based optimizer mode (see OptimizerUtils.isMemoryBasedOptLevel()), - * the exectype is determined by checking this method as well as memory budget of this Hop. - * Please see findExecTypeByMemEstimate for more detail. - *

- * This method is necessary because not all operator are supported efficiently - * on GPU (for example: operations on frames and scalar as well as operations such as table). - * - * @return true if the Hop is eligible for GPU Exectype. - */ - @Override - public boolean isGPUEnabled() { - return false; - } - - /** - * Computes the hop-specific output memory estimate in bytes. Should be 0 if not - * applicable. - * - * @param dim1 dimension 1 - * @param dim2 dimension 2 - * @param nnz number of non-zeros - * @return memory estimate - */ - @Override - protected double computeOutputMemEstimate(long dim1, long dim2, long nnz) { - return 0; - } - - /** - * Computes the hop-specific intermediate memory estimate in bytes. Should be 0 if not - * applicable. - * - * @param dim1 dimension 1 - * @param dim2 dimension 2 - * @param nnz number of non-zeros - * @return memory estimate - */ - @Override - protected double computeIntermediateMemEstimate(long dim1, long dim2, long nnz) { - return 0; - } - - /** - * Update the output size information for this hop. - */ - @Override - public void refreshSizeInformation() { - Hop input1 = getInput().get(0); - setDim1(input1.getDim1()); - setDim2(input1.getDim2()); - setNnz(input1.getNnz()); - setBlocksize(input1.getBlocksize()); - } - - @Override - public Object clone() throws CloneNotSupportedException { - return null; - } - - @Override - public boolean compare(Hop that) { - return false; - } - - public Hop getOutput(int index) { - return _outputs.get(index); - } -} diff --git a/src/main/java/org/apache/sysds/hops/rewrite/RewriteInjectOOCTee.java b/src/main/java/org/apache/sysds/hops/rewrite/RewriteInjectOOCTee.java index f7f5422cd15..0198f7aca3f 100644 --- a/src/main/java/org/apache/sysds/hops/rewrite/RewriteInjectOOCTee.java +++ b/src/main/java/org/apache/sysds/hops/rewrite/RewriteInjectOOCTee.java @@ -21,7 +21,10 @@ import org.apache.sysds.api.DMLScript; import org.apache.sysds.common.Types; -import org.apache.sysds.hops.*; +import org.apache.sysds.hops.AggBinaryOp; +import org.apache.sysds.hops.DataOp; +import org.apache.sysds.hops.Hop; +import org.apache.sysds.hops.ReorgOp; import java.util.ArrayList; import java.util.HashMap; @@ -175,7 +178,20 @@ private void applyTopDownTeeRewrite(Hop sharedInput) { ArrayList consumers = new ArrayList<>(sharedInput.getParent()); // Create the new TeeOp with the original hop as input - TeeOp teeOp = new TeeOp(sharedInput); +// TeeOp teeOp = new TeeOp(sharedInput); + DataOp teeOp = new DataOp("tee_out_" + sharedInput.getName(), + sharedInput.getDataType(), + sharedInput.getValueType(), + Types.OpOpData.TRANSIENTWRITE, + null, + sharedInput.getDim1(), + sharedInput.getDim2(), + sharedInput.getNnz(), + sharedInput.getBlocksize() + ); + + teeOp.setIsTeeOp(true); + HopRewriteUtils.addChildReference(teeOp, sharedInput); // Rewire the graph: replace original connections with TeeOp outputs int i = 0; @@ -212,7 +228,7 @@ private void applyTopDownTeeRewrite(Hop sharedInput) { private boolean isNotAlreadyTee(Hop hop) { if (hop.getParent().size() > 1) { for (Hop consumer : hop.getParent()) { - if (consumer instanceof TeeOp) { + if (consumer instanceof DataOp) { return false; } } From 112f337577530f008de0fe586f65de900cd512d6 Mon Sep 17 00:00:00 2001 From: Janardhan Pulivarthi Date: Tue, 2 Sep 2025 00:00:40 +0530 Subject: [PATCH 04/10] add resettable stream basic structure --- .../parfor/ResettableStream.java | 27 +++++++++++++++++++ 1 file changed, 27 insertions(+) create mode 100644 src/main/java/org/apache/sysds/runtime/controlprogram/parfor/ResettableStream.java diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/ResettableStream.java b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/ResettableStream.java new file mode 100644 index 00000000000..54109cec10e --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/ResettableStream.java @@ -0,0 +1,27 @@ +package org.apache.sysds.runtime.controlprogram.parfor; + +import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue; + +import java.util.ArrayList; + +/** + * A wrapper around LocalTaskQueue to consume the source stream and reset to + * consume again for other operators. + * + */ +public class ResettableStream extends LocalTaskQueue { + + // original live stream + private final LocalTaskQueue _source; + + // in-memory cache to store stream for re-play + private final ArrayList _cache; + + public ResettableStream(LocalTaskQueue source) { + this._source = source; + this._cache = new ArrayList<>(); + } + + + +} From 79d859b750cd129f3ed78a45b346465b0524d8b6 Mon Sep 17 00:00:00 2001 From: Janardhan Pulivarthi Date: Tue, 2 Sep 2025 00:10:31 +0530 Subject: [PATCH 05/10] implement reset, closeInput method signatures --- .../parfor/ResettableStream.java | 22 +++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/ResettableStream.java b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/ResettableStream.java index 54109cec10e..0491c2de565 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/ResettableStream.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/ResettableStream.java @@ -17,11 +17,33 @@ public class ResettableStream extends LocalTaskQueue { // in-memory cache to store stream for re-play private final ArrayList _cache; + // state flags + private boolean _cacheInProgress = true; // caching in progress, in the first pass. + private int _replayPosition = 0; // slider position in the stream + public ResettableStream(LocalTaskQueue source) { this._source = source; this._cache = new ArrayList<>(); } + @Override + public synchronized IndexedMatrixValue dequeueTask() + throws InterruptedException { + // Implement dequeTask logic + return null; + } + /** + * Resets the stream to beginning to read the stream from start. + * This can only be called once the stream is fully consumed once. + */ + public void reset() { + // Implement reset logic + } + @Override + public synchronized void closeInput() { + + _source.closeInput(); + } } From 3fe964b15781239fee8d7599544a4b3436187100 Mon Sep 17 00:00:00 2001 From: Janardhan Pulivarthi Date: Tue, 2 Sep 2025 18:31:04 +0530 Subject: [PATCH 06/10] implement dequeueTask and reset --- .../parfor/ResettableStream.java | 35 +++++++++++++++++-- 1 file changed, 32 insertions(+), 3 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/ResettableStream.java b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/ResettableStream.java index 0491c2de565..709a87bd767 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/ResettableStream.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/ResettableStream.java @@ -1,5 +1,7 @@ package org.apache.sysds.runtime.controlprogram.parfor; +import org.apache.hadoop.yarn.webapp.hamlet.HamletSpec; +import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue; import java.util.ArrayList; @@ -26,11 +28,35 @@ public ResettableStream(LocalTaskQueue source) { this._cache = new ArrayList<>(); } + /** + * Dequeues a task. If it is the first, it reads from the disk and stores in the cache. + * For subsequent passes it reads from the memory. + * + * @return The next matrix value in the stream, or NO_MORE_TASKS + * @throws InterruptedException + */ @Override public synchronized IndexedMatrixValue dequeueTask() throws InterruptedException { - // Implement dequeTask logic - return null; + if (_cacheInProgress) { + // First pass: Read value from the source and cache it, and return. + IndexedMatrixValue task = _source.dequeueTask(); + if (task != NO_MORE_TASKS) { + _cache.add(new IndexedMatrixValue(task)); + } else { + _cacheInProgress = false; // caching is complete + _source.closeInput(); // close source stream + } + return task; + } else { + // Replay pass: read directly from in-memory cache + if (_replayPosition < _cache.size()) { + // Return a copy to ensure consumer won't modify the cache + return new IndexedMatrixValue(_cache.get(_replayPosition++)); + } else { + return (IndexedMatrixValue) NO_MORE_TASKS; + } + } } /** @@ -38,7 +64,10 @@ public synchronized IndexedMatrixValue dequeueTask() * This can only be called once the stream is fully consumed once. */ public void reset() { - // Implement reset logic + if (_cacheInProgress) { + throw new DMLRuntimeException("Attempted to reset a stream that's not been fully cached yet."); + } + _replayPosition = 0; } @Override From 1e69e4d9cda2976b75d03bd454ae736cb9e5abe1 Mon Sep 17 00:00:00 2001 From: Janardhan Pulivarthi Date: Tue, 2 Sep 2025 19:20:17 +0530 Subject: [PATCH 07/10] integrate resettable stream with teeooc --- .../parfor/ResettableStream.java | 7 +- .../instructions/ooc/TeeOOCInstruction.java | 101 ++++++++++++------ 2 files changed, 74 insertions(+), 34 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/ResettableStream.java b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/ResettableStream.java index 709a87bd767..9f18cc37cf6 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/ResettableStream.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/ResettableStream.java @@ -47,6 +47,7 @@ public synchronized IndexedMatrixValue dequeueTask() _cacheInProgress = false; // caching is complete _source.closeInput(); // close source stream } + notifyAll(); // Notify all the waiting consumers waiting for cache to fill with this stream return task; } else { // Replay pass: read directly from in-memory cache @@ -63,9 +64,11 @@ public synchronized IndexedMatrixValue dequeueTask() * Resets the stream to beginning to read the stream from start. * This can only be called once the stream is fully consumed once. */ - public void reset() { + public synchronized void reset() throws InterruptedException { if (_cacheInProgress) { - throw new DMLRuntimeException("Attempted to reset a stream that's not been fully cached yet."); + System.out.println("Attempted to reset a stream that's not been fully cached yet."); + wait(); +// throw new DMLRuntimeException("Attempted to reset a stream that's not been fully cached yet."); } _replayPosition = 0; } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/TeeOOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/TeeOOCInstruction.java index 52ee7778649..690970528b2 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/TeeOOCInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/TeeOOCInstruction.java @@ -19,10 +19,12 @@ package org.apache.sysds.runtime.instructions.ooc; +import org.apache.spark.sql.sources.In; import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.controlprogram.caching.MatrixObject; import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; import org.apache.sysds.runtime.controlprogram.parfor.LocalTaskQueue; +import org.apache.sysds.runtime.controlprogram.parfor.ResettableStream; import org.apache.sysds.runtime.instructions.InstructionUtils; import org.apache.sysds.runtime.instructions.cp.CPOperand; import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue; @@ -61,38 +63,73 @@ public void processInstruction( ExecutionContext ec ) { // MatrixObject min = ec.getMatrixObject(input1); // LocalTaskQueue qIn = min.getStreamHandle(); - List> qOuts = new ArrayList<>(); - for (CPOperand out : _outputs) { - MatrixObject mout = ec.createMatrixObject(min.getDataCharacteristics()); - ec.setVariable(out.getName(), mout); - LocalTaskQueue qOut = new LocalTaskQueue<>(); - mout.setStreamHandle(qOut); - qOuts.add(qOut); - } - - ExecutorService pool = CommonThreadPool.get(); - try { - pool.submit(() -> { - IndexedMatrixValue tmp = null; - try { - while ((tmp = qIn.dequeueTask()) != LocalTaskQueue.NO_MORE_TASKS) { - - for (int i = 0; i < qOuts.size(); i++) { - qOuts.get(i).enqueueTask(new IndexedMatrixValue(tmp)); - } - } - for (LocalTaskQueue qOut : qOuts) { - qOut.closeInput(); - } - } - catch(Exception ex) { - throw new DMLRuntimeException(ex); + + // Create a single, shared, resettable stream (cached) + final ResettableStream sharedStream = new ResettableStream(qIn); + + LocalTaskQueue stream2 = new LocalTaskQueue() { + private boolean isFirstCall = true; + + @Override + public IndexedMatrixValue dequeueTask() + throws InterruptedException { + if (isFirstCall) { + sharedStream.reset(); + isFirstCall = false; } - }); - } catch (Exception ex) { - throw new DMLRuntimeException(ex); - } finally { - pool.shutdown(); - } + return sharedStream.dequeueTask(); + } + + @Override + public void closeInput() { + // This a no-op, since sharedStream is managed internally + } + }; + + CPOperand out1 = _outputs.get(0); +// MatrixObject mout1 = ec.getMatrixObject(min.getDataCharacteristics()); + MatrixObject mout1 = ec.createMatrixObject(min.getDataCharacteristics()); + mout1.setStreamHandle(sharedStream); + ec.setVariable(out1.getName(), mout1); + + CPOperand out2 = _outputs.get(1); +// MatrixObject mout2 = ec.getMatrixObject(out2); + MatrixObject mout2 = ec.createMatrixObject(min.getDataCharacteristics()); + mout2.setStreamHandle(stream2); + ec.setVariable(out2.getName(), mout2); + +// List> qOuts = new ArrayList<>(); +// for (CPOperand out : _outputs) { +// MatrixObject mout = ec.createMatrixObject(min.getDataCharacteristics()); +// ec.setVariable(out.getName(), mout); +// LocalTaskQueue qOut = new LocalTaskQueue<>(); +// mout.setStreamHandle(qOut); +// qOuts.add(qOut); +// } +// +// ExecutorService pool = CommonThreadPool.get(); +// try { +// pool.submit(() -> { +// IndexedMatrixValue tmp = null; +// try { +// while ((tmp = qIn.dequeueTask()) != LocalTaskQueue.NO_MORE_TASKS) { +// +// for (int i = 0; i < qOuts.size(); i++) { +// qOuts.get(i).enqueueTask(new IndexedMatrixValue(tmp)); +// } +// } +// for (LocalTaskQueue qOut : qOuts) { +// qOut.closeInput(); +// } +// } +// catch(Exception ex) { +// throw new DMLRuntimeException(ex); +// } +// }); +// } catch (Exception ex) { +// throw new DMLRuntimeException(ex); +// } finally { +// pool.shutdown(); +// } } } From 0fda2b9e2061434d056a467580926b46b4a44d31 Mon Sep 17 00:00:00 2001 From: Janardhan Pulivarthi Date: Tue, 2 Sep 2025 19:24:04 +0530 Subject: [PATCH 08/10] remove commented, redundant code --- .../parfor/ResettableStream.java | 24 +++++++++-- .../instructions/ooc/TeeOOCInstruction.java | 43 ------------------- 2 files changed, 20 insertions(+), 47 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/ResettableStream.java b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/ResettableStream.java index 9f18cc37cf6..957b03ad7f8 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/ResettableStream.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/ResettableStream.java @@ -1,7 +1,24 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + package org.apache.sysds.runtime.controlprogram.parfor; -import org.apache.hadoop.yarn.webapp.hamlet.HamletSpec; -import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue; import java.util.ArrayList; @@ -66,9 +83,8 @@ public synchronized IndexedMatrixValue dequeueTask() */ public synchronized void reset() throws InterruptedException { if (_cacheInProgress) { - System.out.println("Attempted to reset a stream that's not been fully cached yet."); + // Attempted to reset a stream that's not been fully cached yet. wait(); -// throw new DMLRuntimeException("Attempted to reset a stream that's not been fully cached yet."); } _replayPosition = 0; } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/TeeOOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/TeeOOCInstruction.java index 690970528b2..bb5c8e24acc 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/TeeOOCInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/TeeOOCInstruction.java @@ -19,8 +19,6 @@ package org.apache.sysds.runtime.instructions.ooc; -import org.apache.spark.sql.sources.In; -import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.controlprogram.caching.MatrixObject; import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; import org.apache.sysds.runtime.controlprogram.parfor.LocalTaskQueue; @@ -28,12 +26,9 @@ import org.apache.sysds.runtime.instructions.InstructionUtils; import org.apache.sysds.runtime.instructions.cp.CPOperand; import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue; -import org.apache.sysds.runtime.util.CommonThreadPool; -import java.util.ArrayList; import java.util.Arrays; import java.util.List; -import java.util.concurrent.ExecutorService; public class TeeOOCInstruction extends ComputationOOCInstruction { @@ -61,9 +56,6 @@ public void processInstruction( ExecutionContext ec ) { MatrixObject min = ec.getMatrixObject(input1); LocalTaskQueue qIn = min.getStreamHandle(); -// MatrixObject min = ec.getMatrixObject(input1); -// LocalTaskQueue qIn = min.getStreamHandle(); - // Create a single, shared, resettable stream (cached) final ResettableStream sharedStream = new ResettableStream(qIn); @@ -87,49 +79,14 @@ public void closeInput() { }; CPOperand out1 = _outputs.get(0); -// MatrixObject mout1 = ec.getMatrixObject(min.getDataCharacteristics()); MatrixObject mout1 = ec.createMatrixObject(min.getDataCharacteristics()); mout1.setStreamHandle(sharedStream); ec.setVariable(out1.getName(), mout1); CPOperand out2 = _outputs.get(1); -// MatrixObject mout2 = ec.getMatrixObject(out2); MatrixObject mout2 = ec.createMatrixObject(min.getDataCharacteristics()); mout2.setStreamHandle(stream2); ec.setVariable(out2.getName(), mout2); -// List> qOuts = new ArrayList<>(); -// for (CPOperand out : _outputs) { -// MatrixObject mout = ec.createMatrixObject(min.getDataCharacteristics()); -// ec.setVariable(out.getName(), mout); -// LocalTaskQueue qOut = new LocalTaskQueue<>(); -// mout.setStreamHandle(qOut); -// qOuts.add(qOut); -// } -// -// ExecutorService pool = CommonThreadPool.get(); -// try { -// pool.submit(() -> { -// IndexedMatrixValue tmp = null; -// try { -// while ((tmp = qIn.dequeueTask()) != LocalTaskQueue.NO_MORE_TASKS) { -// -// for (int i = 0; i < qOuts.size(); i++) { -// qOuts.get(i).enqueueTask(new IndexedMatrixValue(tmp)); -// } -// } -// for (LocalTaskQueue qOut : qOuts) { -// qOut.closeInput(); -// } -// } -// catch(Exception ex) { -// throw new DMLRuntimeException(ex); -// } -// }); -// } catch (Exception ex) { -// throw new DMLRuntimeException(ex); -// } finally { -// pool.shutdown(); -// } } } From cbd6f4cf6ded732785a9969910c1716632cb0d9d Mon Sep 17 00:00:00 2001 From: Janardhan Pulivarthi Date: Tue, 2 Sep 2025 19:31:35 +0530 Subject: [PATCH 09/10] address review comments, 1/3 --- src/main/java/org/apache/sysds/hops/DataOp.java | 2 -- .../apache/sysds/hops/rewrite/RewriteInjectOOCTee.java | 8 +++----- src/main/java/org/apache/sysds/lops/Tee.java | 1 - 3 files changed, 3 insertions(+), 8 deletions(-) diff --git a/src/main/java/org/apache/sysds/hops/DataOp.java b/src/main/java/org/apache/sysds/hops/DataOp.java index c2fc3842b22..64bdf2c2dc2 100644 --- a/src/main/java/org/apache/sysds/hops/DataOp.java +++ b/src/main/java/org/apache/sysds/hops/DataOp.java @@ -311,8 +311,6 @@ public Lop constructLops() setLineNumbers(l); setLops(l); } -// setLineNumbers(l); -// setLops(l); //add reblock/checkpoint lops if necessary constructAndSetLopsDataFlowProperties(); diff --git a/src/main/java/org/apache/sysds/hops/rewrite/RewriteInjectOOCTee.java b/src/main/java/org/apache/sysds/hops/rewrite/RewriteInjectOOCTee.java index 0198f7aca3f..6fa8f1849f0 100644 --- a/src/main/java/org/apache/sysds/hops/rewrite/RewriteInjectOOCTee.java +++ b/src/main/java/org/apache/sysds/hops/rewrite/RewriteInjectOOCTee.java @@ -178,7 +178,6 @@ private void applyTopDownTeeRewrite(Hop sharedInput) { ArrayList consumers = new ArrayList<>(sharedInput.getParent()); // Create the new TeeOp with the original hop as input -// TeeOp teeOp = new TeeOp(sharedInput); DataOp teeOp = new DataOp("tee_out_" + sharedInput.getName(), sharedInput.getDataType(), sharedInput.getValueType(), @@ -243,14 +242,13 @@ private boolean isTranposePattern (Hop hop) { for (Hop parent: hop.getParent()) { String opString = parent.getOpString(); if (parent instanceof ReorgOp) { - if (opString.contains("r'") || opString.contains("transpose")) { + if (HopRewriteUtils.isTransposeOperation(parent)) { hasTransposeConsumer = true; } } - else if (parent instanceof AggBinaryOp) - if (opString.contains("*") || opString.contains("ba+*")) { + else if (HopRewriteUtils.isMatrixMultiply(parent)) { hasMatrixMultiplyConsumer = true; - } + } } return hasTransposeConsumer && hasMatrixMultiplyConsumer; } diff --git a/src/main/java/org/apache/sysds/lops/Tee.java b/src/main/java/org/apache/sysds/lops/Tee.java index 113bd37b12d..6734e69cf31 100644 --- a/src/main/java/org/apache/sysds/lops/Tee.java +++ b/src/main/java/org/apache/sysds/lops/Tee.java @@ -50,7 +50,6 @@ public String toString() { public String getInstructions(String input1, String outputs) { String[] out = outputs.split(Lop.OPERAND_DELIMITOR); - String output2 = outputs + "_copy"; // This method generates the instruction string: OOC°tee°input°output1°output2... String ret = InstructionUtils.concatOperands( From a2e51cbca43e2815be4cb9be623b07e7b80ade23 Mon Sep 17 00:00:00 2001 From: Janardhan Pulivarthi Date: Tue, 2 Sep 2025 19:35:15 +0530 Subject: [PATCH 10/10] address review comments, 2/3 --- src/main/java/org/apache/sysds/hops/DataOp.java | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/main/java/org/apache/sysds/hops/DataOp.java b/src/main/java/org/apache/sysds/hops/DataOp.java index 64bdf2c2dc2..0aac3ca70c8 100644 --- a/src/main/java/org/apache/sysds/hops/DataOp.java +++ b/src/main/java/org/apache/sysds/hops/DataOp.java @@ -32,8 +32,13 @@ import org.apache.sysds.conf.CompilerConfig.ConfigType; import org.apache.sysds.conf.ConfigurationManager; import org.apache.sysds.hops.rewrite.HopRewriteUtils; -import org.apache.sysds.lops.*; +import org.apache.sysds.lops.Data; +import org.apache.sysds.lops.Federated; +import org.apache.sysds.lops.Lop; import org.apache.sysds.common.Types.ExecType; +import org.apache.sysds.lops.LopsException; +import org.apache.sysds.lops.Sql; +import org.apache.sysds.lops.Tee; import org.apache.sysds.parser.DataExpression; import static org.apache.sysds.parser.DataExpression.FED_RANGES; import org.apache.sysds.runtime.controlprogram.caching.MatrixObject.UpdateType;