Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
340 changes: 340 additions & 0 deletions src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixReorg.java
Original file line number Diff line number Diff line change
Expand Up @@ -4635,4 +4635,344 @@ private static int eulerTotient(int[] primes, int[] exponents, int[] iExponents,
}
return count;
}

// TENSOR
/**
* Performs prime in-place tensor transposition for arbitrary permutations.
*
* @param in
* Tensor stored as MatrixBlock
* @param shape
* Original shape informtion of tensor
* @param perm
* Permutation of tensor
*/
// (A) If permutation is split-index reducible -> reduce to 2D and use transposeInPlaceDenseBrenner()
// (B) Else -> decompose perm into adjacent swaps and apply each via 1324 primitive from EITHOT algorithm
// (https://dl.acm.org/doi/10.1145/3711871)
// with Brenner's method instead of Catanzaro's algorithm for generalizability to arbitrary large dimensions
// ------------------------------------------------------------
public static boolean transposeInPlaceTensor(MatrixBlock in, int[] shape, int[] perm) {
final int rank = shape.length;

// final shape
final int[] finalShape = new int[rank];
for (int i = 0; i < rank; i++)
finalShape[i] = shape[perm[i]];

// Identity perm -> metadata only
boolean identity = true;
for (int i = 0; i < rank; i++) {
if (perm[i] != i) {
identity = false;
break;
}
}
if (identity) {
restoreMetadata(in, finalShape);
return true;
}

// (A) Split-index reducible
int splitIdx = findSplitIndex(perm);
if (splitIdx != -1) {
int newRows = 1;
for (int i = 0; i < splitIdx; i++)
newRows *= shape[perm[i]];

long newColsL = 1;
for (int i = splitIdx; i < rank; i++)
newColsL *= shape[perm[i]];
int newCols = (int) newColsL;

try {

in.setNumRows(newCols);
in.setNumColumns(newRows);

transposeInPlaceDenseBrenner(in, 1);
} finally {
restoreMetadata(in, finalShape);
}
return true;
}

// (B) General path: usage of 1324 primitv

final double[] tensor = in.getDenseBlockValues();
int[] curShape = Arrays.copyOf(shape, rank);
in.getDenseBlock().setDims(curShape);

// plan adjacent swaps to realize perm
int[] swaps = permutationToAdjacentSwaps(rank, perm);
int swapCount = swaps[0];

for (int s = 1; s <= swapCount; s++) {
int k = swaps[s];
reshape1324(in, tensor, curShape, k);
int tmp = curShape[k];
curShape[k] = curShape[k + 1];
curShape[k + 1] = tmp;

in.getDenseBlock().setDims(curShape);
}

restoreMetadata(in, finalShape);
return true;
}

/**
* Applies a single adjacent-axis swap (k <-> k+1) to a dense tensor **in-place** by reducing it to a rank-4 view
* and calling primitive 1324.
*
* @param in
* MatrixBlock holding the dense tensor buffer (metadata is temporarily modified)
* @param a
* backing dense buffer (row-major, contiguous)
* @param curShape
* current logical tensor shape/order before applying this adjacent swap
* @param k
* adjacent axis index to swap (swaps axis k with axis k+1)
*/
private static void reshape1324(MatrixBlock in, double[] a, int[] curShape, int k) {
int lastDim = curShape.length;

int left = prod(curShape, 0, k);
int A = curShape[k];
int B = curShape[k + 1];
int right = prod(curShape, k + 2, lastDim);

// metadata-only reshape to 4D
in.getDenseBlock().setDims(new int[] { left, A, B, right });

// in-place 1324 on that view
prim1324(a, 0, left, A, B, right);

// caller restores dims to curShape after it swaps curShape[k],curShape[k+1]
}

private static int prod(int[] shape, int start, int end) {
long p = 1;
for (int i = start; i < end; i++) {
p *= shape[i];
}
return (int) p;
}

/**
* Decomposes an arbitrary permutation into a sequence of adjacent swaps.
*
* @param rank
* tensor rank
* @param perm
* target permutation (maps output axis i to input axis perm[i])
*
* @return swap plan array
*/
private static int[] permutationToAdjacentSwaps(int rank, int[] perm) {
int[] order = new int[rank];
// original order of permutation
for (int i = 0; i < rank; i++)
order[i] = i;

int maxSwaps = rank * (rank - 1) / 2;
int[] out = new int[maxSwaps + 1]; // stores swap order
int cnt = 0; // number of swaps needed

for (int targetPos = 0; targetPos < rank; targetPos++) {
int wantedAxis = perm[targetPos];

// index of dimension in current permutation
int curPos = -1;
for (int p = targetPos; p < rank; p++) {
if (order[p] == wantedAxis) {
curPos = p;
break;
}
}
if (curPos < 0)
throw new IllegalArgumentException("Invalid perm");

while (curPos > targetPos) {
int t = order[curPos - 1];
order[curPos - 1] = order[curPos];
order[curPos] = t;

out[++cnt] = curPos - 1;
curPos--;
}
}

out[0] = cnt;
return out;
}

/**
* Primitive {@code 1324}: swaps dimensions 2 and 3 while keeping dimensions 1 and 4 fixed:
*
* @param a
* dense buffer
* @param offset
* base offset into {a} (usually 0)
* @param d1
* first dimension (number of slices)
* @param d2
* second dimension (matrix rows)
* @param d3
* third dimension (matrix cols)
* @param d4
* fourth dimension (block length per matrix cell)
*/
private static void prim1324(double[] a, int offset, int d1, int d2, int d3, int d4) {
for (int i1 = 0; i1 < d1; i1++) {
int slice = d2 * d3 * d4;
int base = offset + i1 * slice;
transposeBlocksInPlace(a, base, d2, d3, d4);
}
}

/**
* In-place transpose of an matrix where each element is a contiguous block of length {blk}. Performs a cycle-walk
* permutation over block positions induced by transpose. For each unvisited start position, we rotate blocks along
* its cycle using one temporary block buffer.
*
* @param tensor
* backing dense buffer
* @param base
* offset of the (m*n*blk) region
* @param d2
* number of rows in the block-matrix
* @param d3
* number of columns in the block-matrix
* @param blk
* block length (number of doubles per cell), d4
*/
private static void transposeBlocksInPlace(double[] tensor, int base, int d2, int d3, int blk) {
int numBlocks = d2 * d3;
boolean[] visited = new boolean[numBlocks];
double[] tmp = new double[blk]; // buffer for one block

for (int start = 0; start < numBlocks; start++) {
if (visited[start])
continue;

int next = transposeBlockIndex(start, d2, d3);

// no movement
if (next == start) {
visited[start] = true;
continue;
}

// save start
System.arraycopy(tensor, base + start * blk, tmp, 0, blk);

// cycle-following
int cur = start;
while (true) {
visited[cur] = true;
int prev = inverseTransposeBlockIndex(cur, d2, d3);
if (prev == start)
break;

System.arraycopy(tensor, base + prev * blk, tensor, base + cur * blk, blk);
cur = prev;
}

System.arraycopy(tmp, 0, tensor, base + cur * blk, blk);
visited[cur] = true;
}
}

/**
* Finds the target index of a current block
*
* @param block_idx
* index of block
* @param m
* number of rows
* @param n
* number of columns
*
* @return new block idx
*/
private static int transposeBlockIndex(int block_idx, int m, int n) {
int i = block_idx / n;
int j = block_idx % n;
return j * m + i;
}

/**
* Finds the idx of the element which moves to the current block index duing permutation
*
* @param curr_block_idx
* index of current block
* @param m
* number of rows
* @param n
* number of columns
*
* @return new block idx
*/
private static int inverseTransposeBlockIndex(int curr_block_idx, int m, int n) {
int i = curr_block_idx % m;
int j = curr_block_idx / m;
return i * n + j;
}

/**
* Finds a split index for a tensor permutation that allows reduction of the permutation to a 2D matrix transpose.
* @param perm permutation of tensor axes
* @return split index {i} if reducible, otherwise {-1}
*/
public static int findSplitIndex(int[] perm) {
if (perm == null || perm.length < 2)
return -1;
int n = perm.length;

for (int i = 1; i < n; i++) {
boolean contiguousFirst = isContiguousRange(perm, 0, i);
boolean contiguousSecond = isContiguousRange(perm, i, n);

if (contiguousFirst && contiguousSecond) {
if (isSorted(perm, 0, i) && isSorted(perm, i, n)) {
return i;
}
}
}
return -1;
}

private static boolean isSorted(int[] perm, int start, int end) {
for (int i = start; i < end - 1; i++)
if (perm[i] > perm[i + 1])
return false;
return true;
}

private static boolean isContiguousRange(int[] perm, int start, int end) {
int min = perm[start], max = perm[start];
for (int i = start + 1; i < end; i++) {
if (perm[i] < min)
min = perm[i];
if (perm[i] > max)
max = perm[i];
}
return (max - min + 1) == (end - start);
}

/**
* Restores SystemDS matrix/tensor metadata after an in-place tensor permutation.
* @param in matrix/tensor block whose metadata is restored
* @param finalShape final tensor shape after permutation
*/
private static void restoreMetadata(MatrixBlock in, int[] finalShape) {
in.setNumRows(finalShape[0]);
long totalRemaining = 1;
for (int i = 1; i < finalShape.length; i++)
totalRemaining *= finalShape[i];
in.setNumColumns((int) totalRemaining);
if (in.getDenseBlock() != null)
in.getDenseBlock().setDims(finalShape);
}
}
17 changes: 17 additions & 0 deletions src/test/java/org/apache/sysds/test/TestUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -3990,4 +3990,21 @@ public int numBlocks() {
return 2;
}
}

public static void compareTensorValues(MatrixBlock actual, MatrixBlock expected, double epsilon) {
double[] a = actual.getDenseBlockValues();
double[] e = expected.getDenseBlockValues();

if (a.length != e.length) {
throw new AssertionError("Length mismatch: expected " + e.length + " values, but got " + a.length);
}

for (int i = 0; i < e.length; i++) {
double diff = Math.abs(e[i] - a[i]);
if (diff > epsilon) {
throw new AssertionError(
"Mismatch at linear index " + i + ": expected=" + e[i] + ", actual=" + a[i] + ", diff=" + diff);
}
}
}
}
Loading