diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixReorg.java b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixReorg.java index ffd7b17a20c..f8dbf1e2c6d 100644 --- a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixReorg.java +++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixReorg.java @@ -4635,4 +4635,344 @@ private static int eulerTotient(int[] primes, int[] exponents, int[] iExponents, } return count; } + + // TENSOR + /** + * Performs prime in-place tensor transposition for arbitrary permutations. + * + * @param in + * Tensor stored as MatrixBlock + * @param shape + * Original shape informtion of tensor + * @param perm + * Permutation of tensor + */ + // (A) If permutation is split-index reducible -> reduce to 2D and use transposeInPlaceDenseBrenner() + // (B) Else -> decompose perm into adjacent swaps and apply each via 1324 primitive from EITHOT algorithm + // (https://dl.acm.org/doi/10.1145/3711871) + // with Brenner's method instead of Catanzaro's algorithm for generalizability to arbitrary large dimensions + // ------------------------------------------------------------ + public static boolean transposeInPlaceTensor(MatrixBlock in, int[] shape, int[] perm) { + final int rank = shape.length; + + // final shape + final int[] finalShape = new int[rank]; + for (int i = 0; i < rank; i++) + finalShape[i] = shape[perm[i]]; + + // Identity perm -> metadata only + boolean identity = true; + for (int i = 0; i < rank; i++) { + if (perm[i] != i) { + identity = false; + break; + } + } + if (identity) { + restoreMetadata(in, finalShape); + return true; + } + + // (A) Split-index reducible + int splitIdx = findSplitIndex(perm); + if (splitIdx != -1) { + int newRows = 1; + for (int i = 0; i < splitIdx; i++) + newRows *= shape[perm[i]]; + + long newColsL = 1; + for (int i = splitIdx; i < rank; i++) + newColsL *= shape[perm[i]]; + int newCols = (int) newColsL; + + try { + + in.setNumRows(newCols); + in.setNumColumns(newRows); + + transposeInPlaceDenseBrenner(in, 1); + } finally { + restoreMetadata(in, finalShape); + } + return true; + } + + // (B) General path: usage of 1324 primitv + + final double[] tensor = in.getDenseBlockValues(); + int[] curShape = Arrays.copyOf(shape, rank); + in.getDenseBlock().setDims(curShape); + + // plan adjacent swaps to realize perm + int[] swaps = permutationToAdjacentSwaps(rank, perm); + int swapCount = swaps[0]; + + for (int s = 1; s <= swapCount; s++) { + int k = swaps[s]; + reshape1324(in, tensor, curShape, k); + int tmp = curShape[k]; + curShape[k] = curShape[k + 1]; + curShape[k + 1] = tmp; + + in.getDenseBlock().setDims(curShape); + } + + restoreMetadata(in, finalShape); + return true; + } + + /** + * Applies a single adjacent-axis swap (k <-> k+1) to a dense tensor **in-place** by reducing it to a rank-4 view + * and calling primitive 1324. + * + * @param in + * MatrixBlock holding the dense tensor buffer (metadata is temporarily modified) + * @param a + * backing dense buffer (row-major, contiguous) + * @param curShape + * current logical tensor shape/order before applying this adjacent swap + * @param k + * adjacent axis index to swap (swaps axis k with axis k+1) + */ + private static void reshape1324(MatrixBlock in, double[] a, int[] curShape, int k) { + int lastDim = curShape.length; + + int left = prod(curShape, 0, k); + int A = curShape[k]; + int B = curShape[k + 1]; + int right = prod(curShape, k + 2, lastDim); + + // metadata-only reshape to 4D + in.getDenseBlock().setDims(new int[] { left, A, B, right }); + + // in-place 1324 on that view + prim1324(a, 0, left, A, B, right); + + // caller restores dims to curShape after it swaps curShape[k],curShape[k+1] + } + + private static int prod(int[] shape, int start, int end) { + long p = 1; + for (int i = start; i < end; i++) { + p *= shape[i]; + } + return (int) p; + } + + /** + * Decomposes an arbitrary permutation into a sequence of adjacent swaps. + * + * @param rank + * tensor rank + * @param perm + * target permutation (maps output axis i to input axis perm[i]) + * + * @return swap plan array + */ + private static int[] permutationToAdjacentSwaps(int rank, int[] perm) { + int[] order = new int[rank]; + // original order of permutation + for (int i = 0; i < rank; i++) + order[i] = i; + + int maxSwaps = rank * (rank - 1) / 2; + int[] out = new int[maxSwaps + 1]; // stores swap order + int cnt = 0; // number of swaps needed + + for (int targetPos = 0; targetPos < rank; targetPos++) { + int wantedAxis = perm[targetPos]; + + // index of dimension in current permutation + int curPos = -1; + for (int p = targetPos; p < rank; p++) { + if (order[p] == wantedAxis) { + curPos = p; + break; + } + } + if (curPos < 0) + throw new IllegalArgumentException("Invalid perm"); + + while (curPos > targetPos) { + int t = order[curPos - 1]; + order[curPos - 1] = order[curPos]; + order[curPos] = t; + + out[++cnt] = curPos - 1; + curPos--; + } + } + + out[0] = cnt; + return out; + } + + /** + * Primitive {@code 1324}: swaps dimensions 2 and 3 while keeping dimensions 1 and 4 fixed: + * + * @param a + * dense buffer + * @param offset + * base offset into {a} (usually 0) + * @param d1 + * first dimension (number of slices) + * @param d2 + * second dimension (matrix rows) + * @param d3 + * third dimension (matrix cols) + * @param d4 + * fourth dimension (block length per matrix cell) + */ + private static void prim1324(double[] a, int offset, int d1, int d2, int d3, int d4) { + for (int i1 = 0; i1 < d1; i1++) { + int slice = d2 * d3 * d4; + int base = offset + i1 * slice; + transposeBlocksInPlace(a, base, d2, d3, d4); + } + } + + /** + * In-place transpose of an matrix where each element is a contiguous block of length {blk}. Performs a cycle-walk + * permutation over block positions induced by transpose. For each unvisited start position, we rotate blocks along + * its cycle using one temporary block buffer. + * + * @param tensor + * backing dense buffer + * @param base + * offset of the (m*n*blk) region + * @param d2 + * number of rows in the block-matrix + * @param d3 + * number of columns in the block-matrix + * @param blk + * block length (number of doubles per cell), d4 + */ + private static void transposeBlocksInPlace(double[] tensor, int base, int d2, int d3, int blk) { + int numBlocks = d2 * d3; + boolean[] visited = new boolean[numBlocks]; + double[] tmp = new double[blk]; // buffer for one block + + for (int start = 0; start < numBlocks; start++) { + if (visited[start]) + continue; + + int next = transposeBlockIndex(start, d2, d3); + + // no movement + if (next == start) { + visited[start] = true; + continue; + } + + // save start + System.arraycopy(tensor, base + start * blk, tmp, 0, blk); + + // cycle-following + int cur = start; + while (true) { + visited[cur] = true; + int prev = inverseTransposeBlockIndex(cur, d2, d3); + if (prev == start) + break; + + System.arraycopy(tensor, base + prev * blk, tensor, base + cur * blk, blk); + cur = prev; + } + + System.arraycopy(tmp, 0, tensor, base + cur * blk, blk); + visited[cur] = true; + } + } + + /** + * Finds the target index of a current block + * + * @param block_idx + * index of block + * @param m + * number of rows + * @param n + * number of columns + * + * @return new block idx + */ + private static int transposeBlockIndex(int block_idx, int m, int n) { + int i = block_idx / n; + int j = block_idx % n; + return j * m + i; + } + + /** + * Finds the idx of the element which moves to the current block index duing permutation + * + * @param curr_block_idx + * index of current block + * @param m + * number of rows + * @param n + * number of columns + * + * @return new block idx + */ + private static int inverseTransposeBlockIndex(int curr_block_idx, int m, int n) { + int i = curr_block_idx % m; + int j = curr_block_idx / m; + return i * n + j; + } + + /** + * Finds a split index for a tensor permutation that allows reduction of the permutation to a 2D matrix transpose. + * @param perm permutation of tensor axes + * @return split index {i} if reducible, otherwise {-1} + */ + public static int findSplitIndex(int[] perm) { + if (perm == null || perm.length < 2) + return -1; + int n = perm.length; + + for (int i = 1; i < n; i++) { + boolean contiguousFirst = isContiguousRange(perm, 0, i); + boolean contiguousSecond = isContiguousRange(perm, i, n); + + if (contiguousFirst && contiguousSecond) { + if (isSorted(perm, 0, i) && isSorted(perm, i, n)) { + return i; + } + } + } + return -1; + } + + private static boolean isSorted(int[] perm, int start, int end) { + for (int i = start; i < end - 1; i++) + if (perm[i] > perm[i + 1]) + return false; + return true; + } + + private static boolean isContiguousRange(int[] perm, int start, int end) { + int min = perm[start], max = perm[start]; + for (int i = start + 1; i < end; i++) { + if (perm[i] < min) + min = perm[i]; + if (perm[i] > max) + max = perm[i]; + } + return (max - min + 1) == (end - start); + } + + /** + * Restores SystemDS matrix/tensor metadata after an in-place tensor permutation. + * @param in matrix/tensor block whose metadata is restored + * @param finalShape final tensor shape after permutation + */ + private static void restoreMetadata(MatrixBlock in, int[] finalShape) { + in.setNumRows(finalShape[0]); + long totalRemaining = 1; + for (int i = 1; i < finalShape.length; i++) + totalRemaining *= finalShape[i]; + in.setNumColumns((int) totalRemaining); + if (in.getDenseBlock() != null) + in.getDenseBlock().setDims(finalShape); + } } diff --git a/src/test/java/org/apache/sysds/test/TestUtils.java b/src/test/java/org/apache/sysds/test/TestUtils.java index e470dd82539..c944279c15d 100644 --- a/src/test/java/org/apache/sysds/test/TestUtils.java +++ b/src/test/java/org/apache/sysds/test/TestUtils.java @@ -3990,4 +3990,21 @@ public int numBlocks() { return 2; } } + + public static void compareTensorValues(MatrixBlock actual, MatrixBlock expected, double epsilon) { + double[] a = actual.getDenseBlockValues(); + double[] e = expected.getDenseBlockValues(); + + if (a.length != e.length) { + throw new AssertionError("Length mismatch: expected " + e.length + " values, but got " + a.length); + } + + for (int i = 0; i < e.length; i++) { + double diff = Math.abs(e[i] - a[i]); + if (diff > epsilon) { + throw new AssertionError( + "Mismatch at linear index " + i + ": expected=" + e[i] + ", actual=" + a[i] + ", diff=" + diff); + } + } + } } diff --git a/src/test/java/org/apache/sysds/test/component/matrix/libMatrixReorg/TransposeInPlaceBrennerTest.java b/src/test/java/org/apache/sysds/test/component/matrix/libMatrixReorg/TransposeInPlaceBrennerTest.java index 7b575cf37cc..86f72a43138 100644 --- a/src/test/java/org/apache/sysds/test/component/matrix/libMatrixReorg/TransposeInPlaceBrennerTest.java +++ b/src/test/java/org/apache/sysds/test/component/matrix/libMatrixReorg/TransposeInPlaceBrennerTest.java @@ -19,16 +19,19 @@ package org.apache.sysds.test.component.matrix.libMatrixReorg; +import org.apache.sysds.runtime.data.DenseBlock; import org.apache.sysds.runtime.matrix.data.LibMatrixReorg; import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.test.TestUtils; import org.junit.Test; +import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; public class TransposeInPlaceBrennerTest { + @Test public void transposeInPlaceDenseBrennerOnePrime() { // 3*4-1 = 11 @@ -87,7 +90,7 @@ public void transposeInPlaceDenseBrennerSevenPrimesThreeExpos() { public void transposeInPlaceDenseBrennerEightPrimes() { // 347*27953-1 = 2*3*5*7*11*13*17*19 testTransposeInPlaceDense(347, 27953, 0.86); - } + } @Test public void transposeInPlaceDenseBrennerNinePrimes() { @@ -102,9 +105,598 @@ public void transposeInPlaceDenseBrennerNinePrimes() { private void testTransposeInPlaceDense(int rows, int cols, double sparsity) { MatrixBlock X = MatrixBlock.randOperations(rows, cols, sparsity); MatrixBlock tX = LibMatrixReorg.transpose(X); + LibMatrixReorg.transposeInPlaceDenseBrenner(X, 1); + + + TestUtils.compareMatrices(X, tX, 0); + } + + // Tests for tensor permutations + @Test + public void testTensorPermuteSplit_3D() { + int[] shape = {50,2,10}; + int[] perm = {1,2,0}; + testTransposeInPlaceTensor(shape, perm); + } + + @Test + public void testTensorPermuteSplit_8D() { + int[] shape = {3,2,1,3,2,3,1,2}; + int[] perm = {4,5,6,7,0,1,2,3}; + testTransposeInPlaceTensor(shape, perm); + } + + @Test + public void testTensorPermuteSplit_4D() { + int[] shape = {3,2,5,3}; + int[] perm = {2,3,0,1}; + testTransposeInPlaceTensor(shape, perm); + } + + @Test + public void testTensorPermuteSplit_2D_21() { + int[] shape = {4, 10}; + int[] perm = {1,0}; + testTransposeInPlaceTensor(shape, perm); + } + + //Test for primitives + @Test + public void testTensorPermute_3D_213() { + int[] shape = {4, 2, 7}; + int[] perm = {1,0,2}; + testTransposeInPlaceTensor(shape, perm); + } + + @Test + public void testTensorPermute_3D_132() { + int[] shape = {3, 4, 2}; + int[] perm = {0, 2, 1}; + testTransposeInPlaceTensor(shape, perm); + } + + @Test + public void testTensorPermute_4D_1324() { + int[] shape = {3, 2, 2, 3}; + int[] perm = {0, 2, 1, 3}; + testTransposeInPlaceTensor(shape, perm); + } + + @Test + public void testTensorPermute_4Db_1324() { + int[] shape = {3, 4, 5, 6}; + int[] perm = {0, 2, 1, 3}; + testTransposeInPlaceTensor(shape, perm); + } + + @Test + public void testTensorPermuteSplit_5D() { + int[] shape = {2, 3, 4, 5, 6}; + int[] perm = {2, 3, 4, 0, 1}; + testTransposeInPlaceTensor(shape, perm); + } + + @Test + public void testTensorPermuteSplit_6D() { + int[] shape = {4, 3, 2, 5, 8, 2}; + int[] perm = {3, 4, 5, 0, 1, 2}; + testTransposeInPlaceTensor(shape, perm); + } + + @Test + public void testTensorPermuteSplit_5D_MiddleSwap() { + int[] shape = {2, 6, 2, 4, 5}; + int[] perm = {4, 3, 2, 1, 0}; + testTransposeInPlaceTensor(shape, perm); +} + + @Test + public void testTensorPermute_5D_MiddleSwap_Complex() { + int[] shape = {2, 2, 3, 4, 2}; + int[] perm = {0, 2, 1, 3, 4}; + testTransposeInPlaceTensor(shape, perm); +} + +@Test + public void testTensorPermute_7Db() { + int[] shape = {20, 30, 15, 5, 2, 5, 2}; + int[] perm = {0, 6, 1, 5, 4, 2, 3}; + testTransposeInPlaceTensor(shape, perm); +} + +@Test + public void testTensorPermute_7D() { + int[] shape = {2, 3, 5, 5, 2, 3, 2}; + int[] perm = {0, 6, 1, 5, 4, 2, 3}; + testTransposeInPlaceTensor(shape, perm); +} + + @Test + public void testTensorPermuteSplit_Max2() { + int[] shape = {1000, 300, 100}; + int[] perm = {2, 0, 1}; + testTransposeInPlaceTensor(shape, perm); + } + + @Test + public void testTensorPermuteSplit_Max3() { + int[] shape = {8000, 4000, 2}; + int[] perm = {2, 0, 1}; + testTransposeInPlaceTensor(shape, perm); +} + + @Test + public void testTensorPermute_3D_allCases() { + int[] shape = {2, 3, 2}; + int[] perm1 = {0,1, 2}; + int[] perm2 = {0,2, 1}; + int[] perm3 = {1,0,2}; + int[] perm4 = {1,2,0}; + int[] perm5 = {2,0,1}; + int[] perm6 = {2,1,0}; + testTransposeInPlaceTensor(shape, perm1); + testTransposeInPlaceTensor(shape, perm2); + testTransposeInPlaceTensor(shape, perm3); + testTransposeInPlaceTensor(shape, perm4); + testTransposeInPlaceTensor(shape, perm5); + testTransposeInPlaceTensor(shape, perm6); + +} + @Test + public void testTensorPermuteSplit_4Db_213() { + int[] shape = {2, 3, 4}; + int[] perm = {1, 0, 2}; + testTransposeInPlaceTensor(shape, perm); +} + @Test + public void testTensorPermuteSplit_4Db_132() { + int[] shape = {2, 3, 4}; + int[] perm = {0, 2, 1}; + testTransposeInPlaceTensor(shape, perm); +} + + // Edge case tests + + // 1. Square matrices + @Test + public void transposeInPlaceDenseSquare5x5() { + testTransposeInPlaceDense(5, 5, 0.8); + } + + @Test + public void transposeInPlaceDenseSquare100x100() { + testTransposeInPlaceDense(100, 100, 0.7); + } + + @Test + public void testTensorPermute_3D_SquareDims() { + int[] shape = {4, 4, 4}; + int[] perm = {2, 0, 1}; + testTransposeInPlaceTensor(shape, perm); + } + + // 2. Vectors (1×N and N×1) + @Test + public void transposeInPlaceDenseRowVector() { + testTransposeInPlaceDense(1, 50, 0.9); + } + + @Test + public void transposeInPlaceDenseColVector() { + testTransposeInPlaceDense(50, 1, 0.9); + } + + @Test + public void testTensorPermute_VectorLike() { + int[] shape = {1, 20}; + int[] perm = {1, 0}; + testTransposeInPlaceTensor(shape, perm); + } + + // 3. Single element + @Test + public void transposeInPlaceDenseSingleElement() { + testTransposeInPlaceDense(1, 1, 1.0); + } + + @Test + public void testTensorPermute_SingleElement() { + int[] shape = {1, 1, 1}; + int[] perm = {2, 1, 0}; + testTransposeInPlaceTensor(shape, perm); + } + + // 4. Prime dimensions + @Test + public void transposeInPlaceDensePrime7x11() { + testTransposeInPlaceDense(7, 11, 0.75); + } + + @Test + public void transposeInPlaceDensePrime13x17() { + testTransposeInPlaceDense(13, 17, 0.82); + } + + @Test + public void testTensorPermute_AllPrimeDims() { + int[] shape = {3, 5, 7}; + int[] perm = {1, 2, 0}; + testTransposeInPlaceTensor(shape, perm); + } + + // 5. Power of 2 dimensions (common in computing, just to be sure) + @Test + public void transposeInPlaceDensePowerOf2_64x128() { + testTransposeInPlaceDense(64, 128, 0.6); + } + + @Test + public void transposeInPlaceDensePowerOf2_32x64() { + testTransposeInPlaceDense(32, 64, 0.85); + } + @Test + public void testTensorPermute_PowerOf2Dims() { + int[] shape = {8, 16, 4}; + int[] perm = {2, 1, 0}; + testTransposeInPlaceTensor(shape, perm); + } + + // 7. Consecutive transpose (should return to original) + @Test + public void transposeInPlaceDenseConsecutiveTwice() { + MatrixBlock X = MatrixBlock.randOperations(7, 13, 0.75); + MatrixBlock original = new MatrixBlock(X); + + LibMatrixReorg.transposeInPlaceDenseBrenner(X, 1); LibMatrixReorg.transposeInPlaceDenseBrenner(X, 1); + + TestUtils.compareMatrices(X, original, 0); + } - TestUtils.compareMatrices(X, tX, 0); + @Test + public void testTensorPermute_ConsecutiveTwice() { + int[] shape = {3, 4, 5}; + int[] perm = {1, 2, 0}; + + MatrixBlock matrix = createDenseTensor(shape); + MatrixBlock original = new MatrixBlock(matrix); + + LibMatrixReorg.transposeInPlaceTensor(matrix, shape, perm); + // Apply reverse permutation to get back + int[] reversePerm = new int[perm.length]; + for (int i = 0; i < perm.length; i++) { + reversePerm[perm[i]] = i; + } + int[] newShape = new int[shape.length]; + for (int i = 0; i < perm.length; i++) { + newShape[i] = shape[perm[i]]; + } + LibMatrixReorg.transposeInPlaceTensor(matrix, newShape, reversePerm); + + TestUtils.compareMatrices(matrix, original, 0); + } + + // 8.tensors with dimension=1 + @Test + public void testTensorPermute_WithDim1_case1() { + int[] shape = {1, 5, 3}; + int[] perm = {2, 0, 1}; + testTransposeInPlaceTensor(shape, perm); + } + + @Test + public void testTensorPermute_WithDim1_case2() { + int[] shape = {4, 1, 2, 1}; + int[] perm = {2, 3, 0, 1}; + testTransposeInPlaceTensor(shape, perm); + } + + @Test + public void testTensorPermute_WithDim1_case3() { + int[] shape = {3, 1, 4}; + int[] perm = {1, 2, 0}; + testTransposeInPlaceTensor(shape, perm); + } + + // 9. Invalid permutations (negative tests) + // NOTE: more detailed error handling can be added in the future, currently these are just checking for exceptions + @Test + public void testTensorPermute_InvalidPerm_OutOfRange() { + int[] shape = {2, 3, 4}; + int[] perm = {0, 1, 3}; // 3 is out of range for 3D tensor + + MatrixBlock matrix = createDenseTensor(shape); + + assertThrows(Exception.class, + () -> LibMatrixReorg.transposeInPlaceTensor(matrix, shape, perm)); + } + + + @Test + public void testTensorPermute_InvalidPerm_WrongLength() { + int[] shape = {2, 3, 4}; + int[] perm = {0, 1}; // only 2 elements but 3d tensor + + MatrixBlock matrix = createDenseTensor(shape); + + assertThrows(Exception.class, + () -> LibMatrixReorg.transposeInPlaceTensor(matrix, shape, perm)); + } + + @Test + public void testTensorPermute_InvalidPerm_Negative() { + int[] shape = {2, 3, 4}; + int[] perm = {-1, 1, 2}; // negtive index + + MatrixBlock matrix = createDenseTensor(shape); + + assertThrows(Exception.class, + () -> LibMatrixReorg.transposeInPlaceTensor(matrix, shape, perm)); + } + + // 10. Null/empty inputs + @Test + public void testTensorPermute_EmptyShape() { + int[] shape = {}; + int[] perm = {}; + + assertThrows(Exception.class, + () -> createDenseTensor(shape)); + } + + @Test + public void testTensorPermute_NullMatrix() { + int[] shape = {2, 3}; + int[] perm = {1, 0}; + + assertThrows(Exception.class, + () -> LibMatrixReorg.transposeInPlaceTensor(null, shape, perm)); + } + + + //Filling matrices + private static MatrixBlock createDenseTensor(int[] shape) { + long size = 1; + for (int s : shape) + size *= s; + + if (size > Integer.MAX_VALUE) + throw new IllegalArgumentException("Tensor too large: " + size); + + int rows = shape[0]; + long colsL = size / rows; + int cols = (int) colsL; + + MatrixBlock matrix = new MatrixBlock(rows, cols, false); + matrix.allocateDenseBlock(); + + double[] values = matrix.getDenseBlockValues(); + for (int i = 0; i < values.length; i++) + values[i] = i; + + if (matrix.getDenseBlock() != null) + matrix.getDenseBlock().setDims(shape); + + return matrix; +} + + + private void testTransposeInPlaceTensor(int[] shape, int[] perm) { + + MatrixBlock matrix =createDenseTensor(shape); + MatrixBlock expected = permutationOutOfPlace(matrix, shape, perm); + LibMatrixReorg.transposeInPlaceTensor(matrix, shape, perm); + TestUtils.compareMatrices(matrix, expected, 0); + TestUtils.compareTensorValues(matrix, expected, 0); + + } + + //returns the expected matrix (found out-of-place) for comparision + private MatrixBlock permutationOutOfPlace(MatrixBlock in, int[] shape, int[] perm) { + int[] newShape = new int[shape.length]; + for(int i=0; i= 0; i--) { + originalCoords[i] = index % shape[i]; + index /= shape[i]; + } + } + + private int getIndex(int[] coords, int[] shape) { + int index = 0; + int multiplier = 1; + for (int i = shape.length - 1; i >= 0; i--) { + index += coords[i] * multiplier; + multiplier *= shape[i]; + } + return index; + } + + //Test for correct meta-data after permutation + @Test + public void testTensorPermuteSplitShape_6D() { + int[] shape = {2, 3, 4, 5, 6, 7}; + int[] perm = {1, 2, 3, 4, 5, 0}; + + long size = 1; + for(int s : shape) { + size *= s; } + + MatrixBlock X = new MatrixBlock((int) size, 1, false); + X.allocateDenseBlock(); + LibMatrixReorg.transposeInPlaceTensor(X, shape, perm); + testTransposeInPlaceTensorShape(X, shape, perm); +} + + @Test + public void testTensorPermuteSplitShape_6D_Max() { + int[] shape = {1000, 500, 20, 2, 2, 2}; + int[] perm = {1, 2, 3, 4, 5, 0}; + + long size = 1; + for(int s : shape) { + size *= s; + } + + MatrixBlock X = new MatrixBlock((int) size, 1, false); + X.allocateDenseBlock(); + LibMatrixReorg.transposeInPlaceTensor(X, shape, perm); + testTransposeInPlaceTensorShape(X, shape, perm); +} + + @Test + public void testTensorPermuteSplitShape_4D() { + int[] shape = {100, 22, 70, 90}; + int[] perm = {1, 2, 3, 0}; + + long size = 1; + for(int s : shape) { + size *= s; + } + + MatrixBlock X = new MatrixBlock((int) size, 1, false); + X.allocateDenseBlock(); + LibMatrixReorg.transposeInPlaceTensor(X, shape, perm); + testTransposeInPlaceTensorShape(X, shape, perm); +} + + + @Test + public void testTensorPermuteSplitShape_8D() { + int[] shape = {10, 22, 7, 9, 30, 6, 4, 7}; + int[] perm = { 3, 4, 5, 6, 7, 0, 1, 2}; + + long size = 1; + for(int s : shape) { + size *= s; + } + + MatrixBlock X = new MatrixBlock((int) size, 1, false); + X.allocateDenseBlock(); + LibMatrixReorg.transposeInPlaceTensor(X, shape, perm); + testTransposeInPlaceTensorShape(X, shape, perm); +} + +@Test + public void testTensorPermuteSplitShape_5D_middle() { + int[] shape = {10, 8, 5, 4, 2}; + int[] perm = {0, 2, 1, 3, 4}; + + long size = 1; + for(int s : shape) { + size *= s; + } + + MatrixBlock X = new MatrixBlock((int) size, 1, false); + X.allocateDenseBlock(); + LibMatrixReorg.transposeInPlaceTensor(X, shape, perm); + testTransposeInPlaceTensorShape(X, shape, perm); +} + + @Test + public void testTensorPermuteSplitShape_5D() { + int[] shape = {2,3,5,2,8}; + int[] perm = {3,4,0,1,2}; + + long size = 1; + for(int s : shape) { + size *= s; } + + MatrixBlock X = new MatrixBlock((int) size, 1, false); + X.allocateDenseBlock(); + LibMatrixReorg.transposeInPlaceTensor(X, shape, perm); + testTransposeInPlaceTensorShape(X, shape, perm); + } + + @Test + public void testTensorPermuteSplitShape_2D() { + int[] shape = {2,3}; + int[] perm = {1,0}; + + long size = 1; + for(int s : shape) { + size *= s; + } + + MatrixBlock X = new MatrixBlock((int) size, 1, false); + X.allocateDenseBlock(); + LibMatrixReorg.transposeInPlaceTensor(X, shape, perm); + testTransposeInPlaceTensorShape(X, shape, perm); +} + private void testTransposeInPlaceTensorShape(MatrixBlock transposed_X, int[] originalShape, int[] perm){ + int[] expectedShape = new int[originalShape.length]; + for(int i = 0; i < perm.length; i++) { + expectedShape[i] = originalShape[perm[i]]; + } + int expectedRows = expectedShape[0]; + long expectedCols = 1; + for(int i = 1; i < expectedShape.length; i++) { + expectedCols *= expectedShape[i]; + } + + // MatrixBlock shape-match + assertEquals("Matrix Rows mismatch", expectedRows, transposed_X.getNumRows()); + assertEquals("Matrix Columns mismatch", (int)expectedCols, transposed_X.getNumColumns()); + + // DenseBlock shape-match + int[] transposedShape = new int[originalShape.length]; + DenseBlock dense_X = transposed_X.getDenseBlock(); + if(dense_X != null){ + //Comparison of each dimension + for (int i = 0; i < expectedShape.length; i++) { + transposedShape[i] = dense_X.getDim(i); + assertEquals("Dimension " + i + " mismatch", expectedShape[i], dense_X.getDim(i)); + } + int currentExpectedSuffix = expectedShape[expectedShape.length - 1]; + //Comparison of suffixes + for (int i = expectedShape.length - 1; i >= 1; i--) { + assertEquals("Suffix product at dim " + i + " mismatch", currentExpectedSuffix, dense_X.getCumODims(i - 1)); + if(i > 1) { + currentExpectedSuffix *= expectedShape[i - 1]; + } + } + } + + +} }