Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions src/main/java/org/apache/sysds/hops/AggBinaryOp.java
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
import org.apache.sysds.runtime.util.UtilFunctions;


/* Aggregate binary (cell operations): Sum (aij + bij)
* Properties:
* Inner Symbol: *, -, +, ...
Expand Down Expand Up @@ -515,14 +514,14 @@ private void constructCPLopsMMChain(ChainType chain) {
if (chain == ChainType.XtXv) {
Hop hX = getInput().get(0).getInput().get(0);
Hop hv = getInput().get(1).getInput().get(1);
mapmmchain = new MapMultChain(hX.constructLops(), hv.constructLops(), getDataType(), getValueType(), ExecType.CP);
mapmmchain = new MapMultChain(hX.constructLops(), hv.constructLops(), getDataType(), getValueType(), DMLScript.USE_OOC ? ExecType.OOC : ExecType.CP);
} else { //ChainType.XtwXv / ChainType.XtwXvy
int wix = (chain == ChainType.XtwXv) ? 0 : 1;
int vix = (chain == ChainType.XtwXv) ? 1 : 0;
Hop hX = getInput().get(0).getInput().get(0);
Hop hw = getInput().get(1).getInput().get(wix);
Hop hv = getInput().get(1).getInput().get(vix).getInput().get(1);
mapmmchain = new MapMultChain(hX.constructLops(), hv.constructLops(), hw.constructLops(), chain, getDataType(), getValueType(), ExecType.CP);
mapmmchain = new MapMultChain(hX.constructLops(), hv.constructLops(), hw.constructLops(), chain, getDataType(), getValueType(), DMLScript.USE_OOC ? ExecType.OOC : ExecType.CP);
}

//set degree of parallelism
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,15 @@ public class DMLRuntimeException extends DMLException
{
private static final long serialVersionUID = 1L;

public static DMLRuntimeException of(Throwable t) {
return t instanceof DMLRuntimeException ? (DMLRuntimeException) t : new DMLRuntimeException(t);
}

public DMLRuntimeException(String string) {
super(string);
}

public DMLRuntimeException(Exception e) {
public DMLRuntimeException(Throwable e) {
super(e);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -469,10 +469,11 @@ public BroadcastObject<T> getBroadcastHandle() {
public boolean hasBroadcastHandle() {
return _bcHandle != null && _bcHandle.hasBackReference();
}

public synchronized OOCStream<IndexedMatrixValue> getStreamHandle() {
if( !hasStreamHandle() ) {
final SubscribableTaskQueue<IndexedMatrixValue> _mStream = new SubscribableTaskQueue<>();
_mStream.setData(this);
DataCharacteristics dc = getDataCharacteristics();
MatrixBlock src = (MatrixBlock)acquireReadAndRelease();
_streamHandle = _mStream;
Expand All @@ -489,7 +490,7 @@ public synchronized OOCStream<IndexedMatrixValue> getStreamHandle() {
}

OOCStream<IndexedMatrixValue> stream = _streamHandle.getReadStream();
if (!stream.hasStreamCache())
if(!stream.hasStreamCache())
_streamHandle = null; // To ensure read once
return stream;
}
Expand Down Expand Up @@ -539,6 +540,7 @@ public synchronized void removeGPUObject(GPUContext gCtx) {
}

public synchronized void setStreamHandle(OOCStreamable<IndexedMatrixValue> q) {
q.setData(this);
_streamHandle = q;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.common.InstructionType;
import org.apache.sysds.common.Opcodes;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.instructions.ooc.AggregateTernaryOOCInstruction;
import org.apache.sysds.runtime.instructions.ooc.AggregateUnaryOOCInstruction;
Expand All @@ -37,7 +38,8 @@
import org.apache.sysds.runtime.instructions.ooc.TernaryOOCInstruction;
import org.apache.sysds.runtime.instructions.ooc.TSMMOOCInstruction;
import org.apache.sysds.runtime.instructions.ooc.UnaryOOCInstruction;
import org.apache.sysds.runtime.instructions.ooc.MatrixVectorBinaryOOCInstruction;
import org.apache.sysds.runtime.instructions.ooc.MMultOOCInstruction;
import org.apache.sysds.runtime.instructions.ooc.MapMMChainOOCInstruction;
import org.apache.sysds.runtime.instructions.ooc.ReorgOOCInstruction;
import org.apache.sysds.runtime.instructions.ooc.TeeOOCInstruction;

Expand Down Expand Up @@ -72,11 +74,22 @@ public static OOCInstruction parseSingleInstruction(InstructionType ooctype, Str
return UnaryOOCInstruction.parseInstruction(str);
case Binary:
return BinaryOOCInstruction.parseInstruction(str);
case Builtin:
String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
if(parts[0].equals(Opcodes.LOG.toString()) || parts[0].equals(Opcodes.LOGNZ.toString())) {
if(parts.length == 3)
return UnaryOOCInstruction.parseInstruction(str);
else if(parts.length == 4)
return BinaryOOCInstruction.parseInstruction(str);
}
throw new DMLRuntimeException("Invalid Builtin Instruction: " + str);
case Ternary:
return TernaryOOCInstruction.parseInstruction(str);
case AggregateBinary:
case MAPMM:
return MatrixVectorBinaryOOCInstruction.parseInstruction(str);
return MMultOOCInstruction.parseInstruction(str);
case MAPMMCHAIN:
return MapMMChainOOCInstruction.parseInstruction(str);
case MMTSJ:
return TSMMOOCInstruction.parseInstruction(str);
case Reorg:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -794,6 +794,10 @@ private void processMoveInstruction(ExecutionContext ec) {
// cleanup matrix/frame/list data if necessary
if( srcData.getDataType().isMatrix() || srcData.getDataType().isFrame() ) {
Data tgtData = ec.removeVariable(getInput2().getName());

if (DMLScript.USE_OOC && tgtData instanceof MatrixObject)
TeeOOCInstruction.incrRef(((MatrixObject) tgtData).getStreamable(), -1);

if( tgtData != null && srcData != tgtData )
ec.cleanupDataObject(tgtData);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,13 @@
import org.apache.sysds.runtime.matrix.operators.AggregateTernaryOperator;
import org.apache.sysds.runtime.matrix.operators.Operator;
import org.apache.sysds.runtime.meta.DataCharacteristics;
import org.apache.sysds.runtime.util.IndexRange;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.function.Function;

public class AggregateTernaryOOCInstruction extends ComputationOOCInstruction {

Expand Down Expand Up @@ -111,17 +111,13 @@ private void processReduceAll(ExecutionContext ec, AggregateTernaryOperator abOp
if(qIn3 != null)
streams.add(qIn3);

List<Function<IndexedMatrixValue, MatrixIndexes>> keyFns = new ArrayList<>();
for(int i = 0; i < streams.size(); i++)
keyFns.add(IndexedMatrixValue::getIndexes);

CompletableFuture<Void> fut = joinOOC(streams, qMid, blocks -> {
MatrixBlock b1 = (MatrixBlock) blocks.get(0).getValue();
MatrixBlock b2 = (MatrixBlock) blocks.get(1).getValue();
MatrixBlock b3 = blocks.size() == 3 ? (MatrixBlock) blocks.get(2).getValue() : null;
MatrixBlock partial = MatrixBlock.aggregateTernaryOperations(b1, b2, b3, new MatrixBlock(), abOp, false);
return new IndexedMatrixValue(blocks.get(0).getIndexes(), partial);
}, keyFns);
}, IndexedMatrixValue::getIndexes);

try {
IndexedMatrixValue imv;
Expand Down Expand Up @@ -159,17 +155,26 @@ private void processReduceRow(ExecutionContext ec, AggregateTernaryOperator abOp
if(qIn3 != null)
streams.add(qIn3);

List<Function<IndexedMatrixValue, MatrixIndexes>> keyFns = new ArrayList<>();
for(int i = 0; i < streams.size(); i++)
keyFns.add(IndexedMatrixValue::getIndexes);
for (OOCStream<IndexedMatrixValue> stream : streams)
stream.setDownstreamMessageRelay(qOut::messageDownstream);

qOut.setUpstreamMessageRelay(msg ->
streams.forEach(stream -> stream.messageUpstream(streams.size() > 1 ? msg.split() : msg)));

qOut.setIXTransform((downstream, range) -> {
if (downstream)
return new IndexRange(1, 1, range.colStart, range.colEnd);
else
return new IndexRange(1, dc.getRows(), range.colStart, range.colEnd);
});

CompletableFuture<Void> fut = joinOOC(streams, qMid, blocks -> {
MatrixBlock b1 = (MatrixBlock) blocks.get(0).getValue();
MatrixBlock b2 = (MatrixBlock) blocks.get(1).getValue();
MatrixBlock b3 = blocks.size() == 3 ? (MatrixBlock) blocks.get(2).getValue() : null;
MatrixBlock partial = MatrixBlock.aggregateTernaryOperations(b1, b2, b3, new MatrixBlock(), abOp, false);
return new IndexedMatrixValue(blocks.get(0).getIndexes(), partial);
}, keyFns);
}, IndexedMatrixValue::getIndexes);

final Map<Long, MatrixBlock> aggMap = new HashMap<>();
final Map<Long, MatrixBlock> corrMap = new HashMap<>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@
import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator;
import org.apache.sysds.runtime.matrix.operators.Operator;
import org.apache.sysds.runtime.meta.DataCharacteristics;
import org.apache.sysds.runtime.ooc.stream.StreamContext;
import org.apache.sysds.runtime.util.IndexRange;

import java.util.HashMap;

Expand Down Expand Up @@ -90,6 +92,21 @@ public void processInstruction( ExecutionContext ec ) {

ec.getMatrixObject(output).setStreamHandle(qOut);

qIn.setDownstreamMessageRelay(qOut::messageDownstream);
qOut.setUpstreamMessageRelay(qIn::messageUpstream);
qOut.setIXTransform((downstream, range) -> {
if (downstream) {
if (aggun.isRowAggregate())
return new IndexRange(range.rowStart, range.rowEnd, 1, 1);
else
return new IndexRange(1, 1, range.colStart, range.colEnd);
}
if (aggun.isRowAggregate())
return new IndexRange(range.rowStart, range.rowEnd, 1, min.getNumColumns() - 1);
else
return new IndexRange(1, min.getNumRows() - 1, range.colStart, range.colEnd);
});

// per-block aggregation (parallel map)
mapOOC(qIn, qLocal, tmp -> {
MatrixIndexes midx = aggun.isRowAggregate() ?
Expand Down Expand Up @@ -134,7 +151,7 @@ public void processInstruction( ExecutionContext ec ) {
}
}
qOut.closeInput();
});
}, new StreamContext().addOutStream(qOut));
}
// full aggregation
else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,12 @@ protected void processMatrixMatrixInstruction(ExecutionContext ec) {
OOCStream<IndexedMatrixValue> qIn2 = m2.getStreamHandle();
OOCStream<IndexedMatrixValue> qOut = new SubscribableTaskQueue<>();
ec.getMatrixObject(output).setStreamHandle(qOut);
qIn1.setDownstreamMessageRelay(qOut::messageDownstream);
qIn2.setDownstreamMessageRelay(qOut::messageDownstream);
qOut.setUpstreamMessageRelay(msg -> {
qIn1.messageUpstream(msg.split());
qIn2.messageUpstream(msg.split());
});

if (m1.getNumRows() < 0 || m1.getNumColumns() < 0 || m2.getNumRows() < 0 || m2.getNumColumns() < 0)
throw new DMLRuntimeException("Cannot process (matrix, matrix) BinaryOOCInstruction with unknown dimensions.");
Expand Down Expand Up @@ -116,8 +122,6 @@ else if (isRowBroadcast && !isColBroadcast) {
return tmpOut;
}, IndexedMatrixValue::getIndexes);
}


}

protected void processScalarMatrixInstruction(ExecutionContext ec) {
Expand All @@ -131,6 +135,8 @@ protected void processScalarMatrixInstruction(ExecutionContext ec) {
OOCStream<IndexedMatrixValue> qIn = min.getStreamHandle();
OOCStream<IndexedMatrixValue> qOut = createWritableStream();
ec.getMatrixObject(output).setStreamHandle(qOut);
qIn.setDownstreamMessageRelay(qOut::messageDownstream);
qOut.setUpstreamMessageRelay(qIn::messageUpstream);

mapOOC(qIn, qOut, tmp -> {
IndexedMatrixValue tmpOut = new IndexedMatrixValue();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import org.apache.sysds.runtime.io.ReaderTextCSVParallel;
import org.apache.sysds.runtime.matrix.operators.Operator;
import org.apache.sysds.runtime.meta.DataCharacteristics;
import org.apache.sysds.runtime.ooc.stream.StreamContext;

public class CSVReblockOOCInstruction extends ComputationOOCInstruction {
private final int blen;
Expand Down Expand Up @@ -80,7 +81,7 @@ public void processInstruction(ExecutionContext ec) {
catch(Exception ex) {
throw (ex instanceof DMLRuntimeException) ? (DMLRuntimeException) ex : new DMLRuntimeException(ex);
}
}, qOut);
}, new StreamContext().addOutStream(qOut));

MatrixObject mout = ec.getMatrixObject(output);
mout.setStreamHandle(qOut);
Expand Down
Loading
Loading