Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 19 additions & 7 deletions src/main/java/org/apache/sysds/hops/AggBinaryOp.java
Original file line number Diff line number Diff line change
Expand Up @@ -241,13 +241,25 @@ public Lop constructLops() {
throw new HopsException(this.printErrorLocation() + "Invalid Matrix Mult Method (" + _method + ") while constructing SPARK lops.");
}
} else if (et == ExecType.OOC) {
Lop in1 = getInput().get(0).constructLops();
Lop in2 = getInput().get(1).constructLops();
MatMultCP matmult = new MatMultCP(in1, in2, getDataType(), getValueType(),
et, OptimizerUtils.getConstrainedNumThreads(_maxNumThreads));
setOutputDimensions(matmult);
setLineNumbers(matmult);
setLops(matmult);
_method = optFindMMultMethodCP(input1.getDim1(), input1.getDim2(),
input2.getDim1(), input2.getDim2(), mmtsj, chain, _hasLeftPMInput);

switch (_method) {
case TSMM:
constructCPLopsTSMM(mmtsj, et);
break;
case MM:
Lop in1 = getInput().get(0).constructLops();
Lop in2 = getInput().get(1).constructLops();
MatMultCP matmult = new MatMultCP(in1, in2, getDataType(), getValueType(),
et, OptimizerUtils.getConstrainedNumThreads(_maxNumThreads));
setOutputDimensions(matmult);
setLineNumbers(matmult);
setLops(matmult);
break;
default:
throw new HopsException(this.printErrorLocation() + "Invalid Matrix Mult Method (" + _method + ") while constructing CP lops.");
}
}
} else
throw new HopsException(this.printErrorLocation() + "Invalid operation in AggBinary Hop, aggBin(" + innerOp + "," + outerOp + ") while constructing lops.");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import org.apache.sysds.runtime.instructions.ooc.ReblockOOCInstruction;
import org.apache.sysds.runtime.instructions.ooc.UnaryOOCInstruction;
import org.apache.sysds.runtime.instructions.ooc.MatrixVectorBinaryOOCInstruction;
import org.apache.sysds.runtime.instructions.ooc.TransposeSelfMMOOCInstruction;

public class OOCInstructionParser extends InstructionParser {
protected static final Log LOG = LogFactory.getLog(OOCInstructionParser.class.getName());
Expand Down Expand Up @@ -60,6 +61,9 @@ public static OOCInstruction parseSingleInstruction(InstructionType ooctype, Str
case AggregateBinary:
case MAPMM:
return MatrixVectorBinaryOOCInstruction.parseInstruction(str);
case TSMM:
case MMTSJ:
return TransposeSelfMMOOCInstruction.parseInstruction(str);

default:
throw new DMLRuntimeException("Invalid OOC Instruction Type: " + ooctype);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ public abstract class OOCInstruction extends Instruction {
protected static final Log LOG = LogFactory.getLog(OOCInstruction.class.getName());

public enum OOCType {
Reblock, AggregateUnary, Binary, Unary, MAPMM, AggregateBinary
Reblock, AggregateUnary, Binary, Unary, MAPMM, AggregateBinary, TSMM, MMTSJ
}

protected final OOCInstruction.OOCType _ooctype;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.sysds.runtime.instructions.ooc;

import org.apache.sysds.common.Opcodes;
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.lops.MMTSJ;
import org.apache.sysds.lops.MMTSJ.MMTSJType;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.parfor.LocalTaskQueue;
import org.apache.sysds.runtime.functionobjects.Multiply;
import org.apache.sysds.runtime.functionobjects.Plus;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.cp.MMTSJCPInstruction;
import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
import org.apache.sysds.runtime.matrix.operators.AggregateBinaryOperator;
import org.apache.sysds.runtime.matrix.operators.AggregateOperator;
import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
import org.apache.sysds.runtime.matrix.operators.Operator;
import org.apache.sysds.runtime.util.CommonThreadPool;

import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.ExecutorService;

public class TransposeSelfMMOOCInstruction extends ComputationOOCInstruction {

private final MMTSJType _tstype;

protected TransposeSelfMMOOCInstruction(OOCType type, Operator op, CPOperand in1, CPOperand out, MMTSJType tstype, String opcode, String istr) {
super(type, op, in1, out, opcode, istr);
_tstype = tstype;
}

public static TransposeSelfMMOOCInstruction parseInstruction(String str) {
String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
InstructionUtils.checkNumFields(parts, 3);
String opcode = parts[0];
CPOperand in1 = new CPOperand(parts[1]); // the larget matrix (streamed)
CPOperand out = new CPOperand(parts[2]);
MMTSJType tstype = MMTSJType.valueOf(parts[3]);

AggregateOperator agg = new AggregateOperator(0, Plus.getPlusFnObject());
AggregateBinaryOperator ba = new AggregateBinaryOperator(Multiply.getMultiplyFnObject(), agg);

return new TransposeSelfMMOOCInstruction(OOCType.TSMM, ba, in1, out, tstype, opcode, str);
}

@Override
public void processInstruction( ExecutionContext ec ) {
// 1. Identify the inputs
MatrixObject min = ec.getMatrixObject(input1);
long cols = min.getNumColumns();

LocalTaskQueue<IndexedMatrixValue> qIn = min.getStreamHandle();
BinaryOperator plus = InstructionUtils.parseBinaryOperator(Opcodes.PLUS.toString());

// 1. create an empty accumulator for the result
MatrixBlock result = new MatrixBlock((int)cols, (int)cols, false);

IndexedMatrixValue tmp = null;
try {
// 2. consume a stream of X blocks synchronously on main thread
while ((tmp = qIn.dequeueTask()) != LocalTaskQueue.NO_MORE_TASKS) {
MatrixBlock matrixBlock = (MatrixBlock) tmp.getValue();

// 3. compute the local transpose self: t(block) %*% block
MatrixBlock partialResult = matrixBlock.transposeSelfMatrixMultOperations(new MatrixBlock(), _tstype);

// 4. aggregate the partial result into final accumulator block
result.binaryOperationsInPlace(plus, partialResult);

}
// 5. once the stream is exhausted, set the final, aggregated block as the output
ec.setMatrixOutput(output.getName(), result); // single in-memory matrix block
}
catch(Exception ex) {
throw new DMLRuntimeException(ex);
}
finally {
// ec.releaseMatrixInput(input1.getName());
}

}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.sysds.test.functions.ooc;

import org.apache.sysds.common.Opcodes;
import org.apache.sysds.common.Types;
import org.apache.sysds.runtime.instructions.Instruction;
import org.apache.sysds.runtime.io.MatrixWriter;
import org.apache.sysds.runtime.io.MatrixWriterFactory;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
import org.apache.sysds.runtime.util.DataConverter;
import org.apache.sysds.runtime.util.HDFSTool;
import org.apache.sysds.test.AutomatedTestBase;
import org.apache.sysds.test.TestConfiguration;
import org.apache.sysds.test.TestUtils;
import org.junit.Assert;
import org.junit.Test;

import java.io.IOException;

public class TransposeSelfMMTest extends AutomatedTestBase {
private final static String TEST_NAME1 = "TransposeSelfMM";
private final static String TEST_DIR = "functions/ooc/";
private final static String TEST_CLASS_DIR = TEST_DIR + TransposeSelfMMTest.class.getSimpleName() + "/";
private final static double eps = 1e-10;
private static final String INPUT_NAME = "X";
private static final String OUTPUT_NAME = "res";

private final static int rows = 5000;
private final static int cols_wide = 2000;
private final static int cols_skinny = 500;

private final static double sparsity1 = 0.7;
private final static double sparsity2 = 0.1;

@Override
public void setUp() {
TestUtils.clearAssertionInformation();
TestConfiguration config = new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1);
addTestConfiguration(TEST_NAME1, config);
}

@Test
public void testTransposeSelfMatrixMultiplication1() {
runMatrixVectorMultiplicationTest(cols_wide, false);
}

@Test
public void testTransposeSelfMatrixMultiplication2() {
runMatrixVectorMultiplicationTest(cols_skinny, false);
}

private void runMatrixVectorMultiplicationTest(int cols, boolean sparse )
{
Types.ExecMode platformOld = setExecMode(Types.ExecMode.SINGLE_NODE);

try
{
getAndLoadTestConfiguration(TEST_NAME1);
String HOME = SCRIPT_DIR + TEST_DIR;
fullDMLScriptName = HOME + TEST_NAME1 + ".dml";
programArgs = new String[]{"-explain", "-stats", "-ooc",
"-args", input(INPUT_NAME), output(OUTPUT_NAME)};

// 1. Generate the data in-memory as MatrixBlock objects
double[][] A_data = getRandomMatrix(rows, cols, 0, 1, sparse?sparsity2:sparsity1, 10);

// 2. Convert the double arrays to MatrixBlock objects
MatrixBlock A_mb = DataConverter.convertToMatrixBlock(A_data);

// 3. Create a binary matrix writer
MatrixWriter writer = MatrixWriterFactory.createMatrixWriter(Types.FileFormat.BINARY);

// 4. Write matrix A to a binary SequenceFile
writer.writeMatrixToHDFS(A_mb, input(INPUT_NAME), rows, cols, 1000, A_mb.getNonZeros());
HDFSTool.writeMetaDataFile(input(INPUT_NAME + ".mtd"), Types.ValueType.FP64,
new MatrixCharacteristics(rows, cols, 1000, A_mb.getNonZeros()), Types.FileFormat.BINARY);

boolean exceptionExpected = false;
runTest(true, exceptionExpected, null, -1);

double[][] C1 = readMatrix(output(OUTPUT_NAME), Types.FileFormat.BINARY, rows, cols, 1000, 1000);
double result = 0.0;
for(int i = 0; i < cols; i++) { // verify the results with Java
for(int j = 0; j < cols; j++) {
double expected = 0.0;
for (int k = 0; k < rows; k++) {
expected += A_mb.get(k, i) * A_mb.get(k, j);
}
result = C1[i][j];
Assert.assertEquals( "value mismatch at cell ("+i+","+j+")",expected, result, eps);
}
}

String prefix = Instruction.OOC_INST_PREFIX;
Assert.assertTrue("OOC wasn't used for RBLK",
heavyHittersContainsString(prefix + Opcodes.RBLK));
Assert.assertTrue("OOC wasn't used for TSMM",
heavyHittersContainsString(prefix + Opcodes.TSMM));
}
catch (IOException e) {
throw new RuntimeException(e);
}
finally {
resetExecMode(platformOld);
}
}

private static double[][] readMatrix(String fname, Types.FileFormat fmt, long rows, long cols, int brows, int bcols )
throws IOException
{
MatrixBlock mb = DataConverter.readMatrixFromHDFS(fname, fmt, rows, cols, brows, bcols);
double[][] C = DataConverter.convertToDoubleMatrix(mb);
return C;
}
}
28 changes: 28 additions & 0 deletions src/test/scripts/functions/ooc/TransposeSelfMM.dml
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#-------------------------------------------------------------
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
#-------------------------------------------------------------

# Read input matrix and operator from command line args
X = read($1);

# Operation under test
res = t(X) %*% X;

write(res, $2, format="binary")