diff --git a/src/main/java/org/apache/sysds/hops/AggBinaryOp.java b/src/main/java/org/apache/sysds/hops/AggBinaryOp.java index e3cf95e5732..1f51cad5edc 100644 --- a/src/main/java/org/apache/sysds/hops/AggBinaryOp.java +++ b/src/main/java/org/apache/sysds/hops/AggBinaryOp.java @@ -50,7 +50,6 @@ import org.apache.sysds.runtime.meta.MatrixCharacteristics; import org.apache.sysds.runtime.util.UtilFunctions; - /* Aggregate binary (cell operations): Sum (aij + bij) * Properties: * Inner Symbol: *, -, +, ... @@ -515,14 +514,14 @@ private void constructCPLopsMMChain(ChainType chain) { if (chain == ChainType.XtXv) { Hop hX = getInput().get(0).getInput().get(0); Hop hv = getInput().get(1).getInput().get(1); - mapmmchain = new MapMultChain(hX.constructLops(), hv.constructLops(), getDataType(), getValueType(), ExecType.CP); + mapmmchain = new MapMultChain(hX.constructLops(), hv.constructLops(), getDataType(), getValueType(), DMLScript.USE_OOC ? ExecType.OOC : ExecType.CP); } else { //ChainType.XtwXv / ChainType.XtwXvy int wix = (chain == ChainType.XtwXv) ? 0 : 1; int vix = (chain == ChainType.XtwXv) ? 1 : 0; Hop hX = getInput().get(0).getInput().get(0); Hop hw = getInput().get(1).getInput().get(wix); Hop hv = getInput().get(1).getInput().get(vix).getInput().get(1); - mapmmchain = new MapMultChain(hX.constructLops(), hv.constructLops(), hw.constructLops(), chain, getDataType(), getValueType(), ExecType.CP); + mapmmchain = new MapMultChain(hX.constructLops(), hv.constructLops(), hw.constructLops(), chain, getDataType(), getValueType(), DMLScript.USE_OOC ? ExecType.OOC : ExecType.CP); } //set degree of parallelism diff --git a/src/main/java/org/apache/sysds/runtime/DMLRuntimeException.java b/src/main/java/org/apache/sysds/runtime/DMLRuntimeException.java index eee2e39bf95..1f7bdccdc4c 100644 --- a/src/main/java/org/apache/sysds/runtime/DMLRuntimeException.java +++ b/src/main/java/org/apache/sysds/runtime/DMLRuntimeException.java @@ -28,11 +28,15 @@ public class DMLRuntimeException extends DMLException { private static final long serialVersionUID = 1L; + public static DMLRuntimeException of(Throwable t) { + return t instanceof DMLRuntimeException ? (DMLRuntimeException) t : new DMLRuntimeException(t); + } + public DMLRuntimeException(String string) { super(string); } - public DMLRuntimeException(Exception e) { + public DMLRuntimeException(Throwable e) { super(e); } diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java index 36637ee8959..f41b0511ee9 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java @@ -469,10 +469,11 @@ public BroadcastObject getBroadcastHandle() { public boolean hasBroadcastHandle() { return _bcHandle != null && _bcHandle.hasBackReference(); } - + public synchronized OOCStream getStreamHandle() { if( !hasStreamHandle() ) { final SubscribableTaskQueue _mStream = new SubscribableTaskQueue<>(); + _mStream.setData(this); DataCharacteristics dc = getDataCharacteristics(); MatrixBlock src = (MatrixBlock)acquireReadAndRelease(); _streamHandle = _mStream; @@ -489,7 +490,7 @@ public synchronized OOCStream getStreamHandle() { } OOCStream stream = _streamHandle.getReadStream(); - if (!stream.hasStreamCache()) + if(!stream.hasStreamCache()) _streamHandle = null; // To ensure read once return stream; } @@ -539,6 +540,7 @@ public synchronized void removeGPUObject(GPUContext gCtx) { } public synchronized void setStreamHandle(OOCStreamable q) { + q.setData(this); _streamHandle = q; } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java b/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java index a2e64dd0bac..607acbb3a0c 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java @@ -22,6 +22,7 @@ import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.sysds.common.InstructionType; +import org.apache.sysds.common.Opcodes; import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.instructions.ooc.AggregateTernaryOOCInstruction; import org.apache.sysds.runtime.instructions.ooc.AggregateUnaryOOCInstruction; @@ -37,7 +38,8 @@ import org.apache.sysds.runtime.instructions.ooc.TernaryOOCInstruction; import org.apache.sysds.runtime.instructions.ooc.TSMMOOCInstruction; import org.apache.sysds.runtime.instructions.ooc.UnaryOOCInstruction; -import org.apache.sysds.runtime.instructions.ooc.MatrixVectorBinaryOOCInstruction; +import org.apache.sysds.runtime.instructions.ooc.MMultOOCInstruction; +import org.apache.sysds.runtime.instructions.ooc.MapMMChainOOCInstruction; import org.apache.sysds.runtime.instructions.ooc.ReorgOOCInstruction; import org.apache.sysds.runtime.instructions.ooc.TeeOOCInstruction; @@ -72,11 +74,22 @@ public static OOCInstruction parseSingleInstruction(InstructionType ooctype, Str return UnaryOOCInstruction.parseInstruction(str); case Binary: return BinaryOOCInstruction.parseInstruction(str); + case Builtin: + String[] parts = InstructionUtils.getInstructionPartsWithValueType(str); + if(parts[0].equals(Opcodes.LOG.toString()) || parts[0].equals(Opcodes.LOGNZ.toString())) { + if(parts.length == 3) + return UnaryOOCInstruction.parseInstruction(str); + else if(parts.length == 4) + return BinaryOOCInstruction.parseInstruction(str); + } + throw new DMLRuntimeException("Invalid Builtin Instruction: " + str); case Ternary: return TernaryOOCInstruction.parseInstruction(str); case AggregateBinary: case MAPMM: - return MatrixVectorBinaryOOCInstruction.parseInstruction(str); + return MMultOOCInstruction.parseInstruction(str); + case MAPMMCHAIN: + return MapMMChainOOCInstruction.parseInstruction(str); case MMTSJ: return TSMMOOCInstruction.parseInstruction(str); case Reorg: diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/VariableCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/VariableCPInstruction.java index afc446f7479..9b161ea99d9 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/VariableCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/VariableCPInstruction.java @@ -794,6 +794,10 @@ private void processMoveInstruction(ExecutionContext ec) { // cleanup matrix/frame/list data if necessary if( srcData.getDataType().isMatrix() || srcData.getDataType().isFrame() ) { Data tgtData = ec.removeVariable(getInput2().getName()); + + if (DMLScript.USE_OOC && tgtData instanceof MatrixObject) + TeeOOCInstruction.incrRef(((MatrixObject) tgtData).getStreamable(), -1); + if( tgtData != null && srcData != tgtData ) ec.cleanupDataObject(tgtData); } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/AggregateTernaryOOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/AggregateTernaryOOCInstruction.java index c85e17e4c50..2573ede14e4 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/AggregateTernaryOOCInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/AggregateTernaryOOCInstruction.java @@ -39,13 +39,13 @@ import org.apache.sysds.runtime.matrix.operators.AggregateTernaryOperator; import org.apache.sysds.runtime.matrix.operators.Operator; import org.apache.sysds.runtime.meta.DataCharacteristics; +import org.apache.sysds.runtime.util.IndexRange; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.concurrent.CompletableFuture; -import java.util.function.Function; public class AggregateTernaryOOCInstruction extends ComputationOOCInstruction { @@ -111,17 +111,13 @@ private void processReduceAll(ExecutionContext ec, AggregateTernaryOperator abOp if(qIn3 != null) streams.add(qIn3); - List> keyFns = new ArrayList<>(); - for(int i = 0; i < streams.size(); i++) - keyFns.add(IndexedMatrixValue::getIndexes); - CompletableFuture fut = joinOOC(streams, qMid, blocks -> { MatrixBlock b1 = (MatrixBlock) blocks.get(0).getValue(); MatrixBlock b2 = (MatrixBlock) blocks.get(1).getValue(); MatrixBlock b3 = blocks.size() == 3 ? (MatrixBlock) blocks.get(2).getValue() : null; MatrixBlock partial = MatrixBlock.aggregateTernaryOperations(b1, b2, b3, new MatrixBlock(), abOp, false); return new IndexedMatrixValue(blocks.get(0).getIndexes(), partial); - }, keyFns); + }, IndexedMatrixValue::getIndexes); try { IndexedMatrixValue imv; @@ -159,9 +155,18 @@ private void processReduceRow(ExecutionContext ec, AggregateTernaryOperator abOp if(qIn3 != null) streams.add(qIn3); - List> keyFns = new ArrayList<>(); - for(int i = 0; i < streams.size(); i++) - keyFns.add(IndexedMatrixValue::getIndexes); + for (OOCStream stream : streams) + stream.setDownstreamMessageRelay(qOut::messageDownstream); + + qOut.setUpstreamMessageRelay(msg -> + streams.forEach(stream -> stream.messageUpstream(streams.size() > 1 ? msg.split() : msg))); + + qOut.setIXTransform((downstream, range) -> { + if (downstream) + return new IndexRange(1, 1, range.colStart, range.colEnd); + else + return new IndexRange(1, dc.getRows(), range.colStart, range.colEnd); + }); CompletableFuture fut = joinOOC(streams, qMid, blocks -> { MatrixBlock b1 = (MatrixBlock) blocks.get(0).getValue(); @@ -169,7 +174,7 @@ private void processReduceRow(ExecutionContext ec, AggregateTernaryOperator abOp MatrixBlock b3 = blocks.size() == 3 ? (MatrixBlock) blocks.get(2).getValue() : null; MatrixBlock partial = MatrixBlock.aggregateTernaryOperations(b1, b2, b3, new MatrixBlock(), abOp, false); return new IndexedMatrixValue(blocks.get(0).getIndexes(), partial); - }, keyFns); + }, IndexedMatrixValue::getIndexes); final Map aggMap = new HashMap<>(); final Map corrMap = new HashMap<>(); diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/AggregateUnaryOOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/AggregateUnaryOOCInstruction.java index 54d87dd3f2d..fa884b84d17 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/AggregateUnaryOOCInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/AggregateUnaryOOCInstruction.java @@ -35,6 +35,8 @@ import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator; import org.apache.sysds.runtime.matrix.operators.Operator; import org.apache.sysds.runtime.meta.DataCharacteristics; +import org.apache.sysds.runtime.ooc.stream.StreamContext; +import org.apache.sysds.runtime.util.IndexRange; import java.util.HashMap; @@ -90,6 +92,21 @@ public void processInstruction( ExecutionContext ec ) { ec.getMatrixObject(output).setStreamHandle(qOut); + qIn.setDownstreamMessageRelay(qOut::messageDownstream); + qOut.setUpstreamMessageRelay(qIn::messageUpstream); + qOut.setIXTransform((downstream, range) -> { + if (downstream) { + if (aggun.isRowAggregate()) + return new IndexRange(range.rowStart, range.rowEnd, 1, 1); + else + return new IndexRange(1, 1, range.colStart, range.colEnd); + } + if (aggun.isRowAggregate()) + return new IndexRange(range.rowStart, range.rowEnd, 1, min.getNumColumns() - 1); + else + return new IndexRange(1, min.getNumRows() - 1, range.colStart, range.colEnd); + }); + // per-block aggregation (parallel map) mapOOC(qIn, qLocal, tmp -> { MatrixIndexes midx = aggun.isRowAggregate() ? @@ -134,7 +151,7 @@ public void processInstruction( ExecutionContext ec ) { } } qOut.closeInput(); - }); + }, new StreamContext().addOutStream(qOut)); } // full aggregation else { diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/BinaryOOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/BinaryOOCInstruction.java index 01c7a525bcd..3dfdce26113 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/BinaryOOCInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/BinaryOOCInstruction.java @@ -68,6 +68,12 @@ protected void processMatrixMatrixInstruction(ExecutionContext ec) { OOCStream qIn2 = m2.getStreamHandle(); OOCStream qOut = new SubscribableTaskQueue<>(); ec.getMatrixObject(output).setStreamHandle(qOut); + qIn1.setDownstreamMessageRelay(qOut::messageDownstream); + qIn2.setDownstreamMessageRelay(qOut::messageDownstream); + qOut.setUpstreamMessageRelay(msg -> { + qIn1.messageUpstream(msg.split()); + qIn2.messageUpstream(msg.split()); + }); if (m1.getNumRows() < 0 || m1.getNumColumns() < 0 || m2.getNumRows() < 0 || m2.getNumColumns() < 0) throw new DMLRuntimeException("Cannot process (matrix, matrix) BinaryOOCInstruction with unknown dimensions."); @@ -116,8 +122,6 @@ else if (isRowBroadcast && !isColBroadcast) { return tmpOut; }, IndexedMatrixValue::getIndexes); } - - } protected void processScalarMatrixInstruction(ExecutionContext ec) { @@ -131,6 +135,8 @@ protected void processScalarMatrixInstruction(ExecutionContext ec) { OOCStream qIn = min.getStreamHandle(); OOCStream qOut = createWritableStream(); ec.getMatrixObject(output).setStreamHandle(qOut); + qIn.setDownstreamMessageRelay(qOut::messageDownstream); + qOut.setUpstreamMessageRelay(qIn::messageUpstream); mapOOC(qIn, qOut, tmp -> { IndexedMatrixValue tmpOut = new IndexedMatrixValue(); diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/CSVReblockOOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/CSVReblockOOCInstruction.java index a4f8c497050..4ac54ee3a57 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/CSVReblockOOCInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/CSVReblockOOCInstruction.java @@ -31,6 +31,7 @@ import org.apache.sysds.runtime.io.ReaderTextCSVParallel; import org.apache.sysds.runtime.matrix.operators.Operator; import org.apache.sysds.runtime.meta.DataCharacteristics; +import org.apache.sysds.runtime.ooc.stream.StreamContext; public class CSVReblockOOCInstruction extends ComputationOOCInstruction { private final int blen; @@ -80,7 +81,7 @@ public void processInstruction(ExecutionContext ec) { catch(Exception ex) { throw (ex instanceof DMLRuntimeException) ? (DMLRuntimeException) ex : new DMLRuntimeException(ex); } - }, qOut); + }, new StreamContext().addOutStream(qOut)); MatrixObject mout = ec.getMatrixObject(output); mout.setStreamHandle(qOut); diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/CachingStream.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/CachingStream.java index f9869b20f9a..d3e2b91630f 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/CachingStream.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/CachingStream.java @@ -19,20 +19,28 @@ package org.apache.sysds.runtime.instructions.ooc; +import org.apache.commons.collections4.BidiMap; +import org.apache.commons.collections4.bidimap.DualHashBidiMap; import org.apache.sysds.runtime.DMLRuntimeException; +import org.apache.sysds.runtime.controlprogram.caching.CacheableData; import org.apache.sysds.runtime.controlprogram.parfor.LocalTaskQueue; import org.apache.sysds.runtime.controlprogram.parfor.util.IDSequence; import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue; import org.apache.sysds.runtime.matrix.data.MatrixIndexes; +import org.apache.sysds.runtime.meta.DataCharacteristics; import org.apache.sysds.runtime.ooc.cache.BlockKey; import org.apache.sysds.runtime.ooc.cache.OOCIOHandler; import org.apache.sysds.runtime.ooc.cache.OOCCacheManager; -import org.apache.sysds.runtime.ooc.stream.OOCSourceStream; +import org.apache.sysds.runtime.ooc.stream.SourceOOCStream; +import org.apache.sysds.runtime.ooc.stream.message.OOCGetStreamTypeMessage; +import org.apache.sysds.runtime.ooc.stream.message.OOCStreamMessage; +import org.apache.sysds.runtime.util.IndexRange; import shaded.parquet.it.unimi.dsi.fastutil.ints.IntArrayList; -import java.util.HashMap; -import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CopyOnWriteArrayList; import java.util.concurrent.ExecutionException; +import java.util.function.BiFunction; import java.util.function.Consumer; /** @@ -56,15 +64,16 @@ public class CachingStream implements OOCStreamable { private int _numBlocks = 0; private Consumer>[] _subscribers; + private CopyOnWriteArrayList> _downstreamRelays; // state flags private boolean _cacheInProgress = true; // caching in progress, in the first pass. - private Map _index; + private BidiMap _index; private DMLRuntimeException _failure; - private boolean deletable = false; - private int maxConsumptionCount = 0; + private boolean _deletable = false; + private int _maxConsumptionCount = 0; private String _watchdogId = null; public CachingStream(OOCStream source) { @@ -73,12 +82,14 @@ public CachingStream(OOCStream source) { public CachingStream(OOCStream source, long streamId) { _source = source; + _source.setDownstreamMessageRelay(this::messageDownstream); _streamId = streamId; if (OOCWatchdog.WATCH) { _watchdogId = "CS-" + hashCode(); // Capture a short context to help identify origin OOCWatchdog.registerOpen(_watchdogId, "CachingStream@" + hashCode(), getCtxMsg(), this); } + _downstreamRelays = null; source.setSubscriber(tmp -> { try (tmp) { final IndexedMatrixValue task = tmp.get(); @@ -89,25 +100,26 @@ public CachingStream(OOCStream source, long streamId) { synchronized (this) { mSubscribers = _subscribers; if(task != LocalTaskQueue.NO_MORE_TASKS) { - if (!_cacheInProgress) + if(!_cacheInProgress) throw new DMLRuntimeException("Stream is closed"); OOCIOHandler.SourceBlockDescriptor descriptor = null; - if (_source instanceof OOCSourceStream src) { + if(_source instanceof SourceOOCStream src) { descriptor = src.getDescriptor(task.getIndexes()); } - if (descriptor == null) { - if (mSubscribers == null || mSubscribers.length == 0) + if(descriptor == null) { + if(mSubscribers == null || mSubscribers.length == 0) OOCCacheManager.put(_streamId, _numBlocks, task); else mCallback = OOCCacheManager.putAndPin(_streamId, _numBlocks, task); } else { - if (mSubscribers == null || mSubscribers.length == 0) + if(mSubscribers == null || mSubscribers.length == 0) OOCCacheManager.putSourceBacked(_streamId, _numBlocks, task, descriptor); else - mCallback = OOCCacheManager.putAndPinSourceBacked(_streamId, _numBlocks, task, descriptor); + mCallback = OOCCacheManager.putAndPinSourceBacked(_streamId, _numBlocks, task, + descriptor); } - if (_index != null) + if(_index != null) _index.put(task.getIndexes(), _numBlocks); blk = _numBlocks; _numBlocks++; @@ -116,6 +128,12 @@ public CachingStream(OOCStream source, long streamId) { } else { _cacheInProgress = false; // caching is complete + try { + validateBlockCountOnClose(); + } + catch(Exception e) { + _failure = e instanceof DMLRuntimeException ? (DMLRuntimeException) e : new DMLRuntimeException(e); + } if (OOCWatchdog.WATCH) OOCWatchdog.registerClose(_watchdogId); notifyAll(); @@ -153,7 +171,7 @@ public CachingStream(OOCStream source, long streamId) { } Consumer>[] mSubscribers = _subscribers; - OOCStream.QueueCallback err = OOCStream.eos( _failure); + OOCStream.QueueCallback err = OOCStream.eos(_failure); if(mSubscribers != null) { for(Consumer> mSubscriber : mSubscribers) { try { @@ -181,13 +199,13 @@ private String getCtxMsg() { } public synchronized void scheduleDeletion() { - if (deletable) + if (_deletable) return; // Deletion already scheduled - if (_cacheInProgress && maxConsumptionCount == 0) + if (_cacheInProgress && _maxConsumptionCount == 0) throw new DMLRuntimeException("Cannot have a caching stream with no listeners"); - deletable = true; + _deletable = true; for (int i = 0; i < _consumptionCounts.size(); i++) { tryDeleteBlock(i); } @@ -199,21 +217,21 @@ public String toString() { private synchronized void tryDeleteBlock(int i) { int cnt = _consumptionCounts.getInt(i); - if (cnt > maxConsumptionCount) - throw new DMLRuntimeException("Cannot have more than " + maxConsumptionCount + " consumptions."); - if (cnt == maxConsumptionCount) + if (cnt > _maxConsumptionCount) + throw new DMLRuntimeException("Cannot have more than " + _maxConsumptionCount + " consumptions."); + if (cnt == _maxConsumptionCount) OOCCacheManager.forget(_streamId, i); } private synchronized boolean onConsumed(int blockIdx, int consumerIdx) { int newCount = _consumptionCounts.getInt(blockIdx)+1; - if (newCount > maxConsumptionCount) - throw new DMLRuntimeException("Cannot have more than " + maxConsumptionCount + " consumptions."); + if (newCount > _maxConsumptionCount) + throw new DMLRuntimeException("Cannot have more than " + _maxConsumptionCount + " consumptions."); _consumptionCounts.set(blockIdx, newCount); int newConsumerCount = _consumerConsumptionCounts.getInt(consumerIdx)+1; _consumerConsumptionCounts.set(consumerIdx, newConsumerCount); - if (deletable) + if (_deletable) tryDeleteBlock(blockIdx); return !_cacheInProgress && newConsumerCount == _numBlocks + 1; @@ -225,28 +243,32 @@ private synchronized boolean onNoMoreTasks(int consumerIdx) { return !_cacheInProgress && newConsumerCount == _numBlocks + 1; } - public synchronized OOCStream.QueueCallback get(int idx) throws InterruptedException, + public synchronized CompletableFuture> get(int idx) throws InterruptedException, ExecutionException { while (true) { - if (_failure != null) + if(_failure != null) throw _failure; - else if (idx < _numBlocks) { - OOCStream.QueueCallback out = OOCCacheManager.requestBlock(_streamId, idx).get(); - - if (_index != null) // Ensure index is up to date - _index.putIfAbsent(out.get().getIndexes(), idx); - - int newCount = _consumptionCounts.getInt(idx)+1; - if (newCount > maxConsumptionCount) - throw new DMLRuntimeException("Consumer overflow! Expected: " + maxConsumptionCount); - _consumptionCounts.set(idx, newCount); - - if (deletable) - tryDeleteBlock(idx); - - return out; - } else if (!_cacheInProgress) - return new OOCStream.SimpleQueueCallback<>(null, null); + else if(idx < _numBlocks) { + return OOCCacheManager.requestBlock(_streamId, idx) + .thenApply(cb -> { + synchronized(this) { + if(_index != null) // Ensure index is up to date + _index.putIfAbsent(cb.get().getIndexes(), idx); + + int newCount = _consumptionCounts.getInt(idx) + 1; + if(newCount > _maxConsumptionCount) + throw new DMLRuntimeException("Consumer overflow! Expected: " + _maxConsumptionCount); + _consumptionCounts.set(idx, newCount); + + if(_deletable) + tryDeleteBlock(idx); + } + return cb; + }); + } + else if(!_cacheInProgress) { + return CompletableFuture.completedFuture(new OOCStream.SimpleQueueCallback<>(null, _failure)); + } wait(); } @@ -263,8 +285,9 @@ public synchronized BlockKey peekCachedBlockKey(MatrixIndexes idx) { public synchronized OOCStream.QueueCallback findCached(MatrixIndexes idx) { int mIdx = _index.get(idx); int newCount = _consumptionCounts.getInt(mIdx)+1; - if (newCount > maxConsumptionCount) - throw new DMLRuntimeException("Consumer overflow in " + _streamId + "_" + mIdx + ". Expected: " + maxConsumptionCount); + if (newCount > _maxConsumptionCount) + throw new DMLRuntimeException("Consumer overflow in " + _streamId + "_" + mIdx + ". Expected: " + + _maxConsumptionCount); _consumptionCounts.set(mIdx, newCount); @@ -273,7 +296,7 @@ public synchronized OOCStream.QueueCallback findCached(Matri } catch (InterruptedException | ExecutionException e) { return new OOCStream.SimpleQueueCallback<>(null, new DMLRuntimeException(e)); } finally { - if (deletable) + if (_deletable) tryDeleteBlock(mIdx); } } @@ -283,16 +306,17 @@ public void findCachedAsync(MatrixIndexes idx, Consumer maxConsumptionCount) - throw new DMLRuntimeException("Consumer overflow in " + _streamId + "_" + mIdx + ". Expected: " + maxConsumptionCount); + if (newCount > _maxConsumptionCount) + throw new DMLRuntimeException("Consumer overflow in " + _streamId + "_" + mIdx + ". Expected: " + + _maxConsumptionCount); } OOCCacheManager.requestBlock(_streamId, mIdx).whenComplete((cb, r) -> { try (cb) { synchronized(CachingStream.this) { int newCount = _consumptionCounts.getInt(mIdx) + 1; - if(newCount > maxConsumptionCount) { + if(newCount > _maxConsumptionCount) { _failure = new DMLRuntimeException( - "Consumer overflow in " + _streamId + "_" + mIdx + ". Expected: " + maxConsumptionCount); + "Consumer overflow in " + _streamId + "_" + mIdx + ". Expected: " + _maxConsumptionCount); cb.fail(_failure); } else @@ -304,6 +328,17 @@ public void findCachedAsync(MatrixIndexes idx, Consumer 0) { + long expected = dc.getNumBlocks(); + if (expected >= 0 && _numBlocks != expected) { + throw new DMLRuntimeException("CachingStream block count mismatch: expected " + + expected + " but saw " + _numBlocks + " (" + dc.getRows() + "x" + dc.getCols() + ")"); + } + } + } + /** * Finds a cached item asynchronously without counting it as a consumption. */ @@ -332,7 +367,7 @@ public OOCStream.QueueCallback peekCached(MatrixIndexes idx) public synchronized void activateIndexing() { if (_index == null) - _index = new HashMap<>(); + _index = new DualHashBidiMap<>(); } @Override @@ -350,23 +385,108 @@ public boolean isProcessed() { return false; } + @Override + public DataCharacteristics getDataCharacteristics() { + return _source.getDataCharacteristics(); + } + + @Override + public CacheableData getData() { + return _source.getData(); + } + + @Override + public void setData(CacheableData data) { + _source.setData(data); + } + + @Override + public void messageUpstream(OOCStreamMessage msg) { + if (msg.isCancelled()) + return; + if(msg instanceof OOCGetStreamTypeMessage) { + ((OOCGetStreamTypeMessage) msg).setCachedType(); + activateIndexing(); + return; + } + + _source.messageUpstream(msg); + } + + @Override + public void messageDownstream(OOCStreamMessage msg) { + CopyOnWriteArrayList> relays = _downstreamRelays; + if (relays != null) { + for (Consumer relay : relays) { + if (msg.isCancelled()) + break; + relay.accept(msg); + } + } + } + + @Override + public void setUpstreamMessageRelay(Consumer relay) { + throw new UnsupportedOperationException(); + } + + @Override + public void setDownstreamMessageRelay(Consumer relay) { + addDownstreamMessageRelay(relay); + } + + @Override + public void addUpstreamMessageRelay(Consumer relay) { + throw new UnsupportedOperationException(); + } + + @Override + public void addDownstreamMessageRelay(Consumer relay) { + if (relay == null) + throw new IllegalArgumentException("Cannot set downstream relay to null"); + CopyOnWriteArrayList> relays = _downstreamRelays; + if (relays == null) { + synchronized(this) { + if (_downstreamRelays == null) + _downstreamRelays = new CopyOnWriteArrayList<>(); + relays = _downstreamRelays; + } + } + relays.add(0, relay); + } + + @Override + public void clearUpstreamMessageRelays() { + // No upstream relays supported + } + + @Override + public void clearDownstreamMessageRelays() { + _downstreamRelays = null; + } + + @Override + public void setIXTransform(BiFunction transform) { + throw new UnsupportedOperationException(); + } + public void setSubscriber(Consumer> subscriber, boolean incrConsumers) { - if (deletable) + if(_deletable) throw new DMLRuntimeException("Cannot register a new subscriber on " + this + " because has been flagged for deletion"); - if (_failure != null) + if(_failure != null) throw _failure; int mNumBlocks; boolean cacheInProgress; int consumerIdx; - synchronized (this) { + synchronized(this) { mNumBlocks = _numBlocks; cacheInProgress = _cacheInProgress; consumerIdx = _consumerConsumptionCounts.size(); _consumerConsumptionCounts.add(0); - if (incrConsumers) - maxConsumptionCount++; - if (cacheInProgress) { + if(incrConsumers) + _maxConsumptionCount++; + if(cacheInProgress) { int newLen = _subscribers == null ? 1 : _subscribers.length + 1; Consumer>[] newSubscribers = new Consumer[newLen]; @@ -378,17 +498,21 @@ public void setSubscriber(Consumer> } } - for (int i = 0; i < mNumBlocks; i++) { + for(int i = 0; i < mNumBlocks; i++) { final int idx = i; OOCCacheManager.requestBlock(_streamId, i).whenComplete((cb, r) -> { - try (cb) { + if(r != null) { + subscriber.accept(OOCStream.eos(DMLRuntimeException.of(r))); + return; + } + try(cb) { synchronized(CachingStream.this) { if(_index != null) _index.put(cb.get().getIndexes(), idx); } subscriber.accept(cb); - if (onConsumed(idx, consumerIdx)) + if(onConsumed(idx, consumerIdx)) subscriber.accept(OOCStream.eos(_failure)); // NO_MORE_TASKS } }); @@ -403,10 +527,10 @@ public void setSubscriber(Consumer> * Only use if certain blocks are accessed more than once. */ public synchronized void incrSubscriberCount(int count) { - if (deletable) + if (_deletable) throw new IllegalStateException("Cannot increment the subscriber count if flagged for deletion"); - maxConsumptionCount += count; + _maxConsumptionCount += count; } /** @@ -416,7 +540,7 @@ public synchronized void incrProcessingCount(int i, int count) { int cnt = _consumptionCounts.getInt(i)+count; _consumptionCounts.set(i, cnt); - if (deletable) + if (_deletable) tryDeleteBlock(i); } } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/DataGenOOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/DataGenOOCInstruction.java index 81b3bb7b38d..9489a7d7238 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/DataGenOOCInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/DataGenOOCInstruction.java @@ -36,6 +36,7 @@ import org.apache.sysds.runtime.matrix.data.MatrixIndexes; import org.apache.sysds.runtime.matrix.data.RandomMatrixGenerator; import org.apache.sysds.runtime.matrix.operators.UnaryOperator; +import org.apache.sysds.runtime.ooc.stream.StreamContext; import org.apache.sysds.runtime.util.UtilFunctions; public class DataGenOOCInstruction extends UnaryOOCInstruction { @@ -259,7 +260,7 @@ else if(method == Types.OpOpDG.SEQ) { } qOut.closeInput(); - }, qOut); + }, new StreamContext().addOutStream(qOut)); } else throw new NotImplementedException(); diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/MMultOOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/MMultOOCInstruction.java new file mode 100644 index 00000000000..16338888a54 --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/MMultOOCInstruction.java @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.instructions.ooc; + +import org.apache.sysds.common.Opcodes; +import org.apache.sysds.runtime.controlprogram.caching.MatrixObject; +import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; +import org.apache.sysds.runtime.functionobjects.Multiply; +import org.apache.sysds.runtime.functionobjects.Plus; +import org.apache.sysds.runtime.instructions.InstructionUtils; +import org.apache.sysds.runtime.instructions.cp.CPOperand; +import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.matrix.data.MatrixIndexes; +import org.apache.sysds.runtime.matrix.operators.AggregateBinaryOperator; +import org.apache.sysds.runtime.matrix.operators.AggregateOperator; +import org.apache.sysds.runtime.matrix.operators.BinaryOperator; +import org.apache.sysds.runtime.matrix.operators.Operator; + +public class MMultOOCInstruction extends ComputationOOCInstruction { + + + protected MMultOOCInstruction(OOCType type, Operator op, CPOperand in1, CPOperand in2, CPOperand out, String opcode, String istr) { + super(type, op, in1, in2, out, opcode, istr); + } + + public static MMultOOCInstruction parseInstruction(String str) { + String[] parts = InstructionUtils.getInstructionPartsWithValueType(str); + InstructionUtils.checkNumFields(parts, 4); + String opcode = parts[0]; + CPOperand in1 = new CPOperand(parts[1]); // the larget matrix (streamed) + CPOperand in2 = new CPOperand(parts[2]); // the small vector (in-memory) + CPOperand out = new CPOperand(parts[3]); + + AggregateOperator agg = new AggregateOperator(0, Plus.getPlusFnObject()); + AggregateBinaryOperator ba = new AggregateBinaryOperator(Multiply.getMultiplyFnObject(), agg); + + return new MMultOOCInstruction(OOCType.MAPMM, ba, in1, in2, out, opcode, str); + } + + @Override + public void processInstruction( ExecutionContext ec ) { + // 1. Identify the inputs + MatrixObject min = ec.getMatrixObject(input1); // big matrix + MatrixObject vin = ec.getMatrixObject(input2); // streamed vector + + int emitLeftThreshold = (int)vin.getDataCharacteristics().getNumColBlocks(); + int emitRightThreshold = (int)min.getDataCharacteristics().getNumRowBlocks(); + + OOCStream intermediateStream = createWritableStream(); + OOCStream outStream = createWritableStream(); + ec.getMatrixObject(output).setStreamHandle(outStream); + + joinManyOOC(min.getStreamHandle(), vin.getStreamHandle(), intermediateStream, + (left, right) -> { + MatrixBlock leftBlock = (MatrixBlock) left.getValue(); + MatrixBlock rightBlock = (MatrixBlock) right.getValue(); + MatrixBlock partialResult = leftBlock.aggregateBinaryOperations(leftBlock, rightBlock, + new MatrixBlock(), (AggregateBinaryOperator) _optr); + return new IndexedMatrixValue(new MatrixIndexes(left.getIndexes().getRowIndex(), right.getIndexes().getColumnIndex()), partialResult); + }, + tmp -> tmp.getIndexes().getColumnIndex(), + tmp -> tmp.getIndexes().getRowIndex(), + emitLeftThreshold, emitRightThreshold); + + BinaryOperator plus = InstructionUtils.parseBinaryOperator(Opcodes.PLUS.toString()); + int emitAggThreshold = (int)min.getDataCharacteristics().getNumColBlocks(); + + groupedReduceOOC(intermediateStream, outStream, (left, right) -> { + MatrixBlock mb = ((MatrixBlock)left.getValue()).binaryOperationsInPlace(plus, right.getValue()); + left.setValue(mb); + return left; + }, emitAggThreshold); + } +} diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/MapMMChainOOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/MapMMChainOOCInstruction.java new file mode 100644 index 00000000000..2fd5585edd3 --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/MapMMChainOOCInstruction.java @@ -0,0 +1,270 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.instructions.ooc; + +import java.util.concurrent.CompletableFuture; + +import org.apache.sysds.common.Opcodes; +import org.apache.sysds.lops.MapMultChain.ChainType; +import org.apache.sysds.runtime.DMLRuntimeException; +import org.apache.sysds.runtime.controlprogram.caching.MatrixObject; +import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; +import org.apache.sysds.runtime.data.DenseBlock; +import org.apache.sysds.runtime.data.SparseBlock; +import org.apache.sysds.runtime.functionobjects.Multiply; +import org.apache.sysds.runtime.functionobjects.Plus; +import org.apache.sysds.runtime.instructions.InstructionUtils; +import org.apache.sysds.runtime.instructions.cp.CPOperand; +import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.matrix.data.MatrixIndexes; +import org.apache.sysds.runtime.matrix.operators.AggregateBinaryOperator; +import org.apache.sysds.runtime.matrix.operators.AggregateOperator; +import org.apache.sysds.runtime.matrix.operators.BinaryOperator; +import org.apache.sysds.runtime.matrix.operators.Operator; +import org.apache.sysds.runtime.matrix.operators.RightScalarOperator; + +public class MapMMChainOOCInstruction extends ComputationOOCInstruction { + private final ChainType _type; + + protected MapMMChainOOCInstruction(OOCType type, Operator op, CPOperand in1, CPOperand in2, CPOperand in3, + CPOperand out, ChainType chainType, String opcode, String istr) { + super(type, op, in1, in2, in3, out, opcode, istr); + _type = chainType; + } + + public static MapMMChainOOCInstruction parseInstruction(String str) { + String[] parts = InstructionUtils.getInstructionPartsWithValueType(str); + InstructionUtils.checkNumFields(parts, 4, 5); + String opcode = parts[0]; + CPOperand in1 = new CPOperand(parts[1]); + CPOperand in2 = new CPOperand(parts[2]); + + if (parts.length == 5) { + CPOperand out = new CPOperand(parts[3]); + ChainType type = ChainType.valueOf(parts[4]); + return new MapMMChainOOCInstruction(OOCType.MAPMMCHAIN, null, in1, in2, null, out, type, opcode, str); + } + else { //parts.length==6 + CPOperand in3 = new CPOperand(parts[3]); + CPOperand out = new CPOperand(parts[4]); + ChainType type = ChainType.valueOf(parts[5]); + return new MapMMChainOOCInstruction(OOCType.MAPMMCHAIN, null, in1, in2, in3, out, type, opcode, str); + } + } + + @Override + public void processInstruction(ExecutionContext ec) { + MatrixObject min = ec.getMatrixObject(input1); + MatrixObject mv = ec.getMatrixObject(input2); + OOCStream qV = mv.getStreamHandle(); + + OOCStream qOut = createWritableStream(); + addOutStream(qOut); + ec.getMatrixObject(output).setStreamHandle(qOut); + + OOCStream qInX = min.getStreamHandle(); + boolean createdCache = !qInX.hasStreamCache(); + CachingStream xCache = createdCache ? new CachingStream(qInX) : qInX.getStreamCache(); + + long numRowBlocksL = min.getDataCharacteristics().getNumRowBlocks(); + long numColBlocksL = min.getDataCharacteristics().getNumColBlocks(); + int numRowBlocks = Math.toIntExact(numRowBlocksL); + int numColBlocks = Math.toIntExact(numColBlocksL); + long vRows = mv.getDataCharacteristics().getRows(); + + AggregateOperator agg = new AggregateOperator(0, Plus.getPlusFnObject()); + AggregateBinaryOperator mmOp = new AggregateBinaryOperator(Multiply.getMultiplyFnObject(), agg); + BinaryOperator plus = InstructionUtils.parseBinaryOperator(Opcodes.PLUS.toString()); + + boolean hasV = !mv.getDataCharacteristics().rowsKnown() || vRows > 0; + if(!hasV && _type != ChainType.XtXvy) + throw new DMLRuntimeException("MMChain requires non-empty v for chain type " + _type); + + OOCStream qU; + CompletableFuture uFuture; + + if(!hasV && _type == ChainType.XtXvy) { + MatrixObject mw = ec.getMatrixObject(input3); + OOCStream qW = mw.getStreamHandle(); + OOCStream qNegW = createWritableStream(); + RightScalarOperator negOp = new RightScalarOperator(Multiply.getMultiplyFnObject(), -1); + + uFuture = mapOOC(qW, qNegW, tmp -> { + MatrixBlock wBlock = (MatrixBlock) tmp.getValue(); + MatrixBlock neg = wBlock.scalarOperations(negOp, new MatrixBlock()); + return new IndexedMatrixValue(new MatrixIndexes(tmp.getIndexes().getRowIndex(), 1L), neg); + }); + qU = qNegW; + } + else { + OOCStream qPartialXv = createWritableStream(); + OOCStream qXv = createWritableStream(); + OOCStream qInXv = xCache.getReadStream(); + + CompletableFuture mapXvFuture = broadcastJoinOOC(qInXv, qV, qPartialXv, (x, v) -> { + MatrixBlock xBlock = (MatrixBlock) x.getValue(); + MatrixBlock vBlock = (MatrixBlock) v.getValue().getValue(); + MatrixBlock partial = xBlock.aggregateBinaryOperations(xBlock, vBlock, new MatrixBlock(), mmOp); + return new IndexedMatrixValue(new MatrixIndexes(x.getIndexes().getRowIndex(), 1L), partial); + }, tmp -> tmp.getIndexes().getColumnIndex(), tmp -> tmp.getIndexes().getRowIndex()); + + CompletableFuture reduceXvFuture = groupedReduceOOC(qPartialXv, qXv, (left, right) -> { + MatrixBlock mb = ((MatrixBlock) left.getValue()).binaryOperationsInPlace(plus, right.getValue()); + left.setValue(mb); + return left; + }, numColBlocks); + + if(_type.isWeighted()) { + MatrixObject mw = ec.getMatrixObject(input3); + OOCStream qW = mw.getStreamHandle(); + OOCStream qWeighted = createWritableStream(); + BinaryOperator weightOp = InstructionUtils.parseBinaryOperator( + _type == ChainType.XtwXv ? Opcodes.MULT.toString() : Opcodes.MINUS.toString()); + + uFuture = broadcastJoinOOC(qXv, qW, qWeighted, (u, w) -> { + MatrixBlock uBlock = (MatrixBlock) u.getValue(); + MatrixBlock wBlock = (MatrixBlock) w.getValue().getValue(); + MatrixBlock updated = uBlock.binaryOperationsInPlace(weightOp, wBlock); + u.setValue(updated); + return u; + }, tmp -> tmp.getIndexes().getRowIndex(), tmp -> tmp.getIndexes().getRowIndex()); + qU = qWeighted; + } + else { + uFuture = reduceXvFuture; + qU = qXv; + } + + mapXvFuture.exceptionally(err -> { + qOut.propagateFailure(DMLRuntimeException.of(err)); + return null; + }); + } + + OOCStream qInXt = xCache.getReadStream(); + OOCStream qPartialXt = createWritableStream(); + CompletableFuture joinXtFuture = broadcastJoinOOC(qInXt, qU, qPartialXt, (x, u) -> { + MatrixBlock xBlock = (MatrixBlock) x.getValue(); + MatrixBlock uBlock = (MatrixBlock) u.getValue().getValue(); + MatrixBlock partial = multTransposeVector(xBlock, uBlock); + return new IndexedMatrixValue(new MatrixIndexes(x.getIndexes().getColumnIndex(), 1L), partial); + }, tmp -> tmp.getIndexes().getRowIndex(), tmp -> tmp.getIndexes().getRowIndex()); + + CompletableFuture outFuture = groupedReduceOOC(qPartialXt, qOut, (left, right) -> { + MatrixBlock mb = ((MatrixBlock) left.getValue()).binaryOperationsInPlace(plus, right.getValue()); + left.setValue(mb); + return left; + }, numRowBlocks); + + outFuture.whenComplete((res, err) -> { + if(createdCache) + xCache.scheduleDeletion(); + }); + + uFuture.exceptionally(err -> { + qOut.propagateFailure(DMLRuntimeException.of(err)); + return null; + }); + joinXtFuture.exceptionally(err -> { + qOut.propagateFailure(DMLRuntimeException.of(err)); + return null; + }); + outFuture.exceptionally(err -> { + qOut.propagateFailure(DMLRuntimeException.of(err)); + return null; + }); + } + + private static MatrixBlock multTransposeVector(MatrixBlock x, MatrixBlock u) { + int rows = x.getNumRows(); + int cols = x.getNumColumns(); + MatrixBlock out = new MatrixBlock(cols, 1, false); + out.allocateDenseBlock(); + double[] outVals = out.getDenseBlockValues(); + + if(x.isInSparseFormat()) { + SparseBlock a = x.getSparseBlock(); + if(a != null) { + if(u.isInSparseFormat()) { + for(int i = 0; i < rows; i++) { + if(a.isEmpty(i)) + continue; + double uval = u.get(i, 0); + if(uval == 0) + continue; + int apos = a.pos(i); + int alen = a.size(i); + int[] aix = a.indexes(i); + double[] avals = a.values(i); + for(int k = apos; k < apos + alen; k++) + outVals[aix[k]] += uval * avals[k]; + } + } + else { + double[] uvals = u.getDenseBlockValues(); + for(int i = 0; i < rows; i++) { + if(a.isEmpty(i)) + continue; + double uval = uvals[i]; + if(uval == 0) + continue; + int apos = a.pos(i); + int alen = a.size(i); + int[] aix = a.indexes(i); + double[] avals = a.values(i); + for(int k = apos; k < apos + alen; k++) + outVals[aix[k]] += uval * avals[k]; + } + } + } + } + else { + DenseBlock a = x.getDenseBlock(); + if(u.isInSparseFormat()) { + for(int i = 0; i < rows; i++) { + double uval = u.get(i, 0); + if(uval == 0) + continue; + double[] avals = a.values(i); + int apos = a.pos(i); + for(int j = 0; j < cols; j++) + outVals[j] += uval * avals[apos + j]; + } + } + else { + double[] uvals = u.getDenseBlockValues(); + for(int i = 0; i < rows; i++) { + double uval = uvals[i]; + if(uval == 0) + continue; + double[] avals = a.values(i); + int apos = a.pos(i); + for(int j = 0; j < cols; j++) + outVals[j] += uval * avals[apos + j]; + } + } + } + + out.recomputeNonZeros(); + out.examSparsity(); + return out; + } +} diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/MatrixIndexingOOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/MatrixIndexingOOCInstruction.java index fa0d0df55d3..116e65302f1 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/MatrixIndexingOOCInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/MatrixIndexingOOCInstruction.java @@ -32,6 +32,7 @@ import org.apache.sysds.runtime.matrix.data.MatrixIndexes; import org.apache.sysds.runtime.util.IndexRange; +import java.util.Optional; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicInteger; @@ -83,18 +84,49 @@ public void processInstruction(ExecutionContext ec) { throw new DMLRuntimeException("Desired block not found"); } + qIn.setDownstreamMessageRelay(qOut::messageDownstream); + qOut.setUpstreamMessageRelay(qIn::messageUpstream); + qOut.setIXTransform((downstream, range) -> { + if(downstream){ + long rs = range.rowStart-ix.rowStart+1; + long re = range.rowEnd-ix.rowStart+1; + long cs = range.colStart-ix.colStart+1; + long ce = range.colEnd-ix.colStart+1; + // TODO What happens if range is out of bounds? + rs = Math.max(1, rs); + cs = Math.max(1, cs); + re = Math.min(ix.rowSpan(), re); + ce = Math.min(ix.colSpan(), ce); + return new IndexRange(rs, re, cs, ce); + } + else{ + long rs = range.rowStart+ix.rowStart; + long re = range.rowEnd+ix.rowStart; + long cs = range.colStart+ix.colStart; + long ce = range.colEnd+ix.colStart; + return new IndexRange(rs, re, cs, ce); + } + }); + if(ix.rowStart % blocksize == 0 && ix.colStart % blocksize == 0) { // Aligned case: interior blocks can be forwarded directly, borders may require slicing final int outBlockRows = (int) Math.ceil((double) (ix.rowSpan() + 1) / blocksize); final int outBlockCols = (int) Math.ceil((double) (ix.colSpan() + 1) / blocksize); final int totalBlocks = outBlockRows * outBlockCols; + final boolean isCached = qIn.hasStreamCache(); final AtomicInteger producedBlocks = new AtomicInteger(0); CompletableFuture future = new CompletableFuture<>(); - filterOOC(qIn, tmp -> { - MatrixIndexes inIdx = tmp.getIndexes(); - long blockRow = inIdx.getRowIndex() - 1; - long blockCol = inIdx.getColumnIndex() - 1; + mapOptionalOOC(qIn, qOut, tmp -> { + if (future.isDone()) + return Optional.empty(); + + long blockRow = tmp.getIndexes().getRowIndex() - 1; + long blockCol = tmp.getIndexes().getColumnIndex() - 1; + boolean within = blockRow >= firstBlockRow && blockRow <= lastBlockRow && + blockCol >= firstBlockCol && blockCol <= lastBlockCol; + if(!within) + return Optional.empty(); MatrixBlock block = (MatrixBlock) tmp.getValue(); @@ -108,7 +140,8 @@ public void processInstruction(ExecutionContext ec) { MatrixBlock outBlock; if(rowStartLocal == 0 && rowEndLocal == block.getNumRows() - 1 && colStartLocal == 0 && colEndLocal == block.getNumColumns() - 1) { - outBlock = block; + // If the block is cached, we need to copy because otherwise it could lead to nullpointers + outBlock = isCached ? new MatrixBlock(block) : block; } else { outBlock = block.slice(rowStartLocal, rowEndLocal, colStartLocal, colEndLocal); @@ -116,19 +149,11 @@ public void processInstruction(ExecutionContext ec) { long outBlockRow = blockRow - firstBlockRow + 1; long outBlockCol = blockCol - firstBlockCol + 1; - qOut.enqueue(new IndexedMatrixValue(new MatrixIndexes(outBlockRow, outBlockCol), outBlock)); if(producedBlocks.incrementAndGet() >= totalBlocks) future.complete(null); - }, tmp -> { - if (future.isDone()) // Then we may skip blocks and avoid submitting tasks - return false; - - long blockRow = tmp.getIndexes().getRowIndex() - 1; - long blockCol = tmp.getIndexes().getColumnIndex() - 1; - return blockRow >= firstBlockRow && blockRow <= lastBlockRow && blockCol >= firstBlockCol && - blockCol <= lastBlockCol; - }, qOut::closeInput); + return Optional.of(new IndexedMatrixValue(new MatrixIndexes(outBlockRow, outBlockCol), outBlock)); + }); return; } @@ -137,18 +162,28 @@ public void processInstruction(ExecutionContext ec) { // We may need to construct our own intermediate stream to properly manage the cached items boolean hasIntermediateStream = !qIn.hasStreamCache(); - final CachingStream cachedStream = hasIntermediateStream ? new CachingStream(new SubscribableTaskQueue<>()) : qOut.getStreamCache(); - cachedStream.activateIndexing(); - cachedStream.incrSubscriberCount(1); // We may require re-consumption of blocks (up to 4 times) final CompletableFuture future = new CompletableFuture<>(); - filterOOC(qIn.getReadStream(), tmp -> { - if (hasIntermediateStream) { - // We write to an intermediate stream to ensure that these matrix blocks are properly cached - cachedStream.getWriteStream().enqueue(tmp); - } + OOCStream filteredStream = filteredOOCStream(qIn, tmp -> { + boolean pass = !future.isDone(); + // Pre-filter incoming blocks to avoid unnecessary task submission + long blockRow = tmp.getIndexes().getRowIndex() - 1; + long blockCol = tmp.getIndexes().getColumnIndex() - 1; + pass &= blockRow >= firstBlockRow && blockRow <= lastBlockRow && blockCol >= firstBlockCol && + blockCol <= lastBlockCol; + + if(!pass && !hasIntermediateStream) + qIn.getStreamCache().incrProcessingCount(qIn.getStreamCache().findCachedIndex(tmp.getIndexes()), 1); + return pass; + }); + + final CachingStream cachedStream = hasIntermediateStream ? new CachingStream(filteredStream) : qIn.getStreamCache(); + cachedStream.activateIndexing(); + cachedStream.incrSubscriberCount(1); // We may require re-consumption of blocks (up to 4 times) + OOCStream readStream = cachedStream.getReadStream(); - boolean completed = aligner.putNext(tmp.getIndexes(), tmp.getIndexes(), (idx, sector) -> { + submitOOCTasks(readStream, tmp -> { + boolean completed = aligner.putNext(tmp.get().getIndexes(), tmp.get().getIndexes(), (idx, sector) -> { int targetBlockRow = (int) (idx.getRowIndex() - 1); int targetBlockCol = (int) (idx.getColumnIndex() - 1); @@ -226,23 +261,15 @@ public void processInstruction(ExecutionContext ec) { if(completed) future.complete(null); - }, tmp -> { - if (future.isDone()) // Then we may skip blocks and avoid submitting tasks - return false; - - // Pre-filter incoming blocks to avoid unnecessary task submission - long blockRow = tmp.getIndexes().getRowIndex() - 1; - long blockCol = tmp.getIndexes().getColumnIndex() - 1; - return blockRow >= firstBlockRow && blockRow <= lastBlockRow && blockCol >= firstBlockCol && - blockCol <= lastBlockCol; - }, () -> { - aligner.close(); - qOut.closeInput(); - }, tmp -> { - // If elements are not processed in an existing caching stream, we increment the process counter to allow block deletion - if (!hasIntermediateStream) - cachedStream.incrProcessingCount(cachedStream.findCachedIndex(tmp.getIndexes()), 1); - }); + }) + .thenRun(() -> { + aligner.close(); + qOut.closeInput(); + }) + .exceptionally(err -> { + qOut.propagateFailure(DMLRuntimeException.of(err)); + return null; + }); if (hasIntermediateStream) cachedStream.scheduleDeletion(); // We can immediately delete blocks after consumption diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/MatrixVectorBinaryOOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/MatrixVectorBinaryOOCInstruction.java deleted file mode 100644 index b0c08db2dca..00000000000 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/MatrixVectorBinaryOOCInstruction.java +++ /dev/null @@ -1,171 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.sysds.runtime.instructions.ooc; - -import java.util.HashMap; - -import org.apache.sysds.common.Opcodes; -import org.apache.sysds.conf.ConfigurationManager; -import org.apache.sysds.runtime.controlprogram.caching.MatrixObject; -import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; -import org.apache.sysds.runtime.functionobjects.Multiply; -import org.apache.sysds.runtime.functionobjects.Plus; -import org.apache.sysds.runtime.instructions.InstructionUtils; -import org.apache.sysds.runtime.instructions.cp.CPOperand; -import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue; -import org.apache.sysds.runtime.matrix.data.MatrixBlock; -import org.apache.sysds.runtime.matrix.data.MatrixIndexes; -import org.apache.sysds.runtime.matrix.operators.AggregateBinaryOperator; -import org.apache.sysds.runtime.matrix.operators.AggregateOperator; -import org.apache.sysds.runtime.matrix.operators.BinaryOperator; -import org.apache.sysds.runtime.matrix.operators.Operator; - -public class MatrixVectorBinaryOOCInstruction extends ComputationOOCInstruction { - - - protected MatrixVectorBinaryOOCInstruction(OOCType type, Operator op, CPOperand in1, CPOperand in2, CPOperand out, String opcode, String istr) { - super(type, op, in1, in2, out, opcode, istr); - } - - public static MatrixVectorBinaryOOCInstruction parseInstruction(String str) { - String[] parts = InstructionUtils.getInstructionPartsWithValueType(str); - InstructionUtils.checkNumFields(parts, 4); - String opcode = parts[0]; - CPOperand in1 = new CPOperand(parts[1]); // the larget matrix (streamed) - CPOperand in2 = new CPOperand(parts[2]); // the small vector (in-memory) - CPOperand out = new CPOperand(parts[3]); - - AggregateOperator agg = new AggregateOperator(0, Plus.getPlusFnObject()); - AggregateBinaryOperator ba = new AggregateBinaryOperator(Multiply.getMultiplyFnObject(), agg); - - return new MatrixVectorBinaryOOCInstruction(OOCType.MAPMM, ba, in1, in2, out, opcode, str); - } - - @Override - public void processInstruction( ExecutionContext ec ) { - // 1. Identify the inputs - MatrixObject min = ec.getMatrixObject(input1); // big matrix - MatrixBlock vin = ec.getMatrixObject(input2) - .acquireReadAndRelease(); // in-memory vector - - // 2. Pre-partition the in-memory vector into a hashmap - HashMap partitionedVector = new HashMap<>(); - int blksize = vin.getDataCharacteristics().getBlocksize(); - if (blksize < 0) - blksize = ConfigurationManager.getBlocksize(); - for (int i=0; i qIn = min.getStreamHandle(); - OOCStream qOut = createWritableStream(); - BinaryOperator plus = InstructionUtils.parseBinaryOperator(Opcodes.PLUS.toString()); - addOutStream(qOut); - ec.getMatrixObject(output).setStreamHandle(qOut); - final Object lock = new Object(); - - submitOOCTasks(qIn, cb -> { - try(cb) { - IndexedMatrixValue tmp = cb.get(); - MatrixBlock matrixBlock = (MatrixBlock) tmp.getValue(); - long rowIndex = tmp.getIndexes().getRowIndex(); - long colIndex = tmp.getIndexes().getColumnIndex(); - MatrixBlock vectorSlice = partitionedVector.get(colIndex); - - // Now, call the operation with the correct, specific operator. - MatrixBlock partialResult = matrixBlock.aggregateBinaryOperations(matrixBlock, vectorSlice, - new MatrixBlock(), (AggregateBinaryOperator) _optr); - - // for single column block, no aggregation neeeded - if(emitThreshold == 1) { - qOut.enqueue(new IndexedMatrixValue(tmp.getIndexes(), partialResult)); - } - else { - // aggregation - synchronized(lock) { - MatrixBlock currAgg = aggTracker.get(rowIndex); - if(currAgg == null) { - aggTracker.putAndIncrementCount(rowIndex, partialResult); - } - else { - currAgg = currAgg.binaryOperations(plus, partialResult); - if(aggTracker.putAndIncrementCount(rowIndex, currAgg)) { - // early block output: emit aggregated block - MatrixIndexes idx = new MatrixIndexes(rowIndex, 1L); - qOut.enqueue(new IndexedMatrixValue(idx, currAgg)); - aggTracker.remove(rowIndex); - } - } - } - } - } - }, qOut::closeInput); - - /*submitOOCTask(() -> { - IndexedMatrixValue tmp = null; - try { - while((tmp = qIn.dequeue()) != LocalTaskQueue.NO_MORE_TASKS) { - MatrixBlock matrixBlock = (MatrixBlock) tmp.getValue(); - long rowIndex = tmp.getIndexes().getRowIndex(); - long colIndex = tmp.getIndexes().getColumnIndex(); - MatrixBlock vectorSlice = partitionedVector.get(colIndex); - - // Now, call the operation with the correct, specific operator. - MatrixBlock partialResult = matrixBlock.aggregateBinaryOperations( - matrixBlock, vectorSlice, new MatrixBlock(), (AggregateBinaryOperator) _optr); - - // for single column block, no aggregation neeeded - if(emitThreshold == 1) { - qOut.enqueue(new IndexedMatrixValue(tmp.getIndexes(), partialResult)); - } - else { - // aggregation - MatrixBlock currAgg = aggTracker.get(rowIndex); - if (currAgg == null) { - aggTracker.putAndIncrementCount(rowIndex, partialResult); - } - else { - currAgg = currAgg.binaryOperations(plus, partialResult); - if (aggTracker.putAndIncrementCount(rowIndex, currAgg)){ - // early block output: emit aggregated block - MatrixIndexes idx = new MatrixIndexes(rowIndex, 1L); - qOut.enqueue(new IndexedMatrixValue(idx, currAgg)); - aggTracker.remove(rowIndex); - } - } - } - } - } - catch(Exception ex) { - throw new DMLRuntimeException(ex); - } - finally { - qOut.closeInput(); - } - }, qIn, qOut);*/ - } -} diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java index 5a4ae19b613..9ce4c0eb9c4 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java @@ -30,24 +30,30 @@ import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.matrix.data.MatrixIndexes; import org.apache.sysds.runtime.matrix.operators.Operator; +import org.apache.sysds.runtime.ooc.cache.BlockEntry; +import org.apache.sysds.runtime.ooc.cache.BlockKey; import org.apache.sysds.runtime.ooc.cache.OOCCacheManager; import org.apache.sysds.runtime.ooc.stats.OOCEventLog; +import org.apache.sysds.runtime.ooc.stream.FilteredOOCStream; +import org.apache.sysds.runtime.ooc.stream.StreamContext; +import org.apache.sysds.runtime.ooc.stream.TaskContext; import org.apache.sysds.runtime.util.CommonThreadPool; -import org.apache.sysds.runtime.util.OOCJoin; import org.apache.sysds.utils.Statistics; +import scala.Tuple2; import scala.Tuple4; +import scala.Tuple5; import java.util.ArrayList; -import java.util.Arrays; import java.util.Collections; import java.util.HashMap; -import java.util.HashSet; +import java.util.LinkedList; import java.util.List; import java.util.Map; -import java.util.Set; +import java.util.Optional; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ExecutorService; +import java.util.concurrent.ForkJoinTask; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.LongAdder; @@ -55,23 +61,23 @@ import java.util.function.BiFunction; import java.util.function.Consumer; import java.util.function.Function; -import java.util.stream.Stream; public abstract class OOCInstruction extends Instruction { + public static final ExecutorService COMPUTE_EXECUTOR = CommonThreadPool.get(); + private static final AtomicInteger COMPUTE_IN_FLIGHT = new AtomicInteger(0); + private static final int COMPUTE_BACKPRESSURE_THRESHOLD = 100; protected static final Log LOG = LogFactory.getLog(OOCInstruction.class.getName()); private static final AtomicInteger nextStreamId = new AtomicInteger(0); private long nanoTime; public enum OOCType { Reblock, Tee, Binary, Ternary, Unary, AggregateUnary, AggregateBinary, AggregateTernary, MAPMM, MMTSJ, - Reorg, CM, Ctable, MatrixIndexing, ParameterizedBuiltin, Rand + MAPMMCHAIN, Reorg, CM, Ctable, MatrixIndexing, ParameterizedBuiltin, Rand } protected final OOCInstruction.OOCType _ooctype; protected final boolean _requiresLabelUpdate; - protected Set> _inQueues; - protected Set> _outQueues; - private boolean _failed; + protected StreamContext _streamContext; private LongAdder _localStatisticsAdder; public final int _callerId; @@ -86,13 +92,20 @@ protected OOCInstruction(OOCInstruction.OOCType type, Operator op, String opcode instOpcode = opcode; _requiresLabelUpdate = super.requiresLabelUpdate(); - _failed = false; if (DMLScript.STATISTICS) _localStatisticsAdder = new LongAdder(); _callerId = DMLScript.OOC_LOG_EVENTS ? OOCEventLog.registerCaller(getExtendedOpcode() + "_" + hashCode()) : 0; } + public static int getComputeInFlight() { + return COMPUTE_IN_FLIGHT.get(); + } + + public static int getComputeBackpressureThreshold() { + return COMPUTE_BACKPRESSURE_THRESHOLD; + } + @Override public IType getType() { return IType.OUT_OF_CORE; @@ -125,6 +138,7 @@ public Instruction preprocessInstruction(ExecutionContext ec) { @Override public void postprocessInstruction(ExecutionContext ec) { + _streamContext = null; if(DMLScript.LINEAGE_DEBUGGER) ec.maintainLineageDebuggerInfo(this); if (DMLScript.OOC_LOG_EVENTS) @@ -132,116 +146,161 @@ public void postprocessInstruction(ExecutionContext ec) { } protected void addInStream(OOCStream... queue) { - if (_inQueues == null) - _inQueues = new HashSet<>(); - _inQueues.addAll(List.of(queue)); + if(_streamContext == null) + _streamContext = new StreamContext(); + _streamContext.addInStream(queue); } protected void addOutStream(OOCStream... queue) { - if (queue.length == 0 && _outQueues == null) { - _outQueues = Collections.emptySet(); - return; - } + if(_streamContext == null) + _streamContext = new StreamContext(); + _streamContext.addOutStream(queue); + } - if (_outQueues == null || _outQueues.isEmpty()) - _outQueues = new HashSet<>(); - _outQueues.addAll(List.of(queue)); + protected boolean inStreamsDefined() { + return _streamContext != null && _streamContext.inStreamsDefined(); + } + + protected boolean outStreamsDefined() { + return _streamContext != null && _streamContext.outStreamsDefined(); } protected OOCStream createWritableStream() { return new SubscribableTaskQueue<>(); } - protected CompletableFuture filterOOC(OOCStream qIn, Consumer processor, Function predicate, Runnable finalizer) { - return filterOOC(qIn, processor, predicate, finalizer, null); + protected CompletableFuture filterOOC(OOCStream qIn, Consumer processor, Function predicate) { + return filterOOC(qIn, processor, predicate, null); } - protected CompletableFuture filterOOC(OOCStream qIn, Consumer processor, Function predicate, Runnable finalizer, Consumer onNotProcessed) { - if (_inQueues == null || _outQueues == null) + protected CompletableFuture filterOOC(OOCStream qIn, Consumer processor, Function predicate, Consumer onNotProcessed) { + if (!inStreamsDefined() || !outStreamsDefined()) throw new NotImplementedException("filterOOC requires manual specification of all input and output streams for error propagation"); - return submitOOCTasks(qIn, c -> processor.accept(c.get()), finalizer, p -> predicate.apply(p.get()), onNotProcessed != null ? (i, tmp) -> onNotProcessed.accept(tmp.get()) : null); + return submitOOCTasks(qIn, c -> processor.accept(c.get()), p -> predicate.apply(p.get()), onNotProcessed != null ? (i, tmp) -> onNotProcessed.accept(tmp.get()) : null); + } + + protected OOCStream filteredOOCStream(OOCStream qIn, Function predicate) { + return new FilteredOOCStream<>(qIn, predicate); } protected CompletableFuture mapOOC(OOCStream qIn, OOCStream qOut, Function mapper) { + return mapOptionalOOC(qIn, qOut, tmp -> Optional.of(mapper.apply(tmp))); + } + + protected CompletableFuture mapOptionalOOC(OOCStream qIn, OOCStream qOut, Function> optionalMapper) { addInStream(qIn); addOutStream(qOut); - return submitOOCTasks(qIn, tmp -> { - try (tmp) { - R r = mapper.apply(tmp.get()); - qOut.enqueue(r); - } catch (Exception e) { + AtomicInteger deferredCtr = new AtomicInteger(1); + CompletableFuture future = new CompletableFuture<>(); + + Consumer> exec = tmp -> { + Optional r; + try(tmp) { + r = optionalMapper.apply(tmp.get()); + } + catch(Exception e) { throw e instanceof DMLRuntimeException ? (DMLRuntimeException) e : new DMLRuntimeException(e); } - }, qOut::closeInput); + r.ifPresent(t -> { + deferredCtr.incrementAndGet(); + // Defer to clean the stack of large objects + TaskContext.defer(() -> { + qOut.enqueue(t); + if(deferredCtr.decrementAndGet() == 0) + future.complete(null); + }); + }); + }; + + submitOOCTasks(qIn, exec, tmp -> { + // Try to run as a predicate to prefer pipelining rather than fan-out + if(ForkJoinTask.getPool() == COMPUTE_EXECUTOR) { + exec.accept(tmp); + return false; + } + return true; + }, null) + .thenRun(() -> { + if(deferredCtr.decrementAndGet() == 0) + future.complete(null); + }) + .exceptionally(err -> { + future.completeExceptionally(err); + return null; + }); + + return future.thenRun(qOut::closeInput).exceptionally(err -> { + DMLRuntimeException dmlErr = DMLRuntimeException.of(err); + qOut.propagateFailure(dmlErr); + throw dmlErr; + }); } protected CompletableFuture broadcastJoinOOC(OOCStream qIn, OOCStream broadcast, OOCStream qOut, BiFunction mapper, Function on) { + return broadcastJoinOOC(qIn, broadcast, qOut, mapper, on, on); + } + + protected CompletableFuture broadcastJoinOOC(OOCStream qIn, OOCStream broadcast, OOCStream qOut, BiFunction mapper, Function onLeft, Function onRight) { addInStream(qIn, broadcast); addOutStream(qOut); - boolean explicitLeftCaching = !qIn.hasStreamCache(); - boolean explicitRightCaching = !broadcast.hasStreamCache(); - CachingStream leftCache = explicitLeftCaching ? new CachingStream(new SubscribableTaskQueue<>()) : qIn.getStreamCache(); - CachingStream rightCache = explicitRightCaching ? new CachingStream(new SubscribableTaskQueue<>()) : broadcast.getStreamCache(); + CachingStream leftCache = qIn.hasStreamCache() ? qIn.getStreamCache() : new CachingStream(qIn); + CachingStream rightCache = broadcast.hasStreamCache() ? broadcast.getStreamCache() : new CachingStream(broadcast); leftCache.activateIndexing(); rightCache.activateIndexing(); - if (!explicitLeftCaching) - leftCache.incrSubscriberCount(1); // Prevent early block deletion as we may read elements twice - - if (!explicitRightCaching) - rightCache.incrSubscriberCount(1); + leftCache.incrSubscriberCount(1); // Prevent early block deletion as we may read elements twice + rightCache.incrSubscriberCount(1); Map> availableLeftInput = new ConcurrentHashMap<>(); Map availableBroadcastInput = new ConcurrentHashMap<>(); OOCStream, OOCStream.QueueCallback, BroadcastedElement>> broadcastingQueue = createWritableStream(); AtomicInteger waitCtr = new AtomicInteger(1); - CompletableFuture fut1 = new CompletableFuture<>(); + Object lock = new Object(); - submitOOCTasks(List.of(qIn, broadcast), (i, tmp) -> { - try (tmp) { - P key = on.apply(tmp.get()); + CompletableFuture fut1 = submitOOCTasks(List.of(leftCache.getReadStream(), rightCache.getReadStream()), (i, tmp) -> { + try(tmp) { + P key = i == 0 ? onLeft.apply(tmp.get()) : onRight.apply(tmp.get()); if(i == 0) { // qIn stream - BroadcastedElement b = availableBroadcastInput.get(key); + BroadcastedElement b; - if(b == null) { - // Matching broadcast element is not available -> cache element - availableLeftInput.compute(key, (k, v) -> { - if(v == null) - v = new ArrayList<>(); - v.add(tmp.get().getIndexes()); - return v; - }); + synchronized(lock) { + b = availableBroadcastInput.get(key); - if(explicitLeftCaching) - leftCache.getWriteStream().enqueue(tmp.get()); + if(b == null) { + availableLeftInput.compute(key, (k, v) -> { + if(v == null) + v = new ArrayList<>(); + v.add(tmp.get().getIndexes()); + return v; + }); + return; + } } - else { - waitCtr.incrementAndGet(); - OOCCacheManager.requestManyBlocks( + // Then items are present in cache + waitCtr.incrementAndGet(); + OOCCacheManager.requestManyBlocks( List.of(leftCache.peekCachedBlockKey(tmp.get().getIndexes()), rightCache.peekCachedBlockKey(b.idx))) - .whenComplete((items, err) -> { - try { - broadcastingQueue.enqueue(new Tuple4<>(key, items.get(0).keepOpen(), items.get(1).keepOpen(), b)); - } finally { - items.forEach(OOCStream.QueueCallback::close); - } - }); - } + .whenComplete((items, err) -> { + try { + broadcastingQueue.enqueue(new Tuple4<>(key, items.get(0).keepOpen(), items.get(1).keepOpen(), b)); + } finally { + items.forEach(OOCStream.QueueCallback::close); + } + }); } else { // broadcast stream - if(explicitRightCaching) - rightCache.getWriteStream().enqueue(tmp.get()); - BroadcastedElement b = new BroadcastedElement(tmp.get().getIndexes()); - availableBroadcastInput.put(key, b); - - List queued = availableLeftInput.remove(key); + List queued; + synchronized(lock) { + availableBroadcastInput.put(key, b); + queued = availableLeftInput.remove(key); + } if(queued != null) { for(MatrixIndexes idx : queued) { @@ -260,41 +319,38 @@ protected CompletableFuture broadcastJoinOOC(OOCStream { - fut1.complete(null); - if (waitCtr.decrementAndGet() == 0) + }); + fut1 = fut1.thenApply(v -> { + if(waitCtr.decrementAndGet() == 0) broadcastingQueue.closeInput(); + return null; }); - CompletableFuture fut2 = new CompletableFuture<>(); - submitOOCTasks(List.of(broadcastingQueue), (i, tpl) -> { - try (tpl) { + CompletableFuture fut2 = submitOOCTasks(List.of(broadcastingQueue), (i, tpl) -> { + try(tpl) { final BroadcastedElement b = tpl.get()._4(); final OOCStream.QueueCallback lValue = tpl.get()._2(); final OOCStream.QueueCallback bValue = tpl.get()._3(); - try (lValue; bValue) { + try(lValue; bValue) { b.value = bValue.get(); - qOut.enqueue(mapper.apply(lValue.get(), b)); leftCache.incrProcessingCount(leftCache.findCachedIndex(lValue.get().getIndexes()), 1); + qOut.enqueue(mapper.apply(lValue.get(), b)); - if(b.canRelease()) { + if(b.tryRelease()) { availableBroadcastInput.remove(tpl.get()._1()); - - if(!explicitRightCaching) - rightCache.incrProcessingCount(rightCache.findCachedIndex(b.idx), - 1); // Correct for incremented subscriber count to allow block deletion + rightCache.incrProcessingCount(rightCache.findCachedIndex(b.idx), 1); // Correct for incremented subscriber count to allow block deletion } } if(waitCtr.decrementAndGet() == 0) broadcastingQueue.closeInput(); } - }, () -> fut2.complete(null)); + }); - if (explicitLeftCaching) + if(!qIn.hasStreamCache()) leftCache.scheduleDeletion(); - if (explicitRightCaching) + if(!broadcast.hasStreamCache()) rightCache.scheduleDeletion(); CompletableFuture fut = CompletableFuture.allOf(fut1, fut2); @@ -309,6 +365,103 @@ protected CompletableFuture broadcastJoinOOC(OOCStream CompletableFuture joinManyOOC(OOCStream left, + OOCStream right, OOCStream out, + BiFunction mapper, Function leftOn, + Function rightOn, int releaseLeftCount, int releaseRightCount) { + addInStream(left, right); + addOutStream(out); + + CachingStream leftCache = left.hasStreamCache() ? left.getStreamCache() : new CachingStream(left); + CachingStream rightCache = right.hasStreamCache() ? right.getStreamCache() : new CachingStream(right); + leftCache.activateIndexing(); + rightCache.activateIndexing(); + + leftCache.incrSubscriberCount(1); // Prevent early block deletion as we may read elements twice + rightCache.incrSubscriberCount(1); + + Map, List>> joinMap = new ConcurrentHashMap<>(); + + OOCStream, OOCStream.QueueCallback, BroadcastedElement, BroadcastedElement>> joinQueue = createWritableStream(); + AtomicInteger waitCtr = new AtomicInteger(1); + + CompletableFuture fut1 = submitOOCTasks(List.of(leftCache.getReadStream(), rightCache.getReadStream()), + (i, tmp) -> { + try(tmp) { + boolean leftItem = i == 0; + P key = (leftItem ? leftOn : rightOn).apply(tmp.get()); + Tuple2, List> tuple = joinMap.computeIfAbsent(key, + k -> new Tuple2<>(new ArrayList<>(releaseRightCount), new ArrayList<>(releaseLeftCount))); + BroadcastedElement b = new BroadcastedElement(tmp.get().getIndexes()); + List matches = leftItem ? tuple._2 : tuple._1; + List toInsert = leftItem ? tuple._1 : tuple._2; + boolean remove; + synchronized(tuple) { + toInsert.add(b); + + for(BroadcastedElement e : matches) { + waitCtr.incrementAndGet(); + OOCCacheManager.requestManyBlocks( + List.of(leftCache.peekCachedBlockKey(leftItem ? b.idx : e.idx), + rightCache.peekCachedBlockKey(leftItem ? e.idx : b.idx))).thenApply(joined -> { + try { + joinQueue.enqueue( + new Tuple5<>(key, joined.get(0).keepOpen(), joined.get(1).keepOpen(), + leftItem ? b : e, leftItem ? e : b)); + } + finally { + joined.forEach(OOCStream.QueueCallback::close); + } + return null; + }).exceptionally(t -> { + joinQueue.propagateFailure(DMLRuntimeException.of(t)); + return null; + }); + } + remove = tuple._1.size() == releaseRightCount && tuple._2.size() == releaseLeftCount; + } + if(remove) + joinMap.remove(key); + } + }); + fut1 = fut1.thenApply(v -> { + if(waitCtr.decrementAndGet() == 0) + joinQueue.closeInput(); + return null; + }); + + CompletableFuture fut2 = mapOOC(joinQueue, out, tpl -> { + final BroadcastedElement bLeft = tpl._4(); + final BroadcastedElement bRight = tpl._5(); + final OOCStream.QueueCallback lValue = tpl._2(); + final OOCStream.QueueCallback rValue = tpl._3(); + R res; + + try(lValue; rValue) { + res = mapper.apply(lValue.get(), rValue.get()); + int leftCtr = bLeft.incrProcessCtrAndGet(); + int rightCtr = bRight.incrProcessCtrAndGet(); + + if(leftCtr == releaseLeftCount) + leftCache.incrProcessingCount(leftCache.findCachedIndex(bLeft.idx), + 1); // Correct for incremented subscriber count to allow block deletion + if(rightCtr == releaseRightCount) + rightCache.incrProcessingCount(rightCache.findCachedIndex(bRight.idx), 1); + } + + if(waitCtr.decrementAndGet() == 0) + joinQueue.closeInput(); + return res; + }); + + if(!left.hasStreamCache()) + leftCache.scheduleDeletion(); + if(!right.hasStreamCache()) + rightCache.scheduleDeletion(); + + return CompletableFuture.allOf(fut1, fut2); + } + protected static class BroadcastedElement { private final MatrixIndexes idx; private IndexedMatrixValue value; @@ -328,6 +481,14 @@ public synchronized boolean canRelease() { return release; } + public synchronized boolean tryRelease() { + if(release) { + release = false; // To not double release + return true; + } + return false; + } + public synchronized int incrProcessCtrAndGet() { processCtr++; return processCtr; @@ -342,27 +503,28 @@ public IndexedMatrixValue getValue() { } } - protected CompletableFuture joinOOC(OOCStream qIn1, OOCStream qIn2, OOCStream qOut, BiFunction mapper, Function on) { - return joinOOC(qIn1, qIn2, qOut, mapper, on, on); + protected CompletableFuture joinOOC(OOCStream qIn1, OOCStream qIn2, OOCStream qOut, BiFunction mapper, Function on) { + return joinOOC(List.of(qIn1, qIn2), qOut, t -> mapper.apply(t.get(0), t.get(1)), on); } - @SuppressWarnings("unchecked") - protected CompletableFuture joinOOC(List> qIn, OOCStream qOut, Function, R> mapper, List> on) { - if (qIn == null || on == null || qIn.size() != on.size()) - throw new DMLRuntimeException("joinOOC(list) requires the same number of streams and key functions."); + protected CompletableFuture joinOOC(List> qIn, OOCStream qOut, Function, R> mapper, Function on) { + int inSize = qIn.size(); + return joinOOC(qIn, qOut, mapper, Collections.nCopies(inSize, on), t -> Collections.nCopies(inSize, t)); + } - addInStream(qIn.toArray(OOCStream[]::new)); - addOutStream(qOut); + protected CompletableFuture joinOOC(List> qIn, OOCStream qOut, Function, R> mapper, List> on, Function> invOn) { + if(qIn == null || on == null || qIn.size() != on.size()) + throw new DMLRuntimeException("joinOOC(list) requires the same number of streams and key functions."); final int n = qIn.size(); CachingStream[] caches = new CachingStream[n]; boolean[] explicitCaching = new boolean[n]; - for (int i = 0; i < n; i++) { - OOCStream s = qIn.get(i); + for(int i = 0; i < n; i++) { + OOCStream s = qIn.get(i); explicitCaching[i] = !s.hasStreamCache(); - caches[i] = explicitCaching[i] ? new CachingStream((OOCStream) s) : s.getStreamCache(); + caches[i] = explicitCaching[i] ? new CachingStream(s) : s.getStreamCache(); caches[i].activateIndexing(); // One additional consumption for the materialization when emitting caches[i].incrSubscriberCount(1); @@ -370,129 +532,273 @@ protected CompletableFuture joinOOC(List> qIn, OOCS Map seen = new ConcurrentHashMap<>(); - CompletableFuture future = submitOOCTasks( - Arrays.stream(caches).map(CachingStream::getReadStream).collect(java.util.stream.Collectors.toList()), - (i, tmp) -> { - Function keyFn = on.get(i); - P key = keyFn.apply((T)tmp.get()); - MatrixIndexes idx = tmp.get().getIndexes(); - - MatrixIndexes[] arr = seen.computeIfAbsent(key, k -> new MatrixIndexes[n]); - boolean ready; - synchronized (arr) { - arr[i] = idx; - ready = true; - for (MatrixIndexes ix : arr) { - if (ix == null) { - ready = false; - break; - } + OOCStream>> materialized = createWritableStream(); + + List> rStreams = new ArrayList<>(caches.length); + for(int i = 0; i < caches.length; i++) + rStreams.add(explicitCaching[i] ? caches[i].getReadStream() : qIn.get(i)); + + AtomicInteger processing = new AtomicInteger(1); + + addInStream(qIn.toArray(OOCStream[]::new)); + addOutStream(qOut); + + CompletableFuture future = pipeOOC(rStreams, (i, tmp) -> { + Function keyFn = on.get(i); + P key = keyFn.apply(tmp.get()); + MatrixIndexes idx = tmp.get().getIndexes(); + + MatrixIndexes[] arr = seen.computeIfAbsent(key, k -> new MatrixIndexes[n]); + boolean ready; + synchronized(arr) { + arr[i] = idx; + ready = true; + for(int j = 0; j < arr.length; j++) { + MatrixIndexes ix = arr[j]; + if (ix == null) { + ready = false; + break; } } + } - if (!ready || !seen.remove(key, arr)) - return; + if(!ready || !seen.remove(key, arr)) + return; - List> values = new java.util.ArrayList<>(n); - try { - for(int j = 0; j < n; j++) - values.add((OOCStream.QueueCallback) caches[j].findCached(arr[j])); + processing.incrementAndGet(); + List entries = new ArrayList<>(arr.length); + for(int j = 0; j < arr.length; j++) + entries.add(caches[j].peekCachedBlockKey(arr[j])); - qOut.enqueue(mapper.apply(values.stream().map(OOCStream.QueueCallback::get).toList())); - } finally { - values.forEach(OOCStream.QueueCallback::close); + var f = OOCCacheManager.requestManyBlocks(entries); + f.whenComplete((r, err) -> { + try { + if(err != null) { + if(err instanceof DMLRuntimeException) + materialized.propagateFailure((DMLRuntimeException) err); + else if(err instanceof Exception) + materialized.propagateFailure(new DMLRuntimeException(err)); + else + materialized.propagateFailure(new DMLRuntimeException(new Exception(err))); + return; + } + List> outList = new ArrayList<>(r.size()); + for(int j = 0; j < r.size(); j++) { + if(explicitCaching[j]) { + // Early forget item from cache + outList.add(new OOCStream.SimpleQueueCallback<>(r.get(j).get(), null)); + } + else { + outList.add(r.get(j).keepOpen()); + } + caches[j].incrProcessingCount(caches[j].findCachedIndex(r.get(j).get().getIndexes()), 1); + } + materialized.enqueue(outList); + r.forEach(OOCStream.QueueCallback::close); + } + catch(Throwable t) { + throw t; + } + finally { + if(processing.decrementAndGet() == 0) + materialized.closeInput(); } - }, qOut::closeInput); + }); + }); + + future.whenComplete((r, err) -> { + if (processing.decrementAndGet() == 0) { + materialized.closeInput(); + } + }); - for (int i = 0; i < n; i++) { + CompletableFuture outFuture = mapOOC(materialized, qOut, cb -> { + try { + List imvs = cb.stream().map(OOCStream.QueueCallback::get).toList(); + return mapper.apply(imvs); + } + finally { + cb.forEach(OOCStream.QueueCallback::close); + } + }); + + for(int i = 0; i < n; i++) { if (explicitCaching[i]) caches[i].scheduleDeletion(); } - return future; + return outFuture; } - @SuppressWarnings("unchecked") - protected CompletableFuture joinOOC(OOCStream qIn1, OOCStream qIn2, OOCStream qOut, BiFunction mapper, Function onLeft, Function onRight) { - addInStream(qIn1, qIn2); + protected CompletableFuture groupedReduceOOC(OOCStream qIn, OOCStream qOut, BiFunction reduce, int emitCount) { + addInStream(qIn); addOutStream(qOut); - final CompletableFuture future = new CompletableFuture<>(); + if(qIn.hasStreamCache()) + throw new UnsupportedOperationException(); + Map aggregators = new ConcurrentHashMap<>(); + AtomicInteger busyCtr = new AtomicInteger(1); + CompletableFuture outFuture = new CompletableFuture<>(); + + CompletableFuture pipeFuture = pipeOOC(qIn, cb -> { + try(cb) { + Aggregator agg = aggregators.compute(cb.get().getIndexes(), (k, v) -> { + if(v == null) { + v = new Aggregator(reduce, emitCount); + busyCtr.incrementAndGet(); + v.getFuture().thenApply(imv -> { + qOut.enqueue(imv); + if(busyCtr.decrementAndGet() == 0) + outFuture.complete(null); + return null; + }) + .exceptionally(outFuture::completeExceptionally); + } + return v; + }); + agg.insert(cb.get()); + } + }); - boolean explicitLeftCaching = !qIn1.hasStreamCache(); - boolean explicitRightCaching = !qIn2.hasStreamCache(); + pipeFuture.thenRun(() -> { + if(busyCtr.decrementAndGet() == 0) + outFuture.complete(null); + }); - // We need to construct our own stream to properly manage the cached items in the hash join - CachingStream leftCache = explicitLeftCaching ? new CachingStream((OOCStream) qIn1) : qIn1.getStreamCache(); - CachingStream rightCache = explicitRightCaching ? new CachingStream((OOCStream) qIn2) : qIn2.getStreamCache(); - leftCache.activateIndexing(); - rightCache.activateIndexing(); + return outFuture.thenRun(qOut::closeInput); + } - leftCache.incrSubscriberCount(1); - rightCache.incrSubscriberCount(1); + private static class Aggregator { + private final long _streamId; + private final BiFunction _aggFn; + private final int _numTiles; + private final CompletableFuture _future; + private LinkedList _availableIntermediates; + private int _blockId; + private int _processed; + + + public Aggregator(BiFunction aggFn, int numTiles) { + _streamId = CachingStream._streamSeq.getNextID(); + _blockId = 0; + _aggFn = aggFn; + _numTiles = numTiles; + _future = new CompletableFuture<>(); + _availableIntermediates = new LinkedList<>(); + _processed = 0; + } - final OOCJoin join = new OOCJoin<>((idx, left, right) -> { - OOCStream.QueueCallback leftObj = (OOCStream.QueueCallback) leftCache.findCached(left); - OOCStream.QueueCallback rightObj = (OOCStream.QueueCallback) rightCache.findCached(right); - try (leftObj; rightObj) { - qOut.enqueue(mapper.apply(leftObj.get(), rightObj.get())); + public CompletableFuture getFuture() { + return _future; + } + + public void insert(IndexedMatrixValue imv) { + IndexedMatrixValue v = null; + CompletableFuture> future = null; + boolean finished = false; + synchronized(this) { + _processed++; + if(_processed == _numTiles * 2 - 1) { + // Then we are done + finished = true; + } + else { + if(!_availableIntermediates.isEmpty()) { + List sel = new ArrayList<>(1); + List entries = OOCCacheManager.getCache().tryRequestAnyOf(_availableIntermediates, 1, sel); + + if(entries == null) { + BlockEntry entry = OOCCacheManager.getCache() + .putAndPin(new BlockKey(_streamId, _blockId++), imv, + ((MatrixBlock) imv.getValue()).getExactSerializedSize()); + entry.addRetainHint(10); + future = OOCCacheManager.getCache() + .request(List.of(entry.getKey(), _availableIntermediates.removeFirst())); + OOCCacheManager.getCache().unpin(entry); + } + else { + v = (IndexedMatrixValue)entries.get(0).getData(); + _availableIntermediates.remove(sel.get(0)); + OOCCacheManager.getCache().forget(sel.get(0)); + if(v == null) + throw new IllegalStateException(); + } + } + else { + BlockEntry entry = OOCCacheManager.getCache() + .putAndPin(new BlockKey(_streamId, _blockId++), imv, + ((MatrixBlock) imv.getValue()).getExactSerializedSize()); + entry.addRetainHint(10); + OOCCacheManager.getCache().unpin(entry); + _availableIntermediates.add(entry.getKey()); + return; + } + } } - }); - submitOOCTasks(List.of(leftCache.getReadStream(), rightCache.getReadStream()), (i, tmp) -> { - try (tmp) { - if(i == 0) - join.addLeft(onLeft.apply((T) tmp.get()), tmp.get().getIndexes()); - else - join.addRight(onRight.apply((T) tmp.get()), tmp.get().getIndexes()); + if(finished) { + _availableIntermediates = null; + _future.complete(imv); + return; } - }, () -> { - join.close(); - qOut.closeInput(); - future.complete(null); - }); - if (explicitLeftCaching) - leftCache.scheduleDeletion(); - if (explicitRightCaching) - rightCache.scheduleDeletion(); + if(v != null) { + imv = _aggFn.apply(v, imv); + insert(imv); + return; + } - return future; + future.thenApply(l -> { + IndexedMatrixValue agg = _aggFn.apply((IndexedMatrixValue)l.get(0).getData(), + (IndexedMatrixValue)l.get(1).getData()); + OOCCacheManager.getCache().forget(l.get(0).getKey()); + OOCCacheManager.getCache().forget(l.get(1).getKey()); + insert(agg); + return null; + }); + } } - protected CompletableFuture submitOOCTasks(final List> queues, BiConsumer> consumer, Runnable finalizer) { - return submitOOCTasks(queues, consumer, finalizer, null); + protected CompletableFuture pipeOOC(OOCStream queue, Consumer> consumer) { + return pipeOOC(List.of(queue), (i, tmp) -> consumer.accept(tmp)); } - protected CompletableFuture submitOOCTasks(final List> queues, BiConsumer> consumer, Runnable finalizer, BiConsumer> onNotProcessed) { - List> futures = new ArrayList<>(queues.size()); - - for (int i = 0; i < queues.size(); i++) - futures.add(new CompletableFuture<>()); + protected CompletableFuture pipeOOC(List> queues, BiConsumer> consumer) { + return submitOOCTasks(queues, consumer, (i, tmp) -> { + // Try to run as a predicate to prefer pipelining rather than fan-out + if(ForkJoinTask.getPool() == COMPUTE_EXECUTOR) { + consumer.accept(i, tmp); + return false; + } + return true; + }, (i, tmp) -> {}); + } - return submitOOCTasks(queues, consumer, finalizer, futures, null, onNotProcessed); + protected CompletableFuture submitOOCTasks(final List> queues, BiConsumer> consumer) { + return submitOOCTasks(queues, consumer, null, null); } - protected CompletableFuture submitOOCTasks(final List> queues, BiConsumer> consumer, Runnable finalizer, List> futures, BiFunction, Boolean> predicate, BiConsumer> onNotProcessed) { + protected CompletableFuture submitOOCTasks(final List> queues, BiConsumer> consumer, BiFunction, Boolean> predicate, BiConsumer> onNotProcessed) { addInStream(queues.toArray(OOCStream[]::new)); - if (_outQueues == null) + if(!outStreamsDefined()) throw new IllegalArgumentException("Explicit specification of all output streams is required before submitting tasks. If no output streams are present use addOutStream()."); - ExecutorService pool = CommonThreadPool.get(); final List activeTaskCtrs = new ArrayList<>(queues.size()); + final List> futures = new ArrayList<>(queues.size()); - for (int i = 0; i < queues.size(); i++) + for(int i = 0; i < queues.size(); i++) { activeTaskCtrs.add(new AtomicInteger(1)); + futures.add(new CompletableFuture<>()); + } final CompletableFuture globalFuture = CompletableFuture.allOf(futures.toArray(CompletableFuture[]::new)); - if (_outQueues == null) - _outQueues = Collections.emptySet(); - final Runnable oocFinalizer = oocTask(finalizer, null, Stream.concat(_outQueues.stream(), _inQueues.stream()).toArray(OOCStream[]::new)); + final StreamContext streamContext = _streamContext.copy(); // Snapshot of the current stream context + if(streamContext == null || !streamContext.inStreamsDefined() || !streamContext.outStreamsDefined()) + throw new IllegalArgumentException("Explicit specification of all output streams is required before submitting tasks. If no output streams are present use addOutStream()."); int i = 0; @SuppressWarnings("unused") final int streamId = nextStreamId.getAndIncrement(); - //System.out.println("New stream: (id " + streamId + ", size " + queues.size() + ", initiator '" + this.getClass().getSimpleName() + "')"); for (OOCStream queue : queues) { final int k = i; @@ -500,10 +806,9 @@ protected CompletableFuture submitOOCTasks(final List> qu final CompletableFuture localFuture = futures.get(k); final AtomicBoolean closeRaceWatchdog = new AtomicBoolean(false); - //System.out.println("Substream (k " + k + ", id " + streamId + ", type '" + queue.getClass().getSimpleName() + "', stream_id " + queue.hashCode() + ")"); queue.setSubscriber(oocTask(callback -> { long startTime = DMLScript.STATISTICS ? System.nanoTime() : 0; - try (callback) { + try(callback) { if(callback.isEos()) { if(!closeRaceWatchdog.compareAndSet(false, true)) throw new DMLRuntimeException( @@ -534,29 +839,45 @@ protected CompletableFuture submitOOCTasks(final List> qu // The item needs to be pinned in memory to be accessible in the executor thread final OOCStream.QueueCallback pinned = callback.keepOpen(); - pool.submit(oocTask(() -> { - long taskStartTime = DMLScript.STATISTICS ? System.nanoTime() : 0; - try (pinned) { - consumer.accept(k, pinned); - - if(localTaskCtr.decrementAndGet() == 0) - localFuture.complete(null); - } finally { - if (DMLScript.STATISTICS) { - _localStatisticsAdder.add(System.nanoTime() - taskStartTime); - if (globalFuture.isDone()) { - Statistics.maintainOOCHeavyHitter(getExtendedOpcode(), _localStatisticsAdder.sum()); - _localStatisticsAdder.reset(); + COMPUTE_IN_FLIGHT.incrementAndGet(); + try { + Runnable oocTask = oocTask(() -> { + long taskStartTime = DMLScript.STATISTICS ? System.nanoTime() : 0; + try(pinned) { + consumer.accept(k, pinned); + + if(localTaskCtr.decrementAndGet() == 0) { + TaskContext.defer(() -> localFuture.complete(null)); } - if (DMLScript.OOC_LOG_EVENTS) - OOCEventLog.onComputeEvent(_callerId, taskStartTime, System.nanoTime()); } - } - }, localFuture, Stream.concat(_outQueues.stream(), _inQueues.stream()).toArray(OOCStream[]::new))); + finally { + COMPUTE_IN_FLIGHT.decrementAndGet(); + if (DMLScript.STATISTICS) { + _localStatisticsAdder.add(System.nanoTime() - taskStartTime); + if (globalFuture.isDone()) { + Statistics.maintainOOCHeavyHitter(getExtendedOpcode(), _localStatisticsAdder.sum()); + _localStatisticsAdder.reset(); + } + if (DMLScript.OOC_LOG_EVENTS) + OOCEventLog.onComputeEvent(_callerId, taskStartTime, System.nanoTime()); + } + } + }, localFuture, streamContext); + COMPUTE_EXECUTOR.submit(oocTask); + } + catch (Exception e) { + COMPUTE_IN_FLIGHT.decrementAndGet(); + throw e; + } if(closeRaceWatchdog.get()) // Sanity check throw new DMLRuntimeException("Race condition observed"); - } finally { + } + catch(Throwable t) { + streamContext.failAll(DMLRuntimeException.of(t)); + throw t; + } + finally { if (DMLScript.STATISTICS) { _localStatisticsAdder.add(System.nanoTime() - startTime); if (globalFuture.isDone()) { @@ -565,12 +886,12 @@ protected CompletableFuture submitOOCTasks(final List> qu } } } - }, null, Stream.concat(_outQueues.stream(), _inQueues.stream()).toArray(OOCStream[]::new))); + }, null, streamContext)); i++; } - globalFuture.whenComplete((res, e) -> { + return globalFuture.handle((res, e) -> { if (globalFuture.isCancelled() || globalFuture.isCompletedExceptionally()) { futures.forEach(f -> { if(!f.isDone()) { @@ -582,34 +903,42 @@ protected CompletableFuture submitOOCTasks(final List> qu }); } - oocFinalizer.run(); + streamContext.clear(); + return null; }); - return globalFuture; } - protected CompletableFuture submitOOCTasks(OOCStream queue, Consumer> consumer, Runnable finalizer) { - return submitOOCTasks(List.of(queue), (i, tmp) -> consumer.accept(tmp), finalizer); + protected CompletableFuture submitOOCTasks(OOCStream queue, Consumer> consumer) { + return submitOOCTasks(List.of(queue), (i, tmp) -> consumer.accept(tmp), null, null); } - protected CompletableFuture submitOOCTasks(OOCStream queue, Consumer> consumer, Runnable finalizer, Function, Boolean> predicate, BiConsumer> onNotProcessed) { - return submitOOCTasks(List.of(queue), (i, tmp) -> consumer.accept(tmp), finalizer, List.of(new CompletableFuture()), (i, tmp) -> predicate.apply(tmp), onNotProcessed); + protected CompletableFuture submitOOCTasks(OOCStream queue, Consumer> consumer, Function, Boolean> predicate, BiConsumer> onNotProcessed) { + return submitOOCTasks(List.of(queue), (i, tmp) -> consumer.accept(tmp), (i, tmp) -> predicate.apply(tmp), onNotProcessed); } - protected CompletableFuture submitOOCTask(Runnable r, OOCStream... queues) { + protected CompletableFuture submitOOCTask(Runnable r, StreamContext ctx) { ExecutorService pool = CommonThreadPool.get(); final CompletableFuture future = new CompletableFuture<>(); try { + COMPUTE_IN_FLIGHT.incrementAndGet(); pool.submit(oocTask(() -> { long startTime = DMLScript.STATISTICS || DMLScript.OOC_LOG_EVENTS ? System.nanoTime() : 0; - r.run(); - future.complete(null); - if (DMLScript.STATISTICS) - Statistics.maintainOOCHeavyHitter(getExtendedOpcode(), System.nanoTime() - startTime); - if (DMLScript.OOC_LOG_EVENTS) - OOCEventLog.onComputeEvent(_callerId, startTime, System.nanoTime()); - }, future, queues)); + try { + r.run(); + future.complete(null); + ctx.clear(); + if (DMLScript.STATISTICS) + Statistics.maintainOOCHeavyHitter(getExtendedOpcode(), System.nanoTime() - startTime); + if (DMLScript.OOC_LOG_EVENTS) + OOCEventLog.onComputeEvent(_callerId, startTime, System.nanoTime()); + } + finally { + COMPUTE_IN_FLIGHT.decrementAndGet(); + } + }, future, ctx)); } catch (Exception ex) { + COMPUTE_IN_FLIGHT.decrementAndGet(); throw new DMLRuntimeException(ex); } finally { @@ -619,29 +948,23 @@ protected CompletableFuture submitOOCTask(Runnable r, OOCStream... queu return future; } - private Runnable oocTask(Runnable r, CompletableFuture future, OOCStream... queues) { + private Runnable oocTask(Runnable r, CompletableFuture future, StreamContext ctx) { return () -> { + boolean setContext = TaskContext.getContext() == null; + if(setContext) + TaskContext.setContext(new TaskContext()); long startTime = DMLScript.STATISTICS ? System.nanoTime() : 0; try { r.run(); + if(setContext) { + while(TaskContext.runDeferred()) { + } + } } catch (Exception ex) { - DMLRuntimeException re = ex instanceof DMLRuntimeException ? (DMLRuntimeException) ex : new DMLRuntimeException(ex); + DMLRuntimeException re = DMLRuntimeException.of(ex); - synchronized(this) { - if(_failed) // Do avoid infinite cycles - throw re; - - _failed = true; - } - - for(OOCStream q : queues) { - try { - q.propagateFailure(re); - } catch(Throwable ignore) { - // Should not happen, but catch just in case - } - } + ctx.failAll(re); if (future != null) future.completeExceptionally(re); @@ -649,34 +972,23 @@ private Runnable oocTask(Runnable r, CompletableFuture future, OOCStream< // Rethrow to ensure proper future handling throw re; } finally { + if(setContext) + TaskContext.clearContext(); if (DMLScript.STATISTICS) _localStatisticsAdder.add(System.nanoTime() - startTime); } }; } - private Consumer> oocTask(Consumer> c, CompletableFuture future, OOCStream... queues) { + private Consumer> oocTask(Consumer> c, CompletableFuture future, StreamContext ctx) { return callback -> { try { c.accept(callback); } catch (Exception ex) { - DMLRuntimeException re = ex instanceof DMLRuntimeException ? (DMLRuntimeException) ex : new DMLRuntimeException(ex); - - synchronized(this) { - if (_failed) // Do avoid infinite cycles - throw re; + DMLRuntimeException re = DMLRuntimeException.of(ex); - _failed = true; - } - - for(OOCStream q : queues) { - try { - q.propagateFailure(re); - } catch(Throwable ignored) { - // Should not happen, but catch just in case - } - } + ctx.failAll(re); if (future != null) future.completeExceptionally(re); diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCStream.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCStream.java index 27dd9515acf..7ee12e9f025 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCStream.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCStream.java @@ -61,6 +61,8 @@ interface QueueCallback extends AutoCloseable { void fail(DMLRuntimeException failure); boolean isEos(); + + boolean isFailure(); } class SimpleQueueCallback implements QueueCallback { @@ -94,7 +96,12 @@ public void close() {} @Override public boolean isEos() { - return get() == null; + return _result == null && _failure == null; + } + + @Override + public boolean isFailure() { + return _failure != null; } } } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCStreamable.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCStreamable.java index af2c0afa660..26fd227e86a 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCStreamable.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCStreamable.java @@ -19,10 +19,42 @@ package org.apache.sysds.runtime.instructions.ooc; +import org.apache.sysds.runtime.controlprogram.caching.CacheableData; +import org.apache.sysds.runtime.meta.DataCharacteristics; +import org.apache.sysds.runtime.ooc.stream.message.OOCStreamMessage; +import org.apache.sysds.runtime.util.IndexRange; + +import java.util.function.BiFunction; +import java.util.function.Consumer; + public interface OOCStreamable { OOCStream getReadStream(); OOCStream getWriteStream(); boolean isProcessed(); + + DataCharacteristics getDataCharacteristics(); + + CacheableData getData(); + + void setData(CacheableData data); + + void messageUpstream(OOCStreamMessage msg); + + void messageDownstream(OOCStreamMessage msg); + + void setUpstreamMessageRelay(Consumer relay); + + void setDownstreamMessageRelay(Consumer relay); + + void addUpstreamMessageRelay(Consumer relay); + + void addDownstreamMessageRelay(Consumer relay); + + void clearUpstreamMessageRelays(); + + void clearDownstreamMessageRelays(); + + void setIXTransform(BiFunction transform); } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/ParameterizedBuiltinOOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/ParameterizedBuiltinOOCInstruction.java index d70fc3ccb94..b1d397d919a 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/ParameterizedBuiltinOOCInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/ParameterizedBuiltinOOCInstruction.java @@ -90,6 +90,9 @@ public void processInstruction(ExecutionContext ec) { double pattern = Double.parseDouble(params.get("pattern")); double replacement = Double.parseDouble(params.get("replacement")); + qIn.setDownstreamMessageRelay(qOut::messageDownstream); + qOut.setUpstreamMessageRelay(qIn::messageUpstream); + mapOOC(qIn, qOut, tmp -> new IndexedMatrixValue(tmp.getIndexes(), tmp.getValue().replaceOperations(new MatrixBlock(), pattern, replacement))); ec.getMatrixObject(output).setStreamHandle(qOut); @@ -114,12 +117,12 @@ else if(instOpcode.equalsIgnoreCase(Opcodes.CONTAINS.toString())) { CompletableFuture future = new CompletableFuture<>(); filterOOC(qIn, tmp -> { - boolean contains = ((MatrixBlock)tmp.getValue()).containsValue(((ScalarObject)finalPattern).getDoubleValue()); + boolean contains = ((MatrixBlock)tmp.getValue()).containsValue(((ScalarObject)finalPattern).getDoubleValue()); - if (contains) - future.complete(true); - }, tmp -> !future.isDone(), // Don't start a separate worker if result already known - () -> future.complete(false)); // Then the pattern was not found + if (contains) + future.complete(true); + }, tmp -> !future.isDone()) // Don't start a separate worker if result already known + .whenComplete((v, err) -> future.complete(false)); // Then the pattern was not found boolean ret; try { diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/PlaybackStream.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/PlaybackStream.java index bd725e5dd44..6a67c6602b6 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/PlaybackStream.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/PlaybackStream.java @@ -20,18 +20,25 @@ package org.apache.sysds.runtime.instructions.ooc; import org.apache.sysds.runtime.DMLRuntimeException; +import org.apache.sysds.runtime.controlprogram.caching.CacheableData; import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue; +import org.apache.sysds.runtime.meta.DataCharacteristics; +import org.apache.sysds.runtime.ooc.stream.message.OOCStreamMessage; +import org.apache.sysds.runtime.util.IndexRange; +import java.util.concurrent.CopyOnWriteArrayList; import java.util.concurrent.ExecutionException; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.BiFunction; import java.util.function.Consumer; -public class PlaybackStream implements OOCStream, OOCStreamable { +public class PlaybackStream implements OOCStream { private final CachingStream _streamCache; private final AtomicInteger _streamIdx; private final AtomicBoolean _subscriberSet; private QueueCallback _lastDequeue; + private volatile CopyOnWriteArrayList> _downstreamRelays; public PlaybackStream(CachingStream streamCache) { this._streamCache = streamCache; @@ -58,7 +65,7 @@ public synchronized IndexedMatrixValue dequeue() { try { if (_lastDequeue != null) _lastDequeue.close(); - _lastDequeue = _streamCache.get(_streamIdx.getAndIncrement()); + _lastDequeue = _streamCache.get(_streamIdx.getAndIncrement()).get(); return _lastDequeue.get(); } catch (InterruptedException | ExecutionException e) { throw new DMLRuntimeException(e); @@ -80,6 +87,42 @@ public boolean isProcessed() { return false; } + @Override + public DataCharacteristics getDataCharacteristics() { + return _streamCache.getDataCharacteristics(); + } + + @Override + public CacheableData getData() { + return _streamCache.getData(); + } + + @Override + public void setData(CacheableData data) { + _streamCache.setData(data); + } + + @Override + public void messageUpstream(OOCStreamMessage msg) { + if(msg.isCancelled()) + return; + _streamCache.messageUpstream(msg); + } + + @Override + public void messageDownstream(OOCStreamMessage msg) { + if(msg.isCancelled()) + return; + CopyOnWriteArrayList> relays = _downstreamRelays; + if (relays != null) { + for (Consumer relay : relays) { + if (msg.isCancelled()) + break; + relay.accept(msg); + } + } + } + @Override public void setSubscriber(Consumer> subscriber) { if (!_subscriberSet.compareAndSet(false, true)) @@ -102,4 +145,51 @@ public boolean hasStreamCache() { public CachingStream getStreamCache() { return _streamCache; } + + @Override + public void setUpstreamMessageRelay(Consumer relay) { + throw new UnsupportedOperationException(); + } + + @Override + public void setDownstreamMessageRelay(Consumer relay) { + addDownstreamMessageRelay(relay); + } + + @Override + public void addUpstreamMessageRelay(Consumer relay) { + throw new UnsupportedOperationException(); + } + + @Override + public void addDownstreamMessageRelay(Consumer relay) { + if (relay == null) + throw new IllegalArgumentException("Cannot set downstream relay to null"); + CopyOnWriteArrayList> relays = _downstreamRelays; + if (relays == null) { + synchronized(this) { + if (_downstreamRelays == null) + _downstreamRelays = new CopyOnWriteArrayList<>(); + relays = _downstreamRelays; + } + } + relays.add(0, relay); + _streamCache.addDownstreamMessageRelay(relay); + } + + @Override + public void clearUpstreamMessageRelays() { + // No upstream relays supported + } + + @Override + public void clearDownstreamMessageRelays() { + _downstreamRelays = null; + _streamCache.clearDownstreamMessageRelays(); + } + + @Override + public void setIXTransform(BiFunction transform) { + throw new UnsupportedOperationException(); + } } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/ReblockOOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/ReblockOOCInstruction.java index f744b97506b..4270836b755 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/ReblockOOCInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/ReblockOOCInstruction.java @@ -31,7 +31,7 @@ import org.apache.sysds.runtime.meta.DataCharacteristics; import org.apache.sysds.runtime.ooc.cache.OOCCacheManager; import org.apache.sysds.runtime.ooc.cache.OOCIOHandler; -import org.apache.sysds.runtime.ooc.stream.OOCSourceStream; +import org.apache.sysds.runtime.ooc.stream.SourceOOCStream; public class ReblockOOCInstruction extends ComputationOOCInstruction { private int blen; @@ -69,7 +69,7 @@ public void processInstruction(ExecutionContext ec) { //TODO support other formats than binary //create queue, spawn thread for asynchronous reading, and return - OOCStream q = new OOCSourceStream(); + OOCStream q = new SourceOOCStream(); OOCIOHandler io = OOCCacheManager.getIOHandler(); OOCIOHandler.SourceReadRequest req = new OOCIOHandler.SourceReadRequest( min.getFileName(), Types.FileFormat.BINARY, mc.getRows(), mc.getCols(), blen, mc.getNonZeros(), diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/ReorgOOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/ReorgOOCInstruction.java index a87a3498329..e861f7afc57 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/ReorgOOCInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/ReorgOOCInstruction.java @@ -34,6 +34,7 @@ import org.apache.sysds.runtime.matrix.operators.Operator; import org.apache.sysds.runtime.matrix.operators.ReorgOperator; import org.apache.sysds.runtime.util.DataConverter; +import org.apache.sysds.runtime.util.IndexRange; public class ReorgOOCInstruction extends ComputationOOCInstruction { // sort-specific attributes (to enable variable attributes) @@ -110,6 +111,12 @@ public void processInstruction( ExecutionContext ec ) { OOCStream qIn = min.getStreamHandle(); OOCStream qOut = createWritableStream(); ec.getMatrixObject(output).setStreamHandle(qOut); + + qIn.setDownstreamMessageRelay(qOut::messageDownstream); + qOut.setUpstreamMessageRelay(qIn::messageUpstream); + qOut.setIXTransform((downstream, range) -> + new IndexRange(range.colStart, range.colEnd, range.rowStart, range.rowEnd)); + // Transpose operation mapOOC(qIn, qOut, tmp -> { MatrixBlock inBlock = (MatrixBlock) tmp.getValue(); diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/SubscribableTaskQueue.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/SubscribableTaskQueue.java index 605a78178fa..058a61c208c 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/SubscribableTaskQueue.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/SubscribableTaskQueue.java @@ -20,18 +20,30 @@ package org.apache.sysds.runtime.instructions.ooc; import org.apache.sysds.runtime.DMLRuntimeException; +import org.apache.sysds.runtime.controlprogram.caching.CacheableData; import org.apache.sysds.runtime.controlprogram.parfor.LocalTaskQueue; +import org.apache.sysds.runtime.meta.DataCharacteristics; +import org.apache.sysds.runtime.ooc.stream.message.OOCGetStreamTypeMessage; +import org.apache.sysds.runtime.ooc.stream.message.OOCStreamMessage; +import org.apache.sysds.runtime.util.IndexRange; import java.util.LinkedList; +import java.util.concurrent.CopyOnWriteArrayList; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.BiFunction; import java.util.function.Consumer; public class SubscribableTaskQueue extends LocalTaskQueue implements OOCStream { private final AtomicInteger _availableCtr = new AtomicInteger(1); private final AtomicBoolean _closed = new AtomicBoolean(false); + private final AtomicInteger _blockCount = new AtomicInteger(0); + private CacheableData _cdata; private volatile Consumer> _subscriber = null; + private volatile CopyOnWriteArrayList> _upstreamMsgRelays = null; + private volatile CopyOnWriteArrayList> _downstreamMsgRelays = null; + private volatile BiFunction _ixTransform = null; private String _watchdogId; public SubscribableTaskQueue() { @@ -68,10 +80,13 @@ public void enqueue(T t) { throw new DMLRuntimeException("Cannot enqueue into closed SubscribableTaskQueue"); } + _blockCount.incrementAndGet(); + Consumer> s = _subscriber; + final Consumer> fS = s; - if (s != null) { - s.accept(new SimpleQueueCallback<>(t, _failure)); + if (fS != null) { + fS.accept(new SimpleQueueCallback<>(t, _failure)); onDeliveryFinished(); return; } @@ -126,11 +141,24 @@ public synchronized void closeInput() { if (_closed.compareAndSet(false, true)) { super.closeInput(); onDeliveryFinished(); + _upstreamMsgRelays = null; + _downstreamMsgRelays = null; } else { throw new IllegalStateException("Multiple close input calls"); } } + private void validateBlockCountOnClose() { + DataCharacteristics dc = getDataCharacteristics(); + if (dc != null && dc.dimsKnown() && dc.getBlocksize() > 0) { + long expected = dc.getNumBlocks(); + if (expected >= 0 && _blockCount.get() != expected) { + throw new DMLRuntimeException("OOCStream block count mismatch: expected " + + expected + " but saw " + _blockCount.get() + " (" + dc.getRows() + "x" + dc.getCols() + ")"); + } + } + } + @Override public void setSubscriber(Consumer> subscriber) { if(subscriber == null) @@ -159,6 +187,7 @@ private void onDeliveryFinished() { int ctr = _availableCtr.decrementAndGet(); if (ctr == 0) { + validateBlockCountOnClose(); Consumer> s = _subscriber; if (s != null) s.accept(new SimpleQueueCallback<>((T) LocalTaskQueue.NO_MORE_TASKS, _failure)); @@ -186,6 +215,43 @@ public OOCStream getWriteStream() { return this; } + @Override + public void messageUpstream(OOCStreamMessage msg) { + if(msg.isCancelled()) + return; + msg.addIXTransform(_ixTransform); + if (msg.isCancelled()) + return; + if (msg instanceof OOCGetStreamTypeMessage) { + if (_cdata != null) + ((OOCGetStreamTypeMessage) msg).setInMemoryType(); + return; + } + CopyOnWriteArrayList> relays = _upstreamMsgRelays; + if(relays != null) { + for (Consumer relay : relays) { + if (msg.isCancelled()) + break; + relay.accept(msg); + } + } + } + + @Override + public void messageDownstream(OOCStreamMessage msg) { + if(!msg.isCancelled()) + return; + msg.addIXTransform(_ixTransform); + CopyOnWriteArrayList> relays = _downstreamMsgRelays; + if(relays != null) { + for (Consumer relay : relays) { + if (msg.isCancelled()) + break; + relay.accept(msg); + } + } + } + @Override public boolean hasStreamCache() { return false; @@ -195,4 +261,81 @@ public boolean hasStreamCache() { public CachingStream getStreamCache() { return null; } + + @Override + public DataCharacteristics getDataCharacteristics() { + return _cdata == null ? null : _cdata.getDataCharacteristics(); + } + + @Override + public CacheableData getData() { + return _cdata; + } + + @Override + public void setData(CacheableData data) { + if(_cdata == null && _closed.get()) + System.out.println("[WARN] Data type was defined after closing, which may bypass validation checks"); + _cdata = data; + } + + @Override + public void setUpstreamMessageRelay(Consumer relay) { + addUpstreamMessageRelay(relay); + } + + @Override + public void setDownstreamMessageRelay(Consumer relay) { + addDownstreamMessageRelay(relay); + } + + @Override + public void addUpstreamMessageRelay(Consumer relay) { + if(relay == null) + throw new IllegalArgumentException("Cannot set upstream relay to null"); + CopyOnWriteArrayList> relays = _upstreamMsgRelays; + if(relays == null) { + synchronized(this) { + if(_upstreamMsgRelays == null) + _upstreamMsgRelays = new CopyOnWriteArrayList<>(); + relays = _upstreamMsgRelays; + } + } + relays.add(0, relay); + } + + @Override + public void addDownstreamMessageRelay(Consumer relay) { + if(relay == null) + throw new IllegalArgumentException("Cannot set downstream relay to null"); + CopyOnWriteArrayList> relays = _downstreamMsgRelays; + if(relays == null) { + synchronized(this) { + if(_downstreamMsgRelays == null) + _downstreamMsgRelays = new CopyOnWriteArrayList<>(); + relays = _downstreamMsgRelays; + } + } + relays.add(0, relay); + } + + @Override + public void clearUpstreamMessageRelays() { + _upstreamMsgRelays = null; + } + + @Override + public void clearDownstreamMessageRelays() { + _downstreamMsgRelays = null; + } + + @Override + public void setIXTransform(BiFunction transform) { + _ixTransform = transform; + } + + @Override + public String toString() { + return "STQ-" + hashCode(); + } } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/TernaryOOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/TernaryOOCInstruction.java index da5c37c50ef..6dfedfc1ff2 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/TernaryOOCInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/TernaryOOCInstruction.java @@ -33,7 +33,6 @@ import org.apache.sysds.runtime.instructions.cp.StringObject; import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue; import org.apache.sysds.runtime.matrix.data.MatrixBlock; -import org.apache.sysds.runtime.matrix.data.MatrixIndexes; import org.apache.sysds.runtime.matrix.operators.Operator; import org.apache.sysds.runtime.matrix.operators.TernaryOperator; @@ -110,6 +109,8 @@ private void processSingleMatrixInstruction(ExecutionContext ec, int matrixPos) OOCStream qIn = mo.getStreamHandle(); OOCStream qOut = createWritableStream(); ec.getMatrixObject(output).setStreamHandle(qOut); + qIn.setDownstreamMessageRelay(qOut::messageDownstream); + qOut.setUpstreamMessageRelay(qIn::messageUpstream); mapOOC(qIn, qOut, tmp -> { IndexedMatrixValue outVal = new IndexedMatrixValue(); @@ -125,6 +126,8 @@ private void processSingleMatrixInstruction(ExecutionContext ec, int matrixPos) private void processTwoMatrixInstruction(ExecutionContext ec, int leftPos, int rightPos) { MatrixObject left = getMatrixObject(ec, leftPos); MatrixObject right = getMatrixObject(ec, rightPos); + OOCStream leftStream = left.getStreamHandle(); + OOCStream rightStream = right.getStreamHandle(); MatrixBlock s1 = input1.isMatrix() ? null : getScalarInputBlock(ec, input1); MatrixBlock s2 = input2.isMatrix() ? null : getScalarInputBlock(ec, input2); @@ -132,8 +135,14 @@ private void processTwoMatrixInstruction(ExecutionContext ec, int leftPos, int r OOCStream qOut = createWritableStream(); ec.getMatrixObject(output).setStreamHandle(qOut); + qOut.setUpstreamMessageRelay(msg -> { + leftStream.messageUpstream(msg.split()); + rightStream.messageUpstream(msg.split()); + }); + leftStream.setDownstreamMessageRelay(qOut::messageDownstream); + rightStream.setDownstreamMessageRelay(qOut::messageDownstream); - joinOOC(left.getStreamHandle(), right.getStreamHandle(), qOut, (l, r) -> { + joinOOC(leftStream, rightStream, qOut, (l, r) -> { IndexedMatrixValue outVal = new IndexedMatrixValue(); MatrixBlock op1 = resolveOperandBlock(1, l, r, leftPos, rightPos, s1, s2, s3); MatrixBlock op2 = resolveOperandBlock(2, l, r, leftPos, rightPos, s1, s2, s3); @@ -155,8 +164,8 @@ private void processThreeMatrixInstruction(ExecutionContext ec) { List> streams = List.of( m1.getStreamHandle(), m2.getStreamHandle(), m3.getStreamHandle()); - List> keyFns = - List.of(IndexedMatrixValue::getIndexes, IndexedMatrixValue::getIndexes, IndexedMatrixValue::getIndexes); + streams.forEach(s -> s.setDownstreamMessageRelay(qOut::messageDownstream)); + qOut.setUpstreamMessageRelay(msg -> streams.forEach(s -> s.messageUpstream(msg))); joinOOC(streams, qOut, blocks -> { IndexedMatrixValue b1 = blocks.get(0); @@ -166,7 +175,7 @@ private void processThreeMatrixInstruction(ExecutionContext ec) { outVal.set(b1.getIndexes(), ((MatrixBlock)b1.getValue()).ternaryOperations((TernaryOperator)_optr, (MatrixBlock)b2.getValue(), (MatrixBlock)b3.getValue(), new MatrixBlock())); return outVal; - }, keyFns); + }, IndexedMatrixValue::getIndexes); } private MatrixObject getMatrixObject(ExecutionContext ec, int pos) { diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/UnaryOOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/UnaryOOCInstruction.java index d45d00db592..b9c7612bfe9 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/UnaryOOCInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/UnaryOOCInstruction.java @@ -60,6 +60,8 @@ public void processInstruction( ExecutionContext ec ) { OOCStream qIn = min.getStreamHandle(); OOCStream qOut = createWritableStream(); ec.getMatrixObject(output).setStreamHandle(qOut); + qIn.setDownstreamMessageRelay(qOut::messageDownstream); + qOut.setUpstreamMessageRelay(qIn::messageUpstream); mapOOC(qIn, qOut, tmp -> { IndexedMatrixValue tmpOut = new IndexedMatrixValue(); diff --git a/src/main/java/org/apache/sysds/runtime/ooc/cache/BlockEntry.java b/src/main/java/org/apache/sysds/runtime/ooc/cache/BlockEntry.java index eea76c808a2..b5da05a598d 100644 --- a/src/main/java/org/apache/sysds/runtime/ooc/cache/BlockEntry.java +++ b/src/main/java/org/apache/sysds/runtime/ooc/cache/BlockEntry.java @@ -27,6 +27,7 @@ public final class BlockEntry { private volatile int _pinCount; private volatile BlockState _state; private Object _data; + private int _retainHintCount; BlockEntry(BlockKey key, long size, Object data) { this._key = key; @@ -34,6 +35,7 @@ public final class BlockEntry { this._pinCount = 0; this._state = BlockState.HOT; this._data = data; + this._retainHintCount = 0; } public BlockKey getKey() { @@ -70,6 +72,30 @@ synchronized void setState(BlockState state) { _state = state; } + public synchronized void addRetainHint(int cnt) { + _retainHintCount += cnt; + } + + public synchronized void addRetainHint() { + _retainHintCount++; + } + + public synchronized void removeRetainHint(int cnt) { + _retainHintCount -= cnt; + if(_retainHintCount < 0) + _retainHintCount = 0; + } + + public synchronized void removeRetainHint() { + if (_retainHintCount <= 0) + return; + _retainHintCount--; + } + + public synchronized int getRetainHintCount() { + return _retainHintCount; + } + /** * Tries to clear the underlying data if it is not pinned * @return the number of cleared bytes (or 0 if could not clear or data was already cleared) @@ -80,6 +106,7 @@ synchronized long clear() { if (_data instanceof IndexedMatrixValue) ((IndexedMatrixValue)_data).setValue(null); // Explicitly clear _data = null; + _retainHintCount = 0; return _size; } @@ -104,4 +131,8 @@ synchronized boolean unpin() { _pinCount--; return _pinCount == 0; } + + public String toString() { + return "Entry" + _key.toString(); + } } diff --git a/src/main/java/org/apache/sysds/runtime/ooc/cache/CloseableQueue.java b/src/main/java/org/apache/sysds/runtime/ooc/cache/CloseableQueue.java index b8c312d2a3d..4f1c5799736 100644 --- a/src/main/java/org/apache/sysds/runtime/ooc/cache/CloseableQueue.java +++ b/src/main/java/org/apache/sysds/runtime/ooc/cache/CloseableQueue.java @@ -45,6 +45,7 @@ public boolean enqueueIfOpen(T task) throws InterruptedException { return true; } + @SuppressWarnings("unchecked") public T take() throws InterruptedException { if (closed && queue.isEmpty()) return null; diff --git a/src/main/java/org/apache/sysds/runtime/ooc/cache/DeferredReadQueue.java b/src/main/java/org/apache/sysds/runtime/ooc/cache/DeferredReadQueue.java new file mode 100644 index 00000000000..b5564430fe2 --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/ooc/cache/DeferredReadQueue.java @@ -0,0 +1,192 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.ooc.cache; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.IdentityHashMap; +import java.util.Map; +import java.util.Set; + +class DeferredReadQueue { + private final ArrayList _heap; + private final Map _index; + private final Map> _byKey; + private long _seqCounter; + + DeferredReadQueue() { + this._heap = new ArrayList<>(); + this._index = new IdentityHashMap<>(); + this._byKey = new HashMap<>(); + this._seqCounter = 0; + } + + boolean isEmpty() { + return _heap.isEmpty(); + } + + int size() { + return _heap.size(); + } + + DeferredReadRequest peek() { + return _heap.isEmpty() ? null : _heap.get(0); + } + + DeferredReadRequest poll() { + if(_heap.isEmpty()) + return null; + DeferredReadRequest req = _heap.get(0); + removeAt(0); + removeFromIndex(req); + return req; + } + + void add(DeferredReadRequest req) { + req.setSequence(_seqCounter++); + _heap.add(req); + _index.put(req, _heap.size() - 1); + addToIndex(req); + heapifyUp(_heap.size() - 1); + } + + void remove(DeferredReadRequest req) { + Integer idx = _index.get(req); + if(idx == null) + return; + removeAt(idx); + removeFromIndex(req); + } + + void clear() { + _heap.clear(); + _index.clear(); + _byKey.clear(); + _seqCounter = 0; + } + + boolean boost(BlockKey key, double priority) { + if(priority == 0) + return false; + Set requests = _byKey.get(key); + if(requests == null || requests.isEmpty()) + return false; + for(DeferredReadRequest req : requests) { + double delta = priority / req.getEntries().size(); + req.addPriorityScore(delta); + updatePosition(req); + } + return true; + } + + private void updatePosition(DeferredReadRequest req) { + Integer idx = _index.get(req); + if(idx == null) + return; + int parent = (idx - 1) / 2; + if(idx > 0 && compare(req, _heap.get(parent)) > 0) + heapifyUp(idx); + else + heapifyDown(idx); + } + + private void addToIndex(DeferredReadRequest req) { + for(BlockEntry entry : req.getEntries()) { + BlockKey key = entry.getKey(); + Set set = _byKey.get(key); + if(set == null) { + set = Collections.newSetFromMap(new IdentityHashMap<>()); + _byKey.put(key, set); + } + set.add(req); + } + } + + private void removeFromIndex(DeferredReadRequest req) { + for(BlockEntry entry : req.getEntries()) { + BlockKey key = entry.getKey(); + Set set = _byKey.get(key); + if(set == null) + continue; + set.remove(req); + if(set.isEmpty()) + _byKey.remove(key); + } + } + + private DeferredReadRequest removeAt(int idx) { + int lastIdx = _heap.size() - 1; + DeferredReadRequest removed = _heap.get(idx); + DeferredReadRequest last = _heap.get(lastIdx); + _heap.set(idx, last); + _heap.remove(lastIdx); + _index.remove(removed); + if(idx < _heap.size()) { + _index.put(last, idx); + updatePosition(last); + } + return removed; + } + + private void heapifyUp(int idx) { + int i = idx; + while(i > 0) { + int parent = (i - 1) / 2; + if(compare(_heap.get(i), _heap.get(parent)) <= 0) + break; + swap(i, parent); + i = parent; + } + } + + private void heapifyDown(int idx) { + int i = idx; + int size = _heap.size(); + while(true) { + int left = i * 2 + 1; + if(left >= size) + break; + int right = left + 1; + int best = left; + if(right < size && compare(_heap.get(right), _heap.get(left)) > 0) + best = right; + if(compare(_heap.get(best), _heap.get(i)) <= 0) + break; + swap(i, best); + i = best; + } + } + + private void swap(int i, int j) { + DeferredReadRequest tmp = _heap.get(i); + _heap.set(i, _heap.get(j)); + _heap.set(j, tmp); + _index.put(_heap.get(i), i); + _index.put(_heap.get(j), j); + } + + private int compare(DeferredReadRequest a, DeferredReadRequest b) { + int byPriority = Double.compare(a.getPriorityScore(), b.getPriorityScore()); + if(byPriority != 0) + return byPriority; + return Long.compare(b.getSequence(), a.getSequence()); + } +} diff --git a/src/main/java/org/apache/sysds/runtime/ooc/cache/DeferredReadRequest.java b/src/main/java/org/apache/sysds/runtime/ooc/cache/DeferredReadRequest.java new file mode 100644 index 00000000000..0ca6cbd2eab --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/ooc/cache/DeferredReadRequest.java @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.ooc.cache; + +import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.atomic.AtomicInteger; + +class DeferredReadRequest { + static final short NOT_SCHEDULED = 0; + static final short SCHEDULED = 1; + static final short PINNED = 2; + + private final CompletableFuture> _future; + private final List _entries; + private final short[] _pinned; + private final boolean[] _retainHinted; + private final AtomicInteger _availableCount; + private double _priorityScore; + private long _sequence; + + DeferredReadRequest(CompletableFuture> future, List entries) { + this._future = future; + this._entries = entries; + this._pinned = new short[entries.size()]; + this._retainHinted = new boolean[entries.size()]; + this._availableCount = new AtomicInteger(0); + this._priorityScore = 0; + this._sequence = 0; + } + + CompletableFuture> getFuture() { + return _future; + } + + List getEntries() { + return _entries; + } + + synchronized void setPriorityScore(double score) { + _priorityScore = score; + } + + synchronized void addPriorityScore(double delta) { + _priorityScore += delta; + } + + synchronized double getPriorityScore() { + return _priorityScore; + } + + synchronized void setSequence(long sequence) { + _sequence = sequence; + } + + synchronized long getSequence() { + return _sequence; + } + + synchronized boolean actionRequired(int idx) { + return _pinned[idx] == NOT_SCHEDULED; + } + + synchronized boolean setPinned(int idx) { + if(_pinned[idx] == PINNED) + return false; // already pinned + _pinned[idx] = PINNED; + return _availableCount.incrementAndGet() == _entries.size(); + } + + synchronized void schedule(int idx) { + _pinned[idx] = SCHEDULED; + } + + synchronized void markRetainHinted(int idx) { + _retainHinted[idx] = true; + } + + synchronized boolean isRetainHinted(int idx) { + return _retainHinted[idx]; + } + + boolean isComplete() { + return _availableCount.get() == _entries.size(); + } +} diff --git a/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCCacheManager.java b/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCCacheManager.java index bbf4cfb314c..ef6824022cc 100644 --- a/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCCacheManager.java +++ b/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCCacheManager.java @@ -21,6 +21,7 @@ import org.apache.sysds.api.DMLScript; import org.apache.sysds.runtime.DMLRuntimeException; +import org.apache.sysds.runtime.instructions.ooc.OOCInstruction; import org.apache.sysds.runtime.instructions.ooc.OOCStream; import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue; import org.apache.sysds.runtime.matrix.data.MatrixBlock; @@ -147,6 +148,10 @@ public static OOCStream.QueueCallback putAndPinSourceBacked( descriptor), null); } + public static void prioritize(BlockKey key, int priority) { + getCache().prioritize(key, priority); + } + public static CompletableFuture> requestBlock(long streamId, long blockId) { BlockKey key = new BlockKey(streamId, blockId); return getCache().request(key).thenApply(e -> new CachedQueueCallback<>(e, null)); @@ -157,6 +162,23 @@ public static CompletableFuture l -> l.stream().map(e -> (OOCStream.QueueCallback)new CachedQueueCallback(e, null)).toList()); } + public static List> tryRequestManyBlocks(List keys) { + List entries = getCache().tryRequest(keys); + if(entries == null) + return null; + return entries.stream().map(e -> (OOCStream.QueueCallback)new CachedQueueCallback(e, null)).toList(); + } + + public static CompletableFuture>> requestAnyOf(List keys, int n, List sel) { + return getCache().requestAnyOf(keys, n, sel) + .thenApply( + l -> l.stream().map(e -> (OOCStream.QueueCallback)new CachedQueueCallback(e, null)).toList()); + } + + public static boolean canClaimMemory() { + return getCache().isWithinSoftLimits() && OOCInstruction.getComputeInFlight() <= OOCInstruction.getComputeBackpressureThreshold(); + } + private static void pin(BlockEntry entry) { getCache().pin(entry); } @@ -170,11 +192,15 @@ private static void unpin(BlockEntry entry) { static class CachedQueueCallback implements OOCStream.QueueCallback { private final BlockEntry _result; - private DMLRuntimeException _failure; private final AtomicBoolean _pinned; + private T _data; + private DMLRuntimeException _failure; + private CompletableFuture _future; + @SuppressWarnings("unchecked") CachedQueueCallback(BlockEntry result, DMLRuntimeException failure) { this._result = result; + this._data = (T)result.getData(); this._failure = failure; this._pinned = new AtomicBoolean(true); } @@ -182,19 +208,16 @@ static class CachedQueueCallback implements OOCStream.QueueCallback { @SuppressWarnings("unchecked") @Override public T get() { - if (_failure != null) + if(_failure != null) throw _failure; - if (!_pinned.get()) + if(!_pinned.get()) throw new IllegalStateException("Cannot get cached item of a closed callback"); - T ret = (T)_result.getData(); - if (ret == null) - throw new IllegalStateException("Cannot get a cached item if it is not pinned in memory: " + _result.getState()); - return ret; + return _data; } @Override public OOCStream.QueueCallback keepOpen() { - if (!_pinned.get()) + if(!_pinned.get()) throw new IllegalStateException("Cannot keep open an already closed callback"); pin(_result); return new CachedQueueCallback<>(_result, _failure); @@ -210,10 +233,18 @@ public boolean isEos() { return get() == null; } + @Override + public boolean isFailure() { + return _failure != null; + } + @Override public void close() { - if (_pinned.compareAndSet(true, false)) { + if(_pinned.compareAndSet(true, false)) { + _data = null; unpin(_result); + if(_future != null) + _future.complete(null); } } } diff --git a/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCCacheScheduler.java b/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCCacheScheduler.java index cd04f9879aa..bafda48f4d4 100644 --- a/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCCacheScheduler.java +++ b/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCCacheScheduler.java @@ -38,6 +38,30 @@ public interface OOCCacheScheduler { */ CompletableFuture> request(List keys); + /** + * Tries to request a list of blocks from the cache that must be available at the same time. + * Immediately returns the list of entries if present, otherwise null without scheduling reads. + * @param keys the requested keys associated to the block + * @return the list of available BlockEntries + */ + List tryRequest(List keys); + + /** + * Requests any n entries of the list of blocks, preferring an available item. + */ + CompletableFuture> requestAnyOf(List keys, int n, List selectionOut); + + /** + * Requests any n entries of the list of blocks, preferring an available item. + */ + List tryRequestAnyOf(List keys, int n, List selectionOut); + + /** + * Adds the given priority to any pending request accessing the key. + * Multi-requests are prioritized partially. + */ + void prioritize(BlockKey key, double priority); + /** * Places a new block in the cache. Note that objects are immutable and cannot be overwritten. * The object data should now only be accessed via cache, as ownership has been transferred. @@ -45,7 +69,7 @@ public interface OOCCacheScheduler { * @param data the block data * @param size the size of the data */ - void put(BlockKey key, Object data, long size); + BlockKey put(BlockKey key, Object data, long size); /** * Places a new block in the cache and returns a pinned handle. @@ -96,8 +120,28 @@ BlockEntry putAndPinSourceBacked(BlockKey key, Object data, long size, */ void unpin(BlockEntry entry); + /** + * Returns the current cache size in bytes. + */ + long getCacheSize(); + + /** + * Returns if the current cache size is within its defined memory limits. + */ + boolean isWithinLimits(); + + /** + * Returns if the current cache size is within its soft memory limits. + */ + boolean isWithinSoftLimits(); + /** * Shuts down the cache scheduler. */ void shutdown(); + + /** + * Updates the cache limits. + */ + void updateLimits(long evictionLimit, long hardLimit); } diff --git a/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCIOHandler.java b/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCIOHandler.java index b4d14646e0e..0699597c8b7 100644 --- a/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCIOHandler.java +++ b/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCIOHandler.java @@ -29,6 +29,11 @@ public interface OOCIOHandler { CompletableFuture scheduleRead(BlockEntry block); + /** + * Increase priority for a pending scheduled read if it has not started yet. + */ + void prioritizeRead(BlockKey key, double priority); + CompletableFuture scheduleDeletion(BlockEntry block); /** diff --git a/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCLRUCacheScheduler.java b/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCLRUCacheScheduler.java index 0f30914770a..00da1681813 100644 --- a/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCLRUCacheScheduler.java +++ b/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCLRUCacheScheduler.java @@ -19,13 +19,15 @@ package org.apache.sysds.runtime.ooc.cache; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; import org.apache.sysds.api.DMLScript; import org.apache.sysds.runtime.ooc.stats.OOCEventLog; import org.apache.sysds.utils.Statistics; import scala.Tuple2; -import java.util.ArrayDeque; import java.util.ArrayList; +import java.util.ArrayDeque; import java.util.Collection; import java.util.Collections; import java.util.Deque; @@ -33,18 +35,18 @@ import java.util.LinkedHashMap; import java.util.List; import java.util.concurrent.CompletableFuture; -import java.util.concurrent.atomic.AtomicInteger; public class OOCLRUCacheScheduler implements OOCCacheScheduler { private static final boolean SANITY_CHECKS = false; - + private static final Log LOG = LogFactory.getLog(OOCLRUCacheScheduler.class.getName()); private final OOCIOHandler _ioHandler; private final LinkedHashMap _cache; private final HashMap _evictionCache; - private final Deque _deferredReadRequests; + private final DeferredReadQueue _deferredReadRequests; private final Deque _processingReadRequests; - private final long _hardLimit; - private final long _evictionLimit; + private final HashMap _blockReads; + private long _hardLimit; + private long _evictionLimit; private final int _callerId; private long _cacheSize; private long _bytesUpForEviction; @@ -55,8 +57,9 @@ public OOCLRUCacheScheduler(OOCIOHandler ioHandler, long evictionLimit, long har this._ioHandler = ioHandler; this._cache = new LinkedHashMap<>(1024, 0.75f, true); this._evictionCache = new HashMap<>(); - this._deferredReadRequests = new ArrayDeque<>(); + this._deferredReadRequests = new DeferredReadQueue(); this._processingReadRequests = new ArrayDeque<>(); + this._blockReads = new HashMap<>(); this._hardLimit = hardLimit; this._evictionLimit = evictionLimit; this._cacheSize = 0; @@ -101,24 +104,58 @@ public CompletableFuture request(BlockKey key) { return CompletableFuture.completedFuture(entry); } - //System.out.println("Requesting deferred: " + key); // Schedule deferred read otherwise - final CompletableFuture future = new CompletableFuture<>(); final CompletableFuture> requestFuture = new CompletableFuture<>(); - requestFuture.whenComplete((r, t) -> future.complete(r.get(0))); + CompletableFuture future = requestFuture.thenApply(l -> l.get(0)); scheduleDeferredRead(new DeferredReadRequest(requestFuture, Collections.singletonList(entry))); return future; } + @Override + public List tryRequest(List keys) { + CompletableFuture> f = request(keys, true); + if(f == null) + return null; + return f.getNow(null); + } + + @Override + public CompletableFuture> requestAnyOf(List keys, int n, List selectionOut) { + List l = tryRequestAnyOf(keys, n, selectionOut); + if(l != null) + return CompletableFuture.completedFuture(l); + return request(keys.subList(0, n)); + } + + @Override + public List tryRequestAnyOf(List keys, int n, List selectionOut) { + List present = new ArrayList<>(n); + for(BlockKey key : keys) { + List l = tryRequest(List.of(key)); + if(l != null) { + present.add(l.get(0)); + selectionOut.add(l.get(0).getKey()); + if(l.size() == n) + return l; + } + } + present.forEach(this::unpin); + return null; + } + @Override public CompletableFuture> request(List keys) { + return request(keys, false); + } + + public CompletableFuture> request(List keys, boolean onlyIfAvailable) { if (!this._running) throw new IllegalStateException("Cache scheduler has been shut down."); Statistics.incrementOOCEvictionGet(keys.size()); List entries = new ArrayList<>(keys.size()); - boolean couldPinAll = true; + boolean allAvailable = true; synchronized(this) { for (BlockKey key : keys) { @@ -128,48 +165,91 @@ public CompletableFuture> request(List keys) { if (entry == null) throw new IllegalArgumentException("Could not find requested block with key " + key); - if (couldPinAll) { - synchronized(entry) { - if(entry.getState().isAvailable()) { - if(entry.pin() == 0) - throw new IllegalStateException(); - } - else { - couldPinAll = false; - } - } + synchronized(entry) { + if(!entry.getState().isAvailable()) + allAvailable = false; + } + entries.add(entry); + } - if (!couldPinAll) { - // Undo pin for all previous entries - for (BlockEntry e : entries) - e.unpin(); // Do not unpin using unpin(...) method to avoid explicit eviction on memory pressure + if(allAvailable) { + for(BlockEntry entry : entries) { + synchronized(entry) { + if(entry.pin() == 0) + throw new IllegalStateException(); } } - entries.add(entry); } } - if (couldPinAll) { + if (allAvailable) { // Then we could pin all entries return CompletableFuture.completedFuture(entries); } + if(onlyIfAvailable) + return null; + // Schedule deferred read otherwise final CompletableFuture> future = new CompletableFuture<>(); - scheduleDeferredRead(new DeferredReadRequest(future, entries)); + DeferredReadRequest request = new DeferredReadRequest(future, entries); + for (int i = 0; i < entries.size(); i++) { + BlockEntry entry = entries.get(i); + synchronized(entry) { + if (entry.getState().isAvailable()) { + entry.addRetainHint(); + request.markRetainHinted(i); + } + } + } + scheduleDeferredRead(request); return future; } + @Override + public void prioritize(BlockKey key, double priority) { + if (!this._running) + return; + if (priority == 0) + return; + + synchronized(this) { + boolean matched = _deferredReadRequests.boost(key, priority); + if(matched) { + BlockReadState state = _blockReads.computeIfAbsent(key, k -> new BlockReadState()); + state.priority += priority; + } + } + _ioHandler.prioritizeRead(key, priority); + } + private void scheduleDeferredRead(DeferredReadRequest deferredReadRequest) { synchronized(this) { + double score = 0; + int readyCount = 0; + for (BlockEntry entry : deferredReadRequest.getEntries()) { + synchronized(entry) { + if (entry.getState().isAvailable()) + readyCount++; + } + BlockReadState state = _blockReads.get(entry.getKey()); + if (state != null) + score += state.priority; + } + if (!deferredReadRequest.getEntries().isEmpty()) + score /= deferredReadRequest.getEntries().size(); + if (!deferredReadRequest.getEntries().isEmpty()) + score += ((double) readyCount) / deferredReadRequest.getEntries().size(); + deferredReadRequest.setPriorityScore(score); + _deferredReadRequests.add(deferredReadRequest); } onCacheSizeChanged(false); // To schedule deferred reads if possible } @Override - public void put(BlockKey key, Object data, long size) { - put(key, data, size, false, null); + public BlockKey put(BlockKey key, Object data, long size) { + return put(key, data, size, false, null).getKey(); } @Override @@ -229,6 +309,7 @@ public void forget(BlockKey key) { shouldScheduleDeletion = entry.getState().isBackedByDisk() || entry.getState() == BlockState.EVICTING; cacheSizeDelta = transitionMemState(entry, BlockState.REMOVED); + entry.setDataUnsafe(null); } } @@ -259,6 +340,7 @@ public void unpin(BlockEntry entry) { if (couldFree) { long cacheSizeDelta = 0; + boolean shouldCheckEviction = false; synchronized(this) { if (_cacheSize <= _evictionLimit) return; // Nothing to do @@ -268,22 +350,49 @@ public void unpin(BlockEntry entry) { return; // Pin state changed so we cannot evict if (entry.getState().isAvailable() && entry.getState().isBackedByDisk()) { - cacheSizeDelta = transitionMemState(entry, BlockState.COLD); - long cleared = entry.clear(); - if (cleared != entry.getSize()) - throw new IllegalStateException(); - _cache.remove(entry.getKey()); - _evictionCache.put(entry.getKey(), entry); + if (entry.getRetainHintCount() > 0) { + shouldCheckEviction = true; + } + else { + cacheSizeDelta = transitionMemState(entry, BlockState.COLD); + long cleared = entry.clear(); + if (cleared != entry.getSize()) + throw new IllegalStateException(); + _cache.remove(entry.getKey()); + _evictionCache.put(entry.getKey(), entry); + } } else if (entry.getState() == BlockState.HOT) { - cacheSizeDelta = onUnpinnedHotBlockUnderMemoryPressure(entry); + if (entry.getRetainHintCount() > 0) { + shouldCheckEviction = true; + } + else { + cacheSizeDelta = onUnpinnedHotBlockUnderMemoryPressure(entry); + } } } } if (cacheSizeDelta != 0) onCacheSizeChanged(cacheSizeDelta > 0); + else if (shouldCheckEviction) + onCacheSizeChanged(true); } } + @Override + public synchronized long getCacheSize() { + return _cacheSize; + } + + @Override + public boolean isWithinLimits() { + return _cacheSize < _hardLimit; + } + + @Override + public boolean isWithinSoftLimits() { + return _cacheSize < _evictionLimit; + } + @Override public synchronized void shutdown() { this._running = false; @@ -291,10 +400,17 @@ public synchronized void shutdown() { _evictionCache.clear(); _processingReadRequests.clear(); _deferredReadRequests.clear(); + _blockReads.clear(); _cacheSize = 0; _bytesUpForEviction = 0; } + @Override + public synchronized void updateLimits(long evictionLimit, long hardLimit) { + _evictionLimit = evictionLimit; + _hardLimit = hardLimit; + } + /** * Must be called while this cache and the corresponding entry are locked */ @@ -379,21 +495,33 @@ private void onCacheSizeIncremented() { List toRemove = new ArrayList<>(); upForEviction = new ArrayList<>(); - for(BlockEntry entry : entries) { - if(_cacheSize - _bytesUpForEviction <= _evictionLimit) - break; + for(int pass = 0; pass < 2; pass++) { + boolean allowRetainHint = pass == 1; + for(BlockEntry entry : entries) { + if(_cacheSize - _bytesUpForEviction <= _evictionLimit) + break; - synchronized(entry) { - if(!entry.isPinned() && entry.getState().isBackedByDisk()) { - cacheSizeDelta += transitionMemState(entry, BlockState.COLD); - entry.clear(); - toRemove.add(entry); - } - else if(entry.getState() != BlockState.EVICTING && !entry.getState().isBackedByDisk()) { - cacheSizeDelta += transitionMemState(entry, BlockState.EVICTING); - upForEviction.add(entry); + synchronized(entry) { + if(entry.isPinned()) + continue; + if(!allowRetainHint && entry.getRetainHintCount() > 0) + continue; + if(entry.getState() == BlockState.COLD || entry.getState() == BlockState.EVICTING) + continue; + + if(entry.getState().isBackedByDisk()) { + cacheSizeDelta += transitionMemState(entry, BlockState.COLD); + entry.clear(); + toRemove.add(entry); + } + else { + cacheSizeDelta += transitionMemState(entry, BlockState.EVICTING); + upForEviction.add(entry); + } } } + if(_cacheSize - _bytesUpForEviction <= _evictionLimit) + break; } for(BlockEntry entry : toRemove) { @@ -414,6 +542,7 @@ else if(entry.getState() != BlockState.EVICTING && !entry.getState().isBackedByD private boolean onCacheSizeDecremented() { boolean allReserved = true; + boolean reading = false; List> toRead; DeferredReadRequest req; synchronized(this) { @@ -435,11 +564,18 @@ private boolean onCacheSizeDecremented() { throw new IllegalStateException(); req.setPinned(idx); } + else if (entry.getState() == BlockState.READING) { + req.schedule(idx); + registerWaiter(entry.getKey(), req, idx); + reading = true; + } else { if(_cacheSize + entry.getSize() <= _hardLimit) { transitionMemState(entry, BlockState.READING); toRead.add(new Tuple2<>(idx, entry)); req.schedule(idx); + registerWaiter(entry.getKey(), req, idx); + reading = true; } else { allReserved = false; @@ -448,7 +584,7 @@ private boolean onCacheSizeDecremented() { } } - if (allReserved) { + if(allReserved) { _deferredReadRequests.poll(); if (!toRead.isEmpty()) _processingReadRequests.add(req); @@ -457,36 +593,71 @@ private boolean onCacheSizeDecremented() { sanityCheck(); } - if (allReserved && toRead.isEmpty()) { + if(allReserved && !reading) { + clearRetainHints(req); + req.getFuture().complete(req.getEntries()); + return true; + } + else if(allReserved && reading && req.isComplete()) { + clearRetainHints(req); + synchronized(this) { + _processingReadRequests.remove(req); + _deferredReadRequests.remove(req); + } req.getFuture().complete(req.getEntries()); return true; } - for (Tuple2 tpl : toRead) { - final int idx = tpl._1; + for(Tuple2 tpl : toRead) { final BlockEntry entry = tpl._2; CompletableFuture future = _ioHandler.scheduleRead(entry); future.whenComplete((r, t) -> { - boolean allAvailable; + if(t != null) { + BlockReadState state; + synchronized(OOCLRUCacheScheduler.this) { + state = _blockReads.remove(entry.getKey()); + + } + if(state != null) { + for(DeferredReadWaiter waiter : state.waiters) + waiter.request.getFuture().completeExceptionally(t); + } + else { + LOG.error("Uncaught CacheError", t); + t.printStackTrace(); + } + return; + } + java.util.Set completedRequests = new java.util.HashSet<>(); synchronized(this) { synchronized(r) { transitionMemState(r, BlockState.WARM); - if (r.pin() == 0) - throw new IllegalStateException(); _evictionCache.remove(r.getKey()); _cache.put(r.getKey(), r); - allAvailable = req.setPinned(idx); } - if (allAvailable) { - _processingReadRequests.remove(req); + BlockReadState state = _blockReads.remove(r.getKey()); + if(state != null) { + for(DeferredReadWaiter waiter : state.waiters) { + synchronized(r) { + if(r.pin() == 0) + throw new IllegalStateException(); + if(waiter.request.setPinned(waiter.index) || waiter.request.isComplete()) + completedRequests.add(waiter.request); + } + } + } + + for(DeferredReadRequest done : completedRequests) { + clearRetainHints(done); + _processingReadRequests.remove(done); + _deferredReadRequests.remove(done); } sanityCheck(); } - if (allAvailable) { - req.getFuture().complete(req.getEntries()); - } + for(DeferredReadRequest done : completedRequests) + done.getFuture().complete(done.getEntries()); }); } @@ -502,6 +673,8 @@ private void onEvicted(final BlockEntry entry) { long cacheSizeDelta; synchronized(this) { synchronized(entry) { + if(entry.getState() == BlockState.REMOVED) + return; if(entry.isPinned()) { transitionMemState(entry, BlockState.WARM); return; // Then we cannot clear the data @@ -521,6 +694,17 @@ private void onEvicted(final BlockEntry entry) { onCacheSizeChanged(cacheSizeDelta > 0); } + private void clearRetainHints(DeferredReadRequest request) { + for (int i = 0; i < request.getEntries().size(); i++) { + if (!request.isRetainHinted(i)) + continue; + BlockEntry entry = request.getEntries().get(i); + synchronized(entry) { + entry.removeRetainHint(); + } + } + } + /** * Cleanly transitions state of a BlockEntry and handles accounting. * Requires both the scheduler object and the entry to be locked: @@ -574,46 +758,28 @@ private long transitionMemState(BlockEntry entry, BlockState newState) { return _cacheSize - oldCacheSize; } + private void registerWaiter(BlockKey key, DeferredReadRequest request, int index) { + BlockReadState state = _blockReads.computeIfAbsent(key, k -> new BlockReadState()); + state.waiters.add(new DeferredReadWaiter(request, index)); + } + private static class BlockReadState { + private double priority; + private final List waiters; - private static class DeferredReadRequest { - private static final short NOT_SCHEDULED = 0; - private static final short SCHEDULED = 1; - private static final short PINNED = 2; - - private final CompletableFuture> _future; - private final List _entries; - private final short[] _pinned; - private final AtomicInteger _availableCount; - - DeferredReadRequest(CompletableFuture> future, List entries) { - this._future = future; - this._entries = entries; - this._pinned = new short[entries.size()]; - this._availableCount = new AtomicInteger(0); - } - - CompletableFuture> getFuture() { - return _future; - } - - List getEntries() { - return _entries; - } - - public synchronized boolean actionRequired(int idx) { - return _pinned[idx] == NOT_SCHEDULED; + private BlockReadState() { + this.priority = 0; + this.waiters = new ArrayList<>(); } + } - public synchronized boolean setPinned(int idx) { - if (_pinned[idx] == PINNED) - return false; // already pinned - _pinned[idx] = PINNED; - return _availableCount.incrementAndGet() == _entries.size(); - } + private static class DeferredReadWaiter { + private final DeferredReadRequest request; + private final int index; - public synchronized void schedule(int idx) { - _pinned[idx] = SCHEDULED; + private DeferredReadWaiter(DeferredReadRequest request, int index) { + this.request = request; + this.index = index; } } } diff --git a/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCMatrixIOHandler.java b/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCMatrixIOHandler.java index a9da3ccd294..ea508274402 100644 --- a/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCMatrixIOHandler.java +++ b/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCMatrixIOHandler.java @@ -33,7 +33,7 @@ import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.matrix.data.MatrixIndexes; import org.apache.sysds.runtime.ooc.stats.OOCEventLog; -import org.apache.sysds.runtime.ooc.stream.OOCSourceStream; +import org.apache.sysds.runtime.ooc.stream.SourceOOCStream; import org.apache.sysds.runtime.util.FastBufferedDataInputStream; import org.apache.sysds.runtime.util.FastBufferedDataOutputStream; import org.apache.sysds.runtime.util.LocalFileUtils; @@ -55,6 +55,7 @@ import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentLinkedDeque; +import java.util.concurrent.PriorityBlockingQueue; import java.util.concurrent.RejectedExecutionException; import java.util.concurrent.ThreadPoolExecutor; import java.util.concurrent.TimeUnit; @@ -74,12 +75,17 @@ public class OOCMatrixIOHandler implements OOCIOHandler { private final String _spillDir; private final ThreadPoolExecutor _writeExec; private final ThreadPoolExecutor _readExec; + private final ThreadPoolExecutor _srcReadExec; + private final ThreadPoolExecutor _deleteExec; + private final ConcurrentHashMap _pendingReads = new ConcurrentHashMap<>(); + private final AtomicLong _readSeq = new AtomicLong(0); // Spill related structures private final ConcurrentHashMap _spillLocations = new ConcurrentHashMap<>(); private final ConcurrentHashMap _partitions = new ConcurrentHashMap<>(); private final ConcurrentHashMap _sourceLocations = new ConcurrentHashMap<>(); private final AtomicInteger _partitionCounter = new AtomicInteger(0); + private final Object _spillLock = new Object(); private final CloseableQueue>>[] _q; private final AtomicLong _wCtr; private final AtomicBoolean _started; @@ -102,6 +108,18 @@ public OOCMatrixIOHandler() { READER_SIZE, 0L, TimeUnit.MILLISECONDS, + new PriorityBlockingQueue<>()); + _srcReadExec = new ThreadPoolExecutor( + READER_SIZE, + READER_SIZE, + 0L, + TimeUnit.MILLISECONDS, + new ArrayBlockingQueue<>(100000)); + _deleteExec = new ThreadPoolExecutor( + 1, + 1, + 0L, + TimeUnit.MILLISECONDS, new ArrayBlockingQueue<>(100000)); _q = new CloseableQueue[WRITER_SIZE]; _wCtr = new AtomicLong(0); @@ -134,6 +152,10 @@ public void shutdown() { _writeExec.shutdownNow(); _readExec.getQueue().clear(); _readExec.shutdownNow(); + _srcReadExec.getQueue().clear(); + _srcReadExec.shutdownNow(); + _deleteExec.getQueue().clear(); + _deleteExec.shutdownNow(); _spillLocations.clear(); _partitions.clear(); if (started) @@ -158,26 +180,35 @@ public CompletableFuture scheduleEviction(BlockEntry block) { @Override public CompletableFuture scheduleRead(final BlockEntry block) { final CompletableFuture future = new CompletableFuture<>(); + int pinnedPartitionId = pinPartitionForRead(block.getKey()); try { - _readExec.submit(() -> { - try { - long ioStart = DMLScript.OOC_LOG_EVENTS ? System.nanoTime() : 0; - loadFromDisk(block); - if (DMLScript.OOC_LOG_EVENTS) - OOCEventLog.onDiskReadEvent(_readCallerId, ioStart, System.nanoTime(), block.getSize()); - future.complete(block); - } catch (Throwable e) { - future.completeExceptionally(e); - } - }); + ReadTask task = new ReadTask(block, future, _readSeq.getAndIncrement(), pinnedPartitionId); + _pendingReads.put(block.getKey(), task); + _readExec.execute(task); } catch (RejectedExecutionException e) { + unpinPartitionForRead(pinnedPartitionId); + _pendingReads.remove(block.getKey()); future.completeExceptionally(e); } return future; } + @Override + public void prioritizeRead(BlockKey key, double priority) { + if (priority == 0) + return; + ReadTask task = _pendingReads.get(key); + if (task == null) + return; + if (_readExec.getQueue().remove(task)) { + task.addPriority(priority); + _readExec.getQueue().offer(task); + } + } + @Override public CompletableFuture scheduleDeletion(BlockEntry block) { + removeSpillLocation(block.getKey().toFileKey()); _sourceLocations.remove(block.getKey()); return CompletableFuture.completedFuture(true); } @@ -261,7 +292,7 @@ private CompletableFuture readBinarySourceParallel(SourceReadR continue; final int fileIdx = i; try { - _readExec.submit(() -> { + _srcReadExec.submit(() -> { try { readSequenceFile(job, files[fileIdx], request, fileIdx, filePositions, completed, stop, budgetHit, bytesRead, byteLimit, budgetLock, descriptors); @@ -288,8 +319,13 @@ private CompletableFuture readBinarySourceParallel(SourceReadR } if(!anyTask) { - tryCloseTarget(request.target, true); - result.complete(new SourceReadResult(bytesRead.get(), true, null, List.of())); + try { + closeTarget(request.target, true); + result.complete(new SourceReadResult(bytesRead.get(), true, null, List.of())); + } + catch(DMLRuntimeException e) { + result.completeExceptionally(e); + } } return result; @@ -304,16 +340,23 @@ private void completeResult(CompletableFuture future, AtomicLo return; } - if (budgetHit.get()) { - if (!request.keepOpenOnLimit) - tryCloseTarget(request.target, false); - SourceReadContinuation cont = new SourceReadState(request, files, filePositions, completed); - future.complete(new SourceReadResult(bytesRead.get(), false, cont, new ArrayList<>(descriptors))); - return; - } + try { + if (budgetHit.get()) { + if(!request.keepOpenOnLimit) { + closeTarget(request.target, false); + } + SourceReadContinuation cont = new SourceReadState(request, files, filePositions, completed); + future.complete(new SourceReadResult(bytesRead.get(), false, cont, new ArrayList<>(descriptors))); + + return; + } - tryCloseTarget(request.target, true); - future.complete(new SourceReadResult(bytesRead.get(), true, null, new ArrayList<>(descriptors))); + closeTarget(request.target, true); + future.complete(new SourceReadResult(bytesRead.get(), true, null, new ArrayList<>(descriptors))); + } + catch(DMLRuntimeException e) { + future.completeExceptionally(e); + } } private void readSequenceFile(JobConf job, Path path, SourceReadRequest request, int fileIdx, @@ -354,7 +397,7 @@ else if (bytesRead.get() + blockSize > byteLimit) { SourceBlockDescriptor descriptor = new SourceBlockDescriptor(path.toString(), request.format, outIdx, recordStart, (int)(recordEnd - recordStart), blockSize); - if (request.target instanceof OOCSourceStream src) + if (request.target instanceof SourceOOCStream src) src.enqueue(imv, descriptor); else request.target.enqueue(imv); @@ -377,12 +420,13 @@ else if (bytesRead.get() + blockSize > byteLimit) { } } - private void tryCloseTarget(org.apache.sysds.runtime.instructions.ooc.OOCStream target, boolean close) { - if (close) { + private void closeTarget(org.apache.sysds.runtime.instructions.ooc.OOCStream target, boolean close) { + if(close) { try { target.closeInput(); } - catch(Exception ignored) { + catch(Exception ex) { + throw ex instanceof DMLRuntimeException ? (DMLRuntimeException) ex : new DMLRuntimeException(ex); } } } @@ -393,7 +437,12 @@ private void loadFromDisk(BlockEntry block) { SourceBlockDescriptor src = _sourceLocations.get(block.getKey()); if (src != null) { + long ioStart = DMLScript.OOC_STATISTICS ? System.nanoTime() : 0; loadFromSource(block, src); + if (DMLScript.OOC_STATISTICS) { + Statistics.incrementOOCLoadFromDisk(); + Statistics.accumulateOOCLoadFromDiskTime(System.nanoTime() - ioStart); + } return; } @@ -417,10 +466,10 @@ private void loadFromDisk(BlockEntry block) { raf.seek(sloc.offset); DataInput dis = new FastBufferedDataInputStream(Channels.newInputStream(raf.getChannel())); - long ioStart = DMLScript.STATISTICS ? System.nanoTime() : 0; + long ioStart = DMLScript.OOC_STATISTICS ? System.nanoTime() : 0; ix.readFields(dis); // 1. Read Indexes mb.readFields(dis); // 2. Read Block - if (DMLScript.STATISTICS) + if (DMLScript.OOC_STATISTICS) ioDuration = System.nanoTime() - ioStart; } catch (ClosedByInterruptException ignored) { } catch (IOException e) { @@ -429,7 +478,7 @@ private void loadFromDisk(BlockEntry block) { block.setDataUnsafe(new IndexedMatrixValue(ix, mb)); - if (DMLScript.STATISTICS) { + if (DMLScript.OOC_STATISTICS) { Statistics.incrementOOCLoadFromDisk(); Statistics.accumulateOOCLoadFromDiskTime(ioDuration); } @@ -470,6 +519,7 @@ private void evictTask(CloseableQueue PartitionFile partFile = new PartitionFile(filename); _partitions.put(partitionId, partFile); + partFile.incrementRefCount(); // Writer pin; released when partition closes FileOutputStream fos = null; CountableFastBufferedDataOutputStream dos = null; @@ -484,12 +534,12 @@ private void evictTask(CloseableQueue boolean closePartition = false; while((tpl = q.take()) != null) { - long ioStart = DMLScript.STATISTICS || DMLScript.OOC_LOG_EVENTS ? System.nanoTime() : 0; + long ioStart = DMLScript.OOC_STATISTICS || DMLScript.OOC_LOG_EVENTS ? System.nanoTime() : 0; BlockEntry entry = tpl._1; CompletableFuture future = tpl._2; long wrote = writeOut(partitionId, entry, future, fos, dos, waitingForFlush); - if(DMLScript.STATISTICS && wrote > 0) { + if(DMLScript.OOC_STATISTICS && wrote > 0) { Statistics.incrementOOCEvictionWrite(); Statistics.accumulateOOCEvictionWriteTime(System.nanoTime() - ioStart); } @@ -507,13 +557,13 @@ private void evictTask(CloseableQueue if (!closePartition && q.close()) { while((tpl = q.take()) != null) { - long ioStart = DMLScript.STATISTICS ? System.nanoTime() : 0; + long ioStart = DMLScript.OOC_STATISTICS ? System.nanoTime() : 0; BlockEntry entry = tpl._1; CompletableFuture future = tpl._2; long wrote = writeOut(partitionId, entry, future, fos, dos, waitingForFlush); byteCtr += wrote; - if(DMLScript.STATISTICS && wrote > 0) { + if(DMLScript.OOC_STATISTICS && wrote > 0) { Statistics.incrementOOCEvictionWrite(); Statistics.accumulateOOCEvictionWriteTime(System.nanoTime() - ioStart); } @@ -533,6 +583,7 @@ private void evictTask(CloseableQueue IOUtilFunctions.closeSilently(fos); if(waitingForFlush != null) flushQueue(Long.MAX_VALUE, waitingForFlush); + releasePartitionWriter(partitionId); } } } @@ -550,6 +601,8 @@ private long writeOut(int partitionId, BlockEntry entry, CompletableFuture // 2. write indexes and block IndexedMatrixValue imv = (IndexedMatrixValue) entry.getDataUnsafe(); // Get data without requiring pin + if(imv == null) + return 0; imv.getIndexes().write(dos); // write Indexes imv.getValue().write(dos); @@ -558,7 +611,7 @@ private long writeOut(int partitionId, BlockEntry entry, CompletableFuture // 3. create the spillLocation SpillLocation sloc = new SpillLocation(partitionId, offsetBefore); - _spillLocations.put(key, sloc); + addSpillLocation(key, sloc); flushQueue(fos.getChannel().position(), flushQueue); return offsetAfter - offsetBefore; @@ -574,6 +627,130 @@ private void flushQueue(long offset, ConcurrentLinkedDeque LocalFileUtils.deleteFileIfExists(partFile.filePath, true)); + } + catch(RejectedExecutionException ignored) { + } + } + } + } + + private void releasePartitionWriter(int partitionId) { + synchronized(_spillLock) { + PartitionFile partFile = _partitions.get(partitionId); + if(partFile == null) + return; + int remaining = partFile.decrementRefCount(); + if(remaining == 0 && _partitions.remove(partitionId, partFile)) { + try { + _deleteExec.execute(() -> LocalFileUtils.deleteFileIfExists(partFile.filePath, true)); + } + catch(RejectedExecutionException ignored) { + } + } + } + } + + private int pinPartitionForRead(BlockKey key) { + String fileKey = key.toFileKey(); + synchronized(_spillLock) { + SpillLocation sloc = _spillLocations.get(fileKey); + if(sloc == null) + return -1; + PartitionFile partFile = _partitions.get(sloc.partitionId); + if(partFile == null) + return -1; + partFile.incrementRefCount(); + return sloc.partitionId; + } + } + + private void unpinPartitionForRead(int partitionId) { + if(partitionId < 0) + return; + synchronized(_spillLock) { + PartitionFile partFile = _partitions.get(partitionId); + if(partFile == null) + return; + int remaining = partFile.decrementRefCount(); + if(remaining == 0 && _partitions.remove(partitionId, partFile)) { + try { + _deleteExec.execute(() -> LocalFileUtils.deleteFileIfExists(partFile.filePath, true)); + } + catch(RejectedExecutionException ignored) { + } + } + } + } + + private class ReadTask implements Runnable, Comparable { + private final BlockEntry _block; + private final CompletableFuture _future; + private final long _sequence; + private final int _pinnedPartitionId; + private double _priority; + + private ReadTask(BlockEntry block, CompletableFuture future, long sequence, int pinnedPartitionId) { + this._block = block; + this._future = future; + this._sequence = sequence; + this._pinnedPartitionId = pinnedPartitionId; + this._priority = 0; + } + + private void addPriority(double delta) { + _priority += delta; + } + + @Override + public void run() { + _pendingReads.remove(_block.getKey(), this); + try { + long ioStart = DMLScript.OOC_LOG_EVENTS ? System.nanoTime() : 0; + loadFromDisk(_block); + if (DMLScript.OOC_LOG_EVENTS) + OOCEventLog.onDiskReadEvent(_readCallerId, ioStart, System.nanoTime(), _block.getSize()); + _future.complete(_block); + } catch (Throwable e) { + _future.completeExceptionally(e); + } finally { + unpinPartitionForRead(_pinnedPartitionId); + } + } + + @Override + public int compareTo(ReadTask other) { + int byPriority = Double.compare(other._priority, _priority); + if (byPriority != 0) + return byPriority; + return Long.compare(_sequence, other._sequence); + } + } + @@ -590,9 +767,19 @@ private static class SpillLocation { private static class PartitionFile { final String filePath; + private final AtomicInteger refCount; PartitionFile(String filePath) { this.filePath = filePath; + this.refCount = new AtomicInteger(0); + } + + int incrementRefCount() { + return refCount.incrementAndGet(); + } + + int decrementRefCount() { + return refCount.decrementAndGet(); } } diff --git a/src/main/java/org/apache/sysds/runtime/ooc/stats/OOCEventLog.java b/src/main/java/org/apache/sysds/runtime/ooc/stats/OOCEventLog.java index 0df22c9a851..2279272afa6 100644 --- a/src/main/java/org/apache/sysds/runtime/ooc/stats/OOCEventLog.java +++ b/src/main/java/org/apache/sysds/runtime/ooc/stats/OOCEventLog.java @@ -56,6 +56,8 @@ public static int registerCaller(String callerName) { public static void onComputeEvent(int callerId, long startTimestamp, long endTimestamp) { int idx = _logCtr.getAndIncrement(); + if(idx >= _eventTypes.length) + return; _eventTypes[idx] = EventType.COMPUTE; _startTimestamps[idx] = startTimestamp; _endTimestamps[idx] = endTimestamp; @@ -65,6 +67,8 @@ public static void onComputeEvent(int callerId, long startTimestamp, long endTim public static void onDiskWriteEvent(int callerId, long startTimestamp, long endTimestamp, long size) { int idx = _logCtr.getAndIncrement(); + if(idx >= _eventTypes.length) + return; _eventTypes[idx] = EventType.DISK_WRITE; _startTimestamps[idx] = startTimestamp; _endTimestamps[idx] = endTimestamp; @@ -75,6 +79,8 @@ public static void onDiskWriteEvent(int callerId, long startTimestamp, long endT public static void onDiskReadEvent(int callerId, long startTimestamp, long endTimestamp, long size) { int idx = _logCtr.getAndIncrement(); + if(idx >= _eventTypes.length) + return; _eventTypes[idx] = EventType.DISK_READ; _startTimestamps[idx] = startTimestamp; _endTimestamps[idx] = endTimestamp; @@ -85,6 +91,8 @@ public static void onDiskReadEvent(int callerId, long startTimestamp, long endTi public static void onCacheSizeChangedEvent(int callerId, long timestamp, long cacheSize, long bytesToEvict) { int idx = _logCtr.getAndIncrement(); + if(idx >= _eventTypes.length) + return; _eventTypes[idx] = EventType.CACHESIZE_CHANGE; _startTimestamps[idx] = timestamp; _endTimestamps[idx] = bytesToEvict; @@ -117,7 +125,7 @@ private static String getFilteredCSV(String header, EventType filter, boolean da StringBuilder sb = new StringBuilder(); sb.append(header); - int maxIdx = _logCtr.get(); + int maxIdx = Math.min(_logCtr.get(), _eventTypes.length); for (int i = 0; i < maxIdx; i++) { if (_eventTypes[i] != filter) continue; diff --git a/src/main/java/org/apache/sysds/runtime/ooc/stream/FilteredOOCStream.java b/src/main/java/org/apache/sysds/runtime/ooc/stream/FilteredOOCStream.java new file mode 100644 index 00000000000..0af68edd521 --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/ooc/stream/FilteredOOCStream.java @@ -0,0 +1,161 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.ooc.stream; + +import org.apache.sysds.runtime.DMLRuntimeException; +import org.apache.sysds.runtime.controlprogram.caching.CacheableData; +import org.apache.sysds.runtime.instructions.ooc.CachingStream; +import org.apache.sysds.runtime.instructions.ooc.OOCStream; +import org.apache.sysds.runtime.meta.DataCharacteristics; +import org.apache.sysds.runtime.ooc.stream.message.OOCStreamMessage; +import org.apache.sysds.runtime.util.IndexRange; + +import java.util.function.BiFunction; +import java.util.function.Consumer; +import java.util.function.Function; + +public class FilteredOOCStream implements OOCStream { + private final OOCStream _sourceStream; + private final Function _predicate; + private CacheableData _data; + + public FilteredOOCStream(OOCStream sourceStream, Function predicate) { + _sourceStream = sourceStream; + _predicate = predicate; + } + + @Override + public void enqueue(T t) { + _sourceStream.enqueue(t); + } + + @Override + public synchronized T dequeue() { + T next; + while((next = _sourceStream.dequeue()) != null) { + if(_predicate.apply(next)) + return next; + } + return null; + } + + @Override + public void closeInput() { + _sourceStream.closeInput(); + } + + @Override + public void propagateFailure(DMLRuntimeException re) { + _sourceStream.propagateFailure(re); + } + + @Override + public boolean hasStreamCache() { + return _sourceStream.hasStreamCache(); + } + + @Override + public CachingStream getStreamCache() { + return _sourceStream.getStreamCache(); + } + + @Override + public void setSubscriber(Consumer> subscriber) { + _sourceStream.setSubscriber(cb -> { + if(cb.isFailure() || cb.isEos() || _predicate.apply(cb.get())) + subscriber.accept(cb); + }); + } + + @Override + public OOCStream getReadStream() { + return this; + } + + @Override + public OOCStream getWriteStream() { + return _sourceStream.getWriteStream(); + } + + @Override + public boolean isProcessed() { + return _sourceStream.isProcessed(); + } + + @Override + public DataCharacteristics getDataCharacteristics() { + return _data == null ? null : _data.getDataCharacteristics(); + } + + @Override + public CacheableData getData() { + return _data; + } + + @Override + public void setData(CacheableData data) { + _data = data; + } + + @Override + public void messageUpstream(OOCStreamMessage msg) { + _sourceStream.messageUpstream(msg); + } + + @Override + public void messageDownstream(OOCStreamMessage msg) { + _sourceStream.messageDownstream(msg); + } + + @Override + public void setUpstreamMessageRelay(Consumer relay) { + _sourceStream.setUpstreamMessageRelay(relay); + } + + @Override + public void setDownstreamMessageRelay(Consumer relay) { + _sourceStream.setDownstreamMessageRelay(relay); + } + + @Override + public void addUpstreamMessageRelay(Consumer relay) { + _sourceStream.addUpstreamMessageRelay(relay); + } + + @Override + public void addDownstreamMessageRelay(Consumer relay) { + _sourceStream.addDownstreamMessageRelay(relay); + } + + @Override + public void clearUpstreamMessageRelays() { + _sourceStream.clearUpstreamMessageRelays(); + } + + @Override + public void clearDownstreamMessageRelays() { + _sourceStream.clearDownstreamMessageRelays(); + } + + @Override + public void setIXTransform(BiFunction transform) { + _sourceStream.setIXTransform(transform); + } +} diff --git a/src/main/java/org/apache/sysds/runtime/ooc/stream/OOCSourceStream.java b/src/main/java/org/apache/sysds/runtime/ooc/stream/SourceOOCStream.java similarity index 62% rename from src/main/java/org/apache/sysds/runtime/ooc/stream/OOCSourceStream.java rename to src/main/java/org/apache/sysds/runtime/ooc/stream/SourceOOCStream.java index c48aaa45ab2..df1b415cb96 100644 --- a/src/main/java/org/apache/sysds/runtime/ooc/stream/OOCSourceStream.java +++ b/src/main/java/org/apache/sysds/runtime/ooc/stream/SourceOOCStream.java @@ -19,23 +19,31 @@ package org.apache.sysds.runtime.ooc.stream; +import org.apache.sysds.runtime.DMLRuntimeException; +import org.apache.sysds.runtime.instructions.ooc.OOCInstruction; import org.apache.sysds.runtime.instructions.ooc.SubscribableTaskQueue; import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue; +import org.apache.sysds.runtime.ooc.cache.OOCCacheManager; import org.apache.sysds.runtime.ooc.cache.OOCIOHandler; +import org.apache.sysds.runtime.ooc.stream.message.OOCStreamMessage; import org.apache.sysds.runtime.matrix.data.MatrixIndexes; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.locks.LockSupport; -public class OOCSourceStream extends SubscribableTaskQueue { +public class SourceOOCStream extends SubscribableTaskQueue { private final ConcurrentHashMap _idx; + private static final long BACKPRESSURE_PARK_NANOS = 1_000_000L; + private static final long MAX_BACKPRESSURE_PARK_NANOS = 2_000_000_000L; - public OOCSourceStream() { + public SourceOOCStream() { this._idx = new ConcurrentHashMap<>(); } public void enqueue(IndexedMatrixValue value, OOCIOHandler.SourceBlockDescriptor descriptor) { if(descriptor == null) throw new IllegalArgumentException("Source descriptor must not be null"); + waitForBackpressure(); MatrixIndexes key = new MatrixIndexes(descriptor.indexes); _idx.put(key, descriptor); super.enqueue(value); @@ -49,4 +57,25 @@ public void enqueue(IndexedMatrixValue val) { public OOCIOHandler.SourceBlockDescriptor getDescriptor(MatrixIndexes indexes) { return _idx.get(indexes); } + + private void waitForBackpressure() { + int limit = OOCInstruction.getComputeBackpressureThreshold(); + if(limit <= 0) + return; + long parkNanos = BACKPRESSURE_PARK_NANOS; + while(!OOCCacheManager.canClaimMemory()) { + LockSupport.parkNanos(parkNanos); + if (Thread.interrupted()) + throw new DMLRuntimeException(new InterruptedException()); + if (parkNanos < MAX_BACKPRESSURE_PARK_NANOS) + parkNanos = Math.min(parkNanos * 2, MAX_BACKPRESSURE_PARK_NANOS); + } + } + + @Override + public void messageUpstream(OOCStreamMessage msg) { + if(msg.isCancelled()) + return; + super.messageUpstream(msg); + } } diff --git a/src/main/java/org/apache/sysds/runtime/ooc/stream/StreamContext.java b/src/main/java/org/apache/sysds/runtime/ooc/stream/StreamContext.java new file mode 100644 index 00000000000..9c9f2e3fc0e --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/ooc/stream/StreamContext.java @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.ooc.stream; + +import org.apache.sysds.runtime.DMLRuntimeException; +import org.apache.sysds.runtime.instructions.ooc.OOCStream; + +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; + +public class StreamContext { + private Set> _inStreams; + private Set> _outStreams; + private DMLRuntimeException _failure; + + public boolean inStreamsDefined() { + return _inStreams != null; + } + + public boolean outStreamsDefined() { + return _outStreams != null; + } + + public StreamContext addInStream(OOCStream... inStream) { + if(_inStreams == null) + _inStreams = ConcurrentHashMap.newKeySet(); + _inStreams.addAll(List.of(inStream)); + return this; + } + + public StreamContext addOutStream(OOCStream... outStream) { + if(outStream.length == 0 && _outStreams == null) { + _outStreams = Collections.emptySet(); + return this; + } + + if(_outStreams == null || _outStreams.isEmpty()) + _outStreams = ConcurrentHashMap.newKeySet(); + _outStreams.addAll(List.of(outStream)); + return this; + } + + public Collection> inStreams() { + return _inStreams; + } + + public Collection> outStreams() { + return _outStreams; + } + + public void failAll(DMLRuntimeException e) { + if(_failure != null) + return; + _failure = e; + + for(OOCStream stream : _outStreams) { + try { + stream.propagateFailure(e); + } + catch(Throwable ignored) {} + } + + for(OOCStream stream : _inStreams) { + try { + stream.propagateFailure(e); + } + catch(Throwable ignored) {} + } + } + + public void clear() { + _inStreams = null; + _outStreams = null; + } + + public StreamContext copy() { + StreamContext cpy = new StreamContext(); + cpy._inStreams = _inStreams; + cpy._outStreams = _outStreams; + return cpy; + } +} diff --git a/src/main/java/org/apache/sysds/runtime/ooc/stream/TaskContext.java b/src/main/java/org/apache/sysds/runtime/ooc/stream/TaskContext.java new file mode 100644 index 00000000000..5b6381d4bec --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/ooc/stream/TaskContext.java @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.ooc.stream; + +import java.util.ArrayDeque; + +public class TaskContext { + private static final ThreadLocal CTX = new ThreadLocal<>(); + + private ArrayDeque _deferred; + + public static TaskContext getContext() { + return CTX.get(); + } + + public static void setContext(TaskContext context) { + if(CTX.get() != null) + throw new IllegalStateException(); + CTX.set(context); + } + + public static void clearContext() { + CTX.remove(); + } + + public static void defer(Runnable deferred) { + TaskContext ctx = CTX.get(); + if(ctx == null) { + deferred.run(); + return; + } + if(ctx._deferred == null) + ctx._deferred = new ArrayDeque<>(); + ctx._deferred.add(deferred); + if(ctx._deferred.size() > 3) + System.out.println("[WARN] Defer size bigger than 3"); + } + + public static boolean runDeferred() { + TaskContext ctx = CTX.get(); + if(ctx == null || ctx._deferred == null || ctx._deferred.isEmpty()) + return false; + Runnable deferred; + while((deferred = ctx._deferred.poll()) != null) + deferred.run(); + return true; + } +} diff --git a/src/main/java/org/apache/sysds/runtime/ooc/stream/message/OOCGetStreamTypeMessage.java b/src/main/java/org/apache/sysds/runtime/ooc/stream/message/OOCGetStreamTypeMessage.java new file mode 100644 index 00000000000..0866735d83f --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/ooc/stream/message/OOCGetStreamTypeMessage.java @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.ooc.stream.message; + +import org.apache.sysds.runtime.util.IndexRange; + +import java.util.function.BiFunction; + +public class OOCGetStreamTypeMessage implements OOCStreamMessage { + public static final byte STREAM_TYPE_UNKNOWN = 0; + public static final byte STREAM_TYPE_CACHED = 1; + public static final byte STREAM_TYPE_IN_MEMORY = 2; + + private byte _streamType; + + public OOCGetStreamTypeMessage() { + _streamType = 0; + } + + @Override + public void addIXTransform(BiFunction transform) {} + + public void setUnknownType() { + _streamType = STREAM_TYPE_UNKNOWN; + } + + public void setCachedType() { + _streamType = STREAM_TYPE_CACHED; + } + + public void setInMemoryType() { + _streamType = STREAM_TYPE_IN_MEMORY; + } + + public byte getStreamType() { + return _streamType; + } + + public boolean isRequestable() { + return _streamType == STREAM_TYPE_CACHED || _streamType == STREAM_TYPE_IN_MEMORY; + } +} diff --git a/src/main/java/org/apache/sysds/runtime/ooc/stream/message/OOCStreamMessage.java b/src/main/java/org/apache/sysds/runtime/ooc/stream/message/OOCStreamMessage.java new file mode 100644 index 00000000000..26459a74ffa --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/ooc/stream/message/OOCStreamMessage.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.ooc.stream.message; + +import org.apache.sysds.runtime.util.IndexRange; + +import java.util.function.BiFunction; + +public interface OOCStreamMessage { + default boolean isCancelled() { + return false; + } + + default void cancel() {} + + default OOCStreamMessage split() { + return this; + } + + void addIXTransform(BiFunction transform); +} diff --git a/src/main/java/org/apache/sysds/runtime/ooc/util/OOCUtils.java b/src/main/java/org/apache/sysds/runtime/ooc/util/OOCUtils.java new file mode 100644 index 00000000000..c564748e1da --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/ooc/util/OOCUtils.java @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.ooc.util; + +import org.apache.sysds.runtime.matrix.data.MatrixIndexes; +import org.apache.sysds.runtime.util.IndexRange; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.List; + +public class OOCUtils { + public static IndexRange getRangeOfTile(MatrixIndexes tileIdx, long blen) { + long rs = 1 + tileIdx.getRowIndex() * blen; + long re = (tileIdx.getRowIndex() + 1) * blen; + long cs = 1 + tileIdx.getColumnIndex() * blen; + long ce = (tileIdx.getColumnIndex() + 1) * blen; + return new IndexRange(rs, re, cs, ce); + } + + public static Collection getTilesOfRange(IndexRange range, long blen) { + long rs = (range.rowStart - 1) / blen + 1; + long re = (range.rowEnd - 1) / blen + 1; + long cs = (range.colStart - 1) / blen + 1; + long ce = (range.colEnd - 1) / blen + 1; + + if(rs == re) { + if(cs == ce) { + return Collections.singleton(new MatrixIndexes(rs, cs)); + } + else { + List list = new ArrayList<>((int)(ce-cs+1)); + for(long i = cs; i <= ce; i++) + list.add(new MatrixIndexes(rs, i)); + return list; + } + } + + List list = new ArrayList<>((int)((re-rs+1)*(ce-cs+1))); + for(long r = rs; r <= re; r++) + for (long c = cs; c <= ce; c++) + list.add(new MatrixIndexes(r, c)); + return list; + } +} diff --git a/src/main/java/org/apache/sysds/runtime/util/OOCJoin.java b/src/main/java/org/apache/sysds/runtime/util/OOCJoin.java deleted file mode 100644 index 81265b8a2d2..00000000000 --- a/src/main/java/org/apache/sysds/runtime/util/OOCJoin.java +++ /dev/null @@ -1,69 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.sysds.runtime.util; - -import org.apache.logging.log4j.util.TriConsumer; -import org.apache.sysds.runtime.DMLRuntimeException; - -import java.util.HashMap; -import java.util.Map; - -public class OOCJoin { - private Map left; - private Map right; - private TriConsumer emitter; - - public OOCJoin(TriConsumer emitter) { - this.left = new HashMap<>(); - this.right = new HashMap<>(); - this.emitter = emitter; - } - - public void addLeft(T idx, O item) { - add(true, idx, item); - } - - public void addRight(T idx, O item) { - add(false, idx, item); - } - - public void close() { - synchronized (this) { - if (!left.isEmpty() || !right.isEmpty()) - throw new DMLRuntimeException("There are still unprocessed items in the OOC join"); - } - } - - public void add(boolean isLeft, T idx, O val) { - Map lookup = isLeft ? right : left; - Map store = isLeft ? left : right; - O val2; - - synchronized (this) { - val2 = lookup.remove(idx); - - if (val2 == null) - store.put(idx, val); - } - - if (val2 != null) - emitter.accept(idx, isLeft ? val : val2, isLeft ? val2 : val); - } -} diff --git a/src/main/java/org/apache/sysds/runtime/util/UtilFunctions.java b/src/main/java/org/apache/sysds/runtime/util/UtilFunctions.java index 04a0bd1ab8a..493366d5b18 100644 --- a/src/main/java/org/apache/sysds/runtime/util/UtilFunctions.java +++ b/src/main/java/org/apache/sysds/runtime/util/UtilFunctions.java @@ -1482,12 +1482,13 @@ public static IndexedMatrixValue createIndexedMatrixBlock(MatrixBlock mb, DataCh //compute block sizes int maxRow = UtilFunctions.computeBlockSize(mc.getRows(), blockRow+1, mc.getBlocksize()); int maxCol = UtilFunctions.computeBlockSize(mc.getCols(), blockCol+1, mc.getBlocksize()); + MatrixBlock block = null; //copy sub-matrix to block - MatrixBlock block = new MatrixBlock(maxRow, maxCol, mb.isInSparseFormat()); - int row_offset = (int)blockRow*mc.getBlocksize(); - int col_offset = (int)blockCol*mc.getBlocksize(); - block = mb.slice( row_offset, row_offset+maxRow-1, - col_offset, col_offset+maxCol-1, false, block ); + block = new MatrixBlock(maxRow, maxCol, mb.isInSparseFormat()); + int row_offset = (int) blockRow * mc.getBlocksize(); + int col_offset = (int) blockCol * mc.getBlocksize(); + block = mb.slice(row_offset, row_offset + maxRow - 1, col_offset, col_offset + maxCol - 1, false, + block); //create key-value pair return new IndexedMatrixValue(new MatrixIndexes(blockRow+1, blockCol+1), block); } @@ -1495,4 +1496,4 @@ public static IndexedMatrixValue createIndexedMatrixBlock(MatrixBlock mb, DataCh throw new RuntimeException(ex); } } -} \ No newline at end of file +} diff --git a/src/test/java/org/apache/sysds/runtime/ooc/cache/OOCLRUCacheSchedulerTest.java b/src/test/java/org/apache/sysds/runtime/ooc/cache/OOCLRUCacheSchedulerTest.java new file mode 100644 index 00000000000..66a46c03269 --- /dev/null +++ b/src/test/java/org/apache/sysds/runtime/ooc/cache/OOCLRUCacheSchedulerTest.java @@ -0,0 +1,256 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.ooc.cache; + +import org.apache.sysds.common.Types; +import org.apache.sysds.runtime.matrix.data.MatrixIndexes; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import java.lang.reflect.Field; +import java.util.ArrayList; +import java.util.Deque; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; + +public class OOCLRUCacheSchedulerTest { + private static final long ENTRY_SIZE = 1; + private static final long WAIT_TIMEOUT_SEC = 5; + + private FakeIOHandler _handler; + private OOCLRUCacheScheduler _scheduler; + + @Before + public void setUp() { + _handler = new FakeIOHandler(); + _scheduler = new OOCLRUCacheScheduler(_handler, 0, Long.MAX_VALUE); + } + + @After + public void tearDown() { + if (_scheduler != null) + _scheduler.shutdown(); + if (_handler != null) + _handler.shutdown(); + } + + @Test + public void testImmediateRequestPinsBlock() throws Exception { + FakeIOHandler handler = new FakeIOHandler(); + OOCLRUCacheScheduler scheduler = new OOCLRUCacheScheduler(handler, Long.MAX_VALUE, Long.MAX_VALUE); + try { + BlockKey key = new BlockKey(1, 1); + scheduler.put(key, new Object(), ENTRY_SIZE); + Assert.assertEquals(ENTRY_SIZE, scheduler.getCacheSize()); + + BlockEntry fetched = scheduler.request(key).get(WAIT_TIMEOUT_SEC, TimeUnit.SECONDS); + Assert.assertTrue(fetched.isPinned()); + scheduler.unpin(fetched); + Assert.assertEquals(ENTRY_SIZE, scheduler.getCacheSize()); + } + finally { + scheduler.shutdown(); + handler.shutdown(); + } + } + + @Test + public void testDeferredReadSingleBlock() throws Exception { + BlockKey key = new BlockKey(1, 1); + BlockEntry entry = putColdSourceBacked(key); + Assert.assertEquals(0, _scheduler.getCacheSize()); + + CompletableFuture future = _scheduler.request(key); + Assert.assertFalse(future.isDone()); + Assert.assertEquals(1, _handler.getReadCount(key)); + Assert.assertEquals(ENTRY_SIZE, _scheduler.getCacheSize()); + + _handler.completeRead(key); + + BlockEntry fetched = future.get(WAIT_TIMEOUT_SEC, TimeUnit.SECONDS); + Assert.assertTrue(fetched.isPinned()); + Assert.assertEquals(ENTRY_SIZE, _scheduler.getCacheSize()); + _scheduler.unpin(fetched); + Assert.assertEquals(0, _scheduler.getCacheSize()); + } + + @Test + public void testMergeOverlappingRequests() throws Exception { + BlockKey key1 = new BlockKey(1, 1); + BlockKey key2 = new BlockKey(1, 2); + BlockKey key3 = new BlockKey(1, 3); + putColdSourceBacked(key1); + putColdSourceBacked(key2); + putColdSourceBacked(key3); + Assert.assertEquals(0, _scheduler.getCacheSize()); + + CompletableFuture> reqA = _scheduler.request(List.of(key1, key2)); + CompletableFuture> reqB = _scheduler.request(List.of(key1, key3)); + + Assert.assertEquals(1, _handler.getReadCount(key1)); + Assert.assertEquals(1, _handler.getReadCount(key2)); + Assert.assertEquals(1, _handler.getReadCount(key3)); + Assert.assertEquals(ENTRY_SIZE * 3, _scheduler.getCacheSize()); + Assert.assertFalse(reqA.isDone()); + Assert.assertFalse(reqB.isDone()); + + _handler.completeRead(key1); + _handler.completeRead(key2); + + List resA = reqA.get(WAIT_TIMEOUT_SEC, TimeUnit.SECONDS); + Assert.assertFalse(reqB.isDone()); + Assert.assertEquals(ENTRY_SIZE * 3, _scheduler.getCacheSize()); + + _handler.completeRead(key3); + List resB = reqB.get(WAIT_TIMEOUT_SEC, TimeUnit.SECONDS); + Assert.assertEquals(ENTRY_SIZE * 3, _scheduler.getCacheSize()); + + resA.forEach(_scheduler::unpin); + resB.forEach(_scheduler::unpin); + Assert.assertEquals(0, _scheduler.getCacheSize()); + } + + @Test + public void testPrioritizeReordersDeferredRequests() throws Exception { + OOCLRUCacheScheduler scheduler = new OOCLRUCacheScheduler(_handler, 0, 0); + try { + BlockKey key1 = new BlockKey(1, 1); + BlockKey key2 = new BlockKey(1, 2); + BlockKey key3 = new BlockKey(1, 3); + putColdSourceBacked(scheduler, key1); + putColdSourceBacked(scheduler, key2); + putColdSourceBacked(scheduler, key3); + + scheduler.request(List.of(key1)); + scheduler.request(List.of(key2)); + scheduler.request(List.of(key3)); + + List before = snapshotDeferredOrder(scheduler); + Assert.assertEquals(List.of(key1, key2, key3), before); + + scheduler.prioritize(key3, 1); + + List after = snapshotDeferredOrder(scheduler); + Assert.assertEquals(List.of(key1, key3, key2), after); + } + finally { + scheduler.shutdown(); + } + } + + private BlockEntry putColdSourceBacked(BlockKey key) { + return putColdSourceBacked(_scheduler, key); + } + + private BlockEntry putColdSourceBacked(OOCLRUCacheScheduler scheduler, BlockKey key) { + OOCIOHandler.SourceBlockDescriptor desc = new OOCIOHandler.SourceBlockDescriptor( + "unused", Types.FileFormat.BINARY, new MatrixIndexes(1, 1), 0, 0, ENTRY_SIZE); + BlockEntry entry = scheduler.putAndPinSourceBacked(key, new Object(), ENTRY_SIZE, desc); + scheduler.unpin(entry); + Assert.assertEquals(BlockState.COLD, entry.getState()); + return entry; + } + + @SuppressWarnings("unchecked") + private static List snapshotDeferredOrder(OOCLRUCacheScheduler scheduler) throws Exception { + Field field = OOCLRUCacheScheduler.class.getDeclaredField("_deferredReadRequests"); + field.setAccessible(true); + Deque deque = (Deque) field.get(scheduler); + List order = new ArrayList<>(); + for (Object obj : deque) { + Field entriesField = obj.getClass().getDeclaredField("_entries"); + entriesField.setAccessible(true); + List entries = (List) entriesField.get(obj); + order.add(entries.get(0).getKey()); + } + return order; + } + + private static class FakeIOHandler implements OOCIOHandler { + private final Map> _readFutures = new HashMap<>(); + private final Map _readEntries = new HashMap<>(); + private final Map _readCounts = new HashMap<>(); + + @Override + public void shutdown() { + _readFutures.clear(); + _readEntries.clear(); + _readCounts.clear(); + } + + @Override + public CompletableFuture scheduleEviction(BlockEntry block) { + return CompletableFuture.completedFuture(null); + } + + @Override + public CompletableFuture scheduleRead(BlockEntry block) { + CompletableFuture future = new CompletableFuture<>(); + _readFutures.put(block.getKey(), future); + _readEntries.put(block.getKey(), block); + _readCounts.computeIfAbsent(block.getKey(), k -> new AtomicInteger(0)).incrementAndGet(); + return future; + } + + @Override + public void prioritizeRead(BlockKey key, double priority) {} + + @Override + public CompletableFuture scheduleDeletion(BlockEntry block) { + return CompletableFuture.completedFuture(true); + } + + @Override + public void registerSourceLocation(BlockKey key, SourceBlockDescriptor descriptor) { + } + + @Override + public CompletableFuture scheduleSourceRead(SourceReadRequest request) { + return CompletableFuture.failedFuture(new UnsupportedOperationException()); + } + + @Override + public CompletableFuture continueSourceRead(SourceReadContinuation continuation, long maxBytesInFlight) { + return CompletableFuture.failedFuture(new UnsupportedOperationException()); + } + + public int getReadCount(BlockKey key) { + AtomicInteger ctr = _readCounts.get(key); + return ctr == null ? 0 : ctr.get(); + } + + public void completeRead(BlockKey key) { + CompletableFuture future = _readFutures.get(key); + if (future == null) + throw new IllegalStateException("No scheduled read for " + key); + BlockEntry entry = _readEntries.get(key); + if (entry == null) + throw new IllegalStateException("No registered entry for " + key); + entry.setDataUnsafe(new Object()); + future.complete(entry); + } + } +} diff --git a/src/test/java/org/apache/sysds/test/functions/ooc/BinaryMatrixMatrixTest.java b/src/test/java/org/apache/sysds/test/functions/ooc/BinaryMatrixMatrixTest.java index dfa9413bfb0..69dfc0ae537 100644 --- a/src/test/java/org/apache/sysds/test/functions/ooc/BinaryMatrixMatrixTest.java +++ b/src/test/java/org/apache/sysds/test/functions/ooc/BinaryMatrixMatrixTest.java @@ -89,8 +89,8 @@ private void runBinaryMatrixMatrixTest(boolean sparse1, boolean sparse2) { programArgs = new String[] {"-explain", "-stats", "-ooc", "-args", input(INPUT_NAME_1), input(INPUT_NAME_2), output(OUTPUT_NAME)}; // 1. Generate the data in-memory as MatrixBlock objects - double[][] X_data = getRandomMatrix(rows, 1, 1, maxVal, sparse1 ? sparsity2 : sparsity1, 7); - double[][] Y_data = getRandomMatrix(rows, 1, 0, 1, sparse2 ? sparsity2 : sparsity1, 8); + double[][] X_data = getRandomMatrix(rows, cols, 1, maxVal, sparse1 ? sparsity2 : sparsity1, 7); + double[][] Y_data = getRandomMatrix(rows, cols, 0, 1, sparse2 ? sparsity2 : sparsity1, 8); // 2. Convert the double arrays to MatrixBlock objects MatrixBlock X_mb = DataConverter.convertToMatrixBlock(X_data); diff --git a/src/test/java/org/apache/sysds/test/functions/ooc/BinaryMatrixScalarTest.java b/src/test/java/org/apache/sysds/test/functions/ooc/BinaryMatrixScalarTest.java index e84d36e41b0..bd15a4c25cd 100644 --- a/src/test/java/org/apache/sysds/test/functions/ooc/BinaryMatrixScalarTest.java +++ b/src/test/java/org/apache/sysds/test/functions/ooc/BinaryMatrixScalarTest.java @@ -78,7 +78,7 @@ private void runBinaryMatrixScalarTest(boolean sparse) { programArgs = new String[] {"-explain", "-stats", "-ooc", "-args", input(INPUT_NAME_1), output(OUTPUT_NAME)}; // 1. Generate the data in-memory as MatrixBlock objects - double[][] X_data = getRandomMatrix(rows, 1, 1, maxVal, sparse ? sparsity2 : sparsity1, 7); + double[][] X_data = getRandomMatrix(rows, cols, 1, maxVal, sparse ? sparsity2 : sparsity1, 7); // 2. Convert the double arrays to MatrixBlock objects MatrixBlock X_mb = DataConverter.convertToMatrixBlock(X_data); diff --git a/src/test/java/org/apache/sysds/test/functions/ooc/JoinAccessPatternTest.java b/src/test/java/org/apache/sysds/test/functions/ooc/JoinAccessPatternTest.java new file mode 100644 index 00000000000..4dc22155a24 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/ooc/JoinAccessPatternTest.java @@ -0,0 +1,143 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.ooc; + +import org.apache.sysds.common.Opcodes; +import org.apache.sysds.common.Types; +import org.apache.sysds.runtime.instructions.Instruction; +import org.apache.sysds.runtime.io.MatrixWriter; +import org.apache.sysds.runtime.io.MatrixWriterFactory; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.meta.MatrixCharacteristics; +import org.apache.sysds.runtime.ooc.cache.OOCCacheManager; +import org.apache.sysds.runtime.util.DataConverter; +import org.apache.sysds.runtime.util.HDFSTool; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.junit.Assert; +import org.junit.Test; + +import java.io.IOException; + +public class JoinAccessPatternTest extends AutomatedTestBase { + private final static String TEST_NAME1 = "JoinAccessPattern"; + private final static String TEST_DIR = "functions/ooc/"; + private final static String TEST_CLASS_DIR = TEST_DIR + JoinAccessPatternTest.class.getSimpleName() + "/"; + private final static double eps = 1e-8; + private static final String INPUT_NAME_1 = "X"; + private static final String INPUT_NAME_2 = "Y"; + private static final String OUTPUT_NAME = "res"; + + private final static int rows = 2000; + private final static int cols = 1000; + private final static int maxVal = 7; + private final static double sparsity1 = 1; + private final static double sparsity2 = 0.05; + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + TestConfiguration config = new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1); + addTestConfiguration(TEST_NAME1, config); + } + + @Test + public void testBinaryMatrixMatrixDenseDense() { + runBinaryMatrixMatrixTest(false, false); + } + + @Test + public void testBinaryMatrixMatrixDenseSparse() { + runBinaryMatrixMatrixTest(false, true); + } + + @Test + public void testBinaryMatrixMatrixSparseDense() { + runBinaryMatrixMatrixTest(true, false); + } + + @Test + public void testBinaryMatrixMatrixSparseSparse() { + runBinaryMatrixMatrixTest(true, true); + } + + private void runBinaryMatrixMatrixTest(boolean sparse1, boolean sparse2) { + Types.ExecMode platformOld = setExecMode(Types.ExecMode.SINGLE_NODE); + + try { + getAndLoadTestConfiguration(TEST_NAME1); + + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + TEST_NAME1 + ".dml"; + programArgs = new String[] {"-explain", "-stats", "-ooc", "-args", input(INPUT_NAME_1), input(INPUT_NAME_2), output(OUTPUT_NAME)}; + + // 1. Generate the data in-memory as MatrixBlock objects + double[][] X_data = getRandomMatrix(rows, cols, 1, maxVal, sparse1 ? sparsity2 : sparsity1, 7); + double[][] Y_data = getRandomMatrix(rows, cols, 0, 1, sparse2 ? sparsity2 : sparsity1, 8); + + // 2. Convert the double arrays to MatrixBlock objects + MatrixBlock X_mb = DataConverter.convertToMatrixBlock(X_data); + MatrixBlock Y_mb = DataConverter.convertToMatrixBlock(Y_data); + + // 3. Create a binary matrix writer + MatrixWriter writer = MatrixWriterFactory.createMatrixWriter(Types.FileFormat.BINARY); + + // 4. Write matrix A to a binary SequenceFile + writer.writeMatrixToHDFS(X_mb, input(INPUT_NAME_1), rows, cols, 1000, X_mb.getNonZeros()); + writer.writeMatrixToHDFS(Y_mb, input(INPUT_NAME_2), rows, cols, 1000, Y_mb.getNonZeros()); + HDFSTool.writeMetaDataFile(input(INPUT_NAME_1 + ".mtd"), Types.ValueType.FP64, + new MatrixCharacteristics(rows, cols, 1000, X_mb.getNonZeros()), Types.FileFormat.BINARY); + HDFSTool.writeMetaDataFile(input(INPUT_NAME_2 + ".mtd"), Types.ValueType.FP64, + new MatrixCharacteristics(rows, cols, 1000, Y_mb.getNonZeros()), Types.FileFormat.BINARY); + + X_data = null; + Y_data = null; + X_mb = null; + Y_mb = null; + + OOCCacheManager.getCache().updateLimits(50000000, 100000000); + runTest(true, false, null, -1); + + //check tsmm OOC + Assert.assertTrue("OOC wasn't used for multiplication", + heavyHittersContainsString(Instruction.OOC_INST_PREFIX + Opcodes.MULT)); + + //compare results + + // rerun without ooc flag + programArgs = new String[] {"-explain", "-stats", "-args", input(INPUT_NAME_1), input(INPUT_NAME_2), output(OUTPUT_NAME + "_target")}; + runTest(true, false, null, -1); + + // compare matrices + MatrixBlock ret1 = DataConverter.readMatrixFromHDFS(output(OUTPUT_NAME), + Types.FileFormat.BINARY, rows, cols, 1000); + MatrixBlock ret2 = DataConverter.readMatrixFromHDFS(output(OUTPUT_NAME + "_target"), + Types.FileFormat.BINARY, rows, cols, 1000); + TestUtils.compareMatrices(ret1, ret2, eps); + } + catch(IOException e) { + throw new RuntimeException(e); + } + finally { + resetExecMode(platformOld); + } + } +} diff --git a/src/test/java/org/apache/sysds/test/functions/ooc/LmCGTest.java b/src/test/java/org/apache/sysds/test/functions/ooc/LmCGTest.java index 72650daf9c2..5b348de2ed9 100644 --- a/src/test/java/org/apache/sysds/test/functions/ooc/LmCGTest.java +++ b/src/test/java/org/apache/sysds/test/functions/ooc/LmCGTest.java @@ -43,7 +43,7 @@ public class LmCGTest extends AutomatedTestBase { private static final String OUTPUT_NAME = "res"; private final static int rows = 10000; - private final static int cols = 500; + private final static int cols = 1500; private final static int maxVal = 2; private final static double sparsity1 = 1; private final static double sparsity2 = 0.05; @@ -65,6 +65,7 @@ public void testLmCGSparse() { runLmCGTest(true); } + // TODO codex resume 019bb84d-bac6-7fd1-bfb8-a149e715e5b5 private void runLmCGTest(boolean sparse) { Types.ExecMode platformOld = setExecMode(Types.ExecMode.SINGLE_NODE); diff --git a/src/test/java/org/apache/sysds/test/functions/ooc/LoopSubtractTest.java b/src/test/java/org/apache/sysds/test/functions/ooc/LoopSubtractTest.java new file mode 100644 index 00000000000..c72ef538f73 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/ooc/LoopSubtractTest.java @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.ooc; + +import org.apache.sysds.common.Types; +import org.apache.sysds.runtime.io.MatrixWriter; +import org.apache.sysds.runtime.io.MatrixWriterFactory; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.meta.MatrixCharacteristics; +import org.apache.sysds.runtime.ooc.cache.OOCCacheManager; +import org.apache.sysds.runtime.util.DataConverter; +import org.apache.sysds.runtime.util.HDFSTool; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.junit.Test; + +import java.io.IOException; + +public class LoopSubtractTest extends AutomatedTestBase { + private static final String TEST_NAME = "LoopSubtract"; + private static final String TEST_DIR = "functions/ooc/"; + private static final String TEST_CLASS_DIR = TEST_DIR + LoopSubtractTest.class.getSimpleName() + "/"; + private static final double eps = 1e-8; + private static final String INPUT_NAME_X = "X"; + private static final String INPUT_NAME_Y = "Y"; + private static final String OUTPUT_NAME = "res"; + + private static final int rows = 2000; + private static final int cols = 800; + private static final double sparsity = 0.8; + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + TestConfiguration config = new TestConfiguration(TEST_CLASS_DIR, TEST_NAME); + addTestConfiguration(TEST_NAME, config); + } + + @Test + public void testLoopSubtractOOC() { + Types.ExecMode platformOld = setExecMode(Types.ExecMode.SINGLE_NODE); + + try { + getAndLoadTestConfiguration(TEST_NAME); + + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + TEST_NAME + ".dml"; + programArgs = new String[] {"-explain", "-stats", "-ooc", "-oocStats", + "-args", input(INPUT_NAME_X), input(INPUT_NAME_Y), output(OUTPUT_NAME)}; + + double[][] X_data = getRandomMatrix(rows, cols, 1, 7, sparsity, 7); + double[][] Y_data = getRandomMatrix(rows, cols, 0, 1, sparsity, 8); + + MatrixBlock X_mb = DataConverter.convertToMatrixBlock(X_data); + MatrixBlock Y_mb = DataConverter.convertToMatrixBlock(Y_data); + + MatrixWriter writer = MatrixWriterFactory.createMatrixWriter(Types.FileFormat.BINARY); + writer.writeMatrixToHDFS(X_mb, input(INPUT_NAME_X), rows, cols, 1000, X_mb.getNonZeros()); + writer.writeMatrixToHDFS(Y_mb, input(INPUT_NAME_Y), rows, cols, 1000, Y_mb.getNonZeros()); + HDFSTool.writeMetaDataFile(input(INPUT_NAME_X + ".mtd"), Types.ValueType.FP64, + new MatrixCharacteristics(rows, cols, 1000, X_mb.getNonZeros()), Types.FileFormat.BINARY); + HDFSTool.writeMetaDataFile(input(INPUT_NAME_Y + ".mtd"), Types.ValueType.FP64, + new MatrixCharacteristics(rows, cols, 1000, Y_mb.getNonZeros()), Types.FileFormat.BINARY); + + OOCCacheManager.getCache().updateLimits(60000000, 100000000); + runTest(true, false, null, -1); + + programArgs = new String[] {"-explain", "-stats", + "-args", input(INPUT_NAME_X), input(INPUT_NAME_Y), output(OUTPUT_NAME + "_target")}; + runTest(true, false, null, -1); + + MatrixBlock ret1 = DataConverter.readMatrixFromHDFS(output(OUTPUT_NAME), + Types.FileFormat.BINARY, rows, cols, 1000); + MatrixBlock ret2 = DataConverter.readMatrixFromHDFS(output(OUTPUT_NAME + "_target"), + Types.FileFormat.BINARY, rows, cols, 1000); + TestUtils.compareMatrices(ret1, ret2, eps); + } + catch (IOException e) { + throw new RuntimeException(e); + } + finally { + resetExecMode(platformOld); + } + } +} diff --git a/src/test/java/org/apache/sysds/test/functions/ooc/MatrixMatrixBinaryMultiplicationTest.java b/src/test/java/org/apache/sysds/test/functions/ooc/MatrixMatrixBinaryMultiplicationTest.java new file mode 100644 index 00000000000..23f74519165 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/ooc/MatrixMatrixBinaryMultiplicationTest.java @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.ooc; + +import org.apache.sysds.common.Types; +import org.apache.sysds.runtime.io.MatrixWriter; +import org.apache.sysds.runtime.io.MatrixWriterFactory; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.meta.MatrixCharacteristics; +import org.apache.sysds.runtime.util.DataConverter; +import org.apache.sysds.runtime.util.HDFSTool; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.junit.Test; + +import java.io.IOException; + +public class MatrixMatrixBinaryMultiplicationTest extends AutomatedTestBase { + private final static String TEST_NAME1 = "MatrixMatrixMultiplication"; + private final static String TEST_DIR = "functions/ooc/"; + private final static String TEST_CLASS_DIR = TEST_DIR + MatrixMatrixBinaryMultiplicationTest.class.getSimpleName() + "/"; + private final static double eps = 1e-9; + private static final String INPUT_NAME = "X"; + private static final String INPUT_NAME2 = "v"; + private static final String OUTPUT_NAME = "res"; + + private final static int rows = 2000; + private final static int cols_wide = 3000; + private final static int cols_skinny = 500; + + private final static double sparsity1 = 0.7; + private final static double sparsity2 = 0.1; + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + TestConfiguration config = new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1); + addTestConfiguration(TEST_NAME1, config); + } + + @Test + public void testMVBinaryMultiplication1() { + runMatrixVectorMultiplicationTest(cols_wide, false); + } + + @Test + public void testMVBinaryMultiplication2() { + runMatrixVectorMultiplicationTest(cols_skinny, false); + } + + private void runMatrixVectorMultiplicationTest(int cols, boolean sparse ) + { + Types.ExecMode platformOld = setExecMode(Types.ExecMode.SINGLE_NODE); + + try + { + getAndLoadTestConfiguration(TEST_NAME1); + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + TEST_NAME1 + ".dml"; + programArgs = new String[]{"-explain", "-stats", "-ooc", + "-args", input(INPUT_NAME), input(INPUT_NAME2), output(OUTPUT_NAME)}; + + // 3. Create a binary matrix writer + MatrixWriter writer = MatrixWriterFactory.createMatrixWriter(Types.FileFormat.BINARY); + + // 1. Generate the data in-memory as MatrixBlock objects + double[][] A_data = getRandomMatrix(rows, cols, 0, 1, sparse?sparsity2:sparsity1, 10); + MatrixBlock A_mb = DataConverter.convertToMatrixBlock(A_data); + writer.writeMatrixToHDFS(A_mb, input(INPUT_NAME), rows, cols, 1000, A_mb.getNonZeros()); + HDFSTool.writeMetaDataFile(input(INPUT_NAME + ".mtd"), Types.ValueType.FP64, + new MatrixCharacteristics(rows, cols, 1000, A_mb.getNonZeros()), Types.FileFormat.BINARY); + + A_data = null; + A_mb = null; + + double[][] x_data = getRandomMatrix(cols, rows, 0, 1, 1.0, 10); + MatrixBlock x_mb = DataConverter.convertToMatrixBlock(x_data); + writer.writeMatrixToHDFS(x_mb, input(INPUT_NAME2), cols, rows, 1000, x_mb.getNonZeros()); + HDFSTool.writeMetaDataFile(input(INPUT_NAME2 + ".mtd"), Types.ValueType.FP64, + new MatrixCharacteristics(cols, rows, 1000, x_mb.getNonZeros()), Types.FileFormat.BINARY); + + x_data = null; + x_mb = null; + + boolean exceptionExpected = false; + runTest(true, exceptionExpected, null, -1); + + programArgs = new String[]{"-explain", "-stats", + "-args", input(INPUT_NAME), input(INPUT_NAME2), output(OUTPUT_NAME + "_target")}; + runTest(true, exceptionExpected, null, -1); + + // compare matrices + MatrixBlock ret1 = DataConverter.readMatrixFromHDFS(output(OUTPUT_NAME), + Types.FileFormat.BINARY, rows, rows, 1000); + MatrixBlock ret2 = DataConverter.readMatrixFromHDFS(output(OUTPUT_NAME + "_target"), + Types.FileFormat.BINARY, rows, rows, 1000); + TestUtils.compareMatrices(ret1, ret2, eps); + } + catch (IOException e) { + throw new RuntimeException(e); + } + finally { + resetExecMode(platformOld); + } + } + + private static double[][] readMatrix(String fname, Types.FileFormat fmt, long rows, long cols, int brows, int bcols ) + throws IOException + { + MatrixBlock mb = DataConverter.readMatrixFromHDFS(fname, fmt, rows, cols, brows, bcols); + double[][] C = DataConverter.convertToDoubleMatrix(mb); + return C; + } +} diff --git a/src/test/java/org/apache/sysds/test/functions/ooc/TSMMPlusTest.java b/src/test/java/org/apache/sysds/test/functions/ooc/TSMMPlusTest.java new file mode 100644 index 00000000000..fb97855443f --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/ooc/TSMMPlusTest.java @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.ooc; + +import org.apache.sysds.common.Types; +import org.apache.sysds.runtime.io.MatrixWriter; +import org.apache.sysds.runtime.io.MatrixWriterFactory; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.meta.MatrixCharacteristics; +import org.apache.sysds.runtime.util.DataConverter; +import org.apache.sysds.runtime.util.HDFSTool; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.junit.Test; + +import java.io.IOException; + +public class TSMMPlusTest extends AutomatedTestBase { + private static final String TEST_NAME = "TSMMPlus"; + private static final String TEST_DIR = "functions/ooc/"; + private static final String TEST_CLASS_DIR = TEST_DIR + TSMMPlusTest.class.getSimpleName() + "/"; + private static final double eps = 1e-8; + private static final String INPUT_NAME_X = "X"; + private static final String INPUT_NAME_Y = "Y"; + private static final String OUTPUT_NAME = "res"; + + private static final int rows = 10000; + private static final int cols = 500; + private static final double sparsityX = 0.9; + private static final double sparsityY = 0.1; + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + TestConfiguration config = new TestConfiguration(TEST_CLASS_DIR, TEST_NAME); + addTestConfiguration(TEST_NAME, config); + } + + @Test + public void testTSMMPlusOOC() { + Types.ExecMode platformOld = setExecMode(Types.ExecMode.SINGLE_NODE); + + try { + getAndLoadTestConfiguration(TEST_NAME); + + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + TEST_NAME + ".dml"; + programArgs = new String[] {"-explain", "-stats", "-ooc", + "-args", input(INPUT_NAME_X), input(INPUT_NAME_Y), output(OUTPUT_NAME)}; + + double[][] X_data = getRandomMatrix(rows, cols, 1, 7, sparsityX, 7); + double[][] Y_data = getRandomMatrix(rows, cols, 0, 1, sparsityY, 8); + + MatrixBlock X_mb = DataConverter.convertToMatrixBlock(X_data); + MatrixBlock Y_mb = DataConverter.convertToMatrixBlock(Y_data); + + MatrixWriter writer = MatrixWriterFactory.createMatrixWriter(Types.FileFormat.BINARY); + writer.writeMatrixToHDFS(X_mb, input(INPUT_NAME_X), rows, cols, 1000, X_mb.getNonZeros()); + writer.writeMatrixToHDFS(Y_mb, input(INPUT_NAME_Y), rows, cols, 1000, Y_mb.getNonZeros()); + HDFSTool.writeMetaDataFile(input(INPUT_NAME_X + ".mtd"), Types.ValueType.FP64, + new MatrixCharacteristics(rows, cols, 1000, X_mb.getNonZeros()), Types.FileFormat.BINARY); + HDFSTool.writeMetaDataFile(input(INPUT_NAME_Y + ".mtd"), Types.ValueType.FP64, + new MatrixCharacteristics(rows, cols, 1000, Y_mb.getNonZeros()), Types.FileFormat.BINARY); + + runTest(true, false, null, -1); + + programArgs = new String[] {"-explain", "-stats", + "-args", input(INPUT_NAME_X), input(INPUT_NAME_Y), output(OUTPUT_NAME + "_target")}; + runTest(true, false, null, -1); + + MatrixBlock ret1 = DataConverter.readMatrixFromHDFS(output(OUTPUT_NAME), + Types.FileFormat.BINARY, rows, cols, 1000); + MatrixBlock ret2 = DataConverter.readMatrixFromHDFS(output(OUTPUT_NAME + "_target"), + Types.FileFormat.BINARY, rows, cols, 1000); + TestUtils.compareMatrices(ret1, ret2, eps); + } + catch (IOException e) { + throw new RuntimeException(e); + } + finally { + resetExecMode(platformOld); + } + } +} diff --git a/src/test/scripts/functions/ooc/JoinAccessPattern.dml b/src/test/scripts/functions/ooc/JoinAccessPattern.dml new file mode 100644 index 00000000000..36a162e5ee8 --- /dev/null +++ b/src/test/scripts/functions/ooc/JoinAccessPattern.dml @@ -0,0 +1,31 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +# Read the input matrix as a stream +X = read($1); +Y = read($2); + +for (i in 1:15) { + X = X * Y / X + X - Y * X; + X = X * 5; +} + +write(X, $3, format="binary"); diff --git a/src/test/scripts/functions/ooc/LoopSubtract.dml b/src/test/scripts/functions/ooc/LoopSubtract.dml new file mode 100644 index 00000000000..4725b99c429 --- /dev/null +++ b/src/test/scripts/functions/ooc/LoopSubtract.dml @@ -0,0 +1,29 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +X = read($1); +Y = read($2); + +for (i in 1:20) { + X = X - Y; +} + +write(X, $3, format="binary"); diff --git a/src/test/scripts/functions/ooc/MatrixMatrixMultiplication.dml b/src/test/scripts/functions/ooc/MatrixMatrixMultiplication.dml new file mode 100644 index 00000000000..327fbc084f9 --- /dev/null +++ b/src/test/scripts/functions/ooc/MatrixMatrixMultiplication.dml @@ -0,0 +1,30 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +# Read input matrix and operator from command line args +X = read($1); +Y = read($2); + +# Operation under test +res = X %*% Y; + +#print(max(res)) +write(res, $3, format="binary") diff --git a/src/test/scripts/functions/ooc/TSMMPlus.dml b/src/test/scripts/functions/ooc/TSMMPlus.dml new file mode 100644 index 00000000000..1a488d2901e --- /dev/null +++ b/src/test/scripts/functions/ooc/TSMMPlus.dml @@ -0,0 +1,28 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +X = read($1); +Y = read($2); + +A = X / 7 - 5; +res = A + Y; + +write(res, $3, format="binary");