Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,8 @@ replay_pid*
build

.flattened-pom.xml

.classpath
.factorypath
.project
.settings
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import org.asdfformat.asdf.ndarray.DataType;
import org.asdfformat.asdf.ndarray.DataTypeFamilyType;
import org.asdfformat.asdf.ndarray.DataTypes;
import org.asdfformat.asdf.ndarray.impl.Float16Utils;
import org.asdfformat.asdf.node.AsdfNode;
import org.asdfformat.asdf.util.AsdfCharsets;

Expand All @@ -25,6 +26,7 @@ public class InlineBlockV1_0_0 implements Block {
SIMPLE_VALUE_WRITERS.put(DataTypes.INT32, (b, n) -> b.putInt(n.asInt()));
SIMPLE_VALUE_WRITERS.put(DataTypes.UINT64, (b, n) -> b.putLong(n.asLong()));
SIMPLE_VALUE_WRITERS.put(DataTypes.INT64, (b, n) -> b.putLong(n.asLong()));
SIMPLE_VALUE_WRITERS.put(DataTypes.FLOAT16, (b, n) -> b.putShort(Float16Utils.floatToFloat16(n.asFloat())));
SIMPLE_VALUE_WRITERS.put(DataTypes.FLOAT32, (b, n) -> b.putFloat(n.asFloat()));
SIMPLE_VALUE_WRITERS.put(DataTypes.FLOAT64, (b, n) -> b.putDouble(n.asDouble()));
SIMPLE_VALUE_WRITERS.put(DataTypes.BOOL8, (b, n) -> b.put((byte)(n.asBoolean() ? 1 : 0)));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,12 @@ public class DataTypes {
new HashSet<>(Arrays.asList(Long.TYPE, BigInteger.class))
);

public static final DataType FLOAT16 = new SimpleDataTypeImpl(
DataTypeFamilyType.FLOAT,
2,
new HashSet<>(Arrays.asList(Float.TYPE, Double.TYPE, BigDecimal.class))
);

public static final DataType FLOAT32 = new SimpleDataTypeImpl(
DataTypeFamilyType.FLOAT,
4,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
import java.util.function.BiConsumer;
import java.util.function.Function;

import static org.asdfformat.asdf.ndarray.impl.Float16Utils.float16ToFloat;

public class BigDecimalNdArrayImpl extends NdArrayBase<BigDecimalNdArray> implements BigDecimalNdArray {
public BigDecimalNdArrayImpl(final DataType dataType, final int[] shape, final ByteOrder byteOrder, final int[] strides, final int offset, final Block block) {
super(dataType, shape, byteOrder, strides, offset, block);
Expand Down Expand Up @@ -46,6 +48,8 @@ public BigDecimal get(final int... indices) {
} else {
throw new RuntimeException("Unhandled datatype: " + dataType);
}
} else if (dataType.equals(DataTypes.FLOAT16)) {
return BigDecimal.valueOf(float16ToFloat(byteBuffer.getShort()));
} else if (dataType.equals(DataTypes.FLOAT32)) {
return BigDecimal.valueOf(byteBuffer.getFloat());
} else if (dataType.equals(DataTypes.FLOAT64)) {
Expand Down Expand Up @@ -84,6 +88,12 @@ public <ARRAY> ARRAY toArray(final ARRAY array) {
arr[index + i] = valueCreator.apply(buffer);
}
};
} else if (dataType.equals(DataTypes.FLOAT16)) {
setter = (byteBuffer, arr, index, length) -> {
for (int i = 0; i < length; i++) {
arr[index + i] = BigDecimal.valueOf(float16ToFloat(byteBuffer.getShort()));
}
};
} else if (dataType.equals(DataTypes.FLOAT32)) {
setter = (byteBuffer, arr, index, length) -> {
for (int i = 0; i < length; i++) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import java.nio.ByteBuffer;
import java.nio.ByteOrder;

import static org.asdfformat.asdf.ndarray.impl.Float16Utils.float16ToFloat;

public class DoubleNdArrayImpl extends NdArrayBase<DoubleNdArray> implements DoubleNdArray {
public DoubleNdArrayImpl(final DataType dataType, final int[] shape, final ByteOrder byteOrder, final int[] strides, final int offset, final Block block) {
super(dataType, shape, byteOrder, strides, offset, block);
Expand All @@ -34,6 +36,8 @@ public double get(int... indices) {
return byteBuffer.getDouble();
} else if (dataType.equals(DataTypes.FLOAT32)) {
return byteBuffer.getFloat();
} else if (dataType.equals(DataTypes.FLOAT16)) {
return float16ToFloat(byteBuffer.getShort());
} else {
throw new RuntimeException("Unhandled datatype: " + dataType);
}
Expand All @@ -52,6 +56,12 @@ public <ARRAY> ARRAY toArray(final ARRAY array) {
arr[index + i] = floatArr[i];
}
};
} else if (dataType.equals(DataTypes.FLOAT16)) {
setter = (byteBuffer, arr, index, length) -> {
for (int i = 0; i < length; i++) {
arr[index + i] = float16ToFloat(byteBuffer.getShort());
}
};
} else {
throw new RuntimeException("Unhandled datatype: " + dataType);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
package org.asdfformat.asdf.ndarray.impl;

public class Float16Utils {
public static float float16ToFloat(final short bits) {
final int halfBits = bits & 0xFFFF;
final int sign = (halfBits >>> 15) & 0x1;
final int exponent = (halfBits >>> 10) & 0x1F;
final int mantissa = halfBits & 0x3FF;

final int floatBits;
if (exponent == 0) {
if (mantissa == 0) {
floatBits = sign << 31;
} else {
// Subnormal: normalize by shifting mantissa until the leading 1 is in bit 10
int m = mantissa;
int e = -14 + 127;
while ((m & 0x400) == 0) {
m <<= 1;
e--;
}
m &= 0x3FF;
floatBits = (sign << 31) | (e << 23) | (m << 13);
}
} else if (exponent == 31) {
// Inf or NaN: rebased exponent to float32's 255
floatBits = (sign << 31) | (0xFF << 23) | (mantissa << 13);
} else {
// Normal: rebase exponent from bias-15 to bias-127
final int floatExponent = exponent - 15 + 127;
floatBits = (sign << 31) | (floatExponent << 23) | (mantissa << 13);
}

return Float.intBitsToFloat(floatBits);
}

public static short floatToFloat16(final float value) {
final int floatBits = Float.floatToIntBits(value);
final int sign = (floatBits >>> 31) & 0x1;
final int exponent = (floatBits >>> 23) & 0xFF;
final int mantissa = floatBits & 0x7FFFFF;

final int halfBits;
if (exponent == 0) {
halfBits = sign << 15;
} else if (exponent == 0xFF) {
if (mantissa == 0) {
halfBits = (sign << 15) | (0x1F << 10);
} else {
final int halfMantissa = mantissa >>> 13;
halfBits = (sign << 15) | (0x1F << 10) | (halfMantissa != 0 ? halfMantissa : 0x1);
}
} else {
final int halfExponent = exponent - 127 + 15;
if (halfExponent >= 31) {
halfBits = (sign << 15) | (0x1F << 10);
} else if (halfExponent <= 0) {
if (halfExponent < -10) {
halfBits = sign << 15;
} else {
final int shift = 1 - halfExponent + 13;
final int m = (mantissa | 0x800000) >>> shift;
final int roundBit = ((mantissa | 0x800000) >>> (shift - 1)) & 0x1;
final int stickyBit = ((mantissa | 0x800000) & ((1 << (shift - 1)) - 1)) != 0 ? 1 : 0;
final int rounded = m + (roundBit & (stickyBit | (m & 1)));
halfBits = (sign << 15) | rounded;
}
} else {
final int truncated = (mantissa >>> 13);
final int roundBit = (mantissa >>> 12) & 0x1;
final int stickyBit = (mantissa & 0xFFF) != 0 ? 1 : 0;
final int rounded = truncated + (roundBit & (stickyBit | (truncated & 1)));
halfBits = (sign << 15) | ((halfExponent << 10) + rounded);
}
}

return (short) halfBits;
}

private Float16Utils() {}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import java.nio.ByteBuffer;
import java.nio.ByteOrder;

import static org.asdfformat.asdf.ndarray.impl.Float16Utils.float16ToFloat;

public class FloatNdArrayImpl extends NdArrayBase<FloatNdArray> implements FloatNdArray {
public FloatNdArrayImpl(final DataType dataType, final int[] shape, final ByteOrder byteOrder, final int[] strides, final int offset, final Block block) {
super(dataType, shape, byteOrder, strides, offset, block);
Expand All @@ -32,14 +34,27 @@ public float get(int... indices) {
final ByteBuffer byteBuffer = getByteBufferAt(indices);
if (dataType.equals(DataTypes.FLOAT32)) {
return byteBuffer.getFloat();
} else if (dataType.equals(DataTypes.FLOAT16)) {
return float16ToFloat(byteBuffer.getShort());
} else {
throw new RuntimeException("Unhandled datatype: " + dataType);
}
}

@Override
public <ARRAY> ARRAY toArray(final ARRAY array) {
final ArraySetter<float[]> setter = (byteBuffer, arr, index, length) -> byteBuffer.asFloatBuffer().get(arr, index, length);
final ArraySetter<float[]> setter;
if (dataType.equals(DataTypes.FLOAT32)) {
setter = (byteBuffer, arr, index, length) -> byteBuffer.asFloatBuffer().get(arr, index, length);
} else if (dataType.equals(DataTypes.FLOAT16)) {
setter = (byteBuffer, arr, index, length) -> {
for (int i = 0; i < length; i++) {
arr[index + i] = float16ToFloat(byteBuffer.getShort());
}
};
} else {
throw new RuntimeException("Unhandled datatype: " + dataType);
}
return toArray(array, Float.TYPE, setter);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ public class NdArrayHandler_1_x implements NdArrayHandler {
SIMPLE_DATA_TYPES.put("uint8", DataTypes.UINT8);
SIMPLE_DATA_TYPES.put("uint16", DataTypes.UINT16);
SIMPLE_DATA_TYPES.put("uint32", DataTypes.UINT32);
SIMPLE_DATA_TYPES.put("float16", DataTypes.FLOAT16);
SIMPLE_DATA_TYPES.put("float32", DataTypes.FLOAT32);
SIMPLE_DATA_TYPES.put("float64", DataTypes.FLOAT64);
SIMPLE_DATA_TYPES.put("complex64", DataTypes.COMPLEX64);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
package org.asdfformat.asdf.ndarray;

import org.asdfformat.asdf.Asdf;
import org.asdfformat.asdf.AsdfFile;
import org.asdfformat.asdf.standard.AsdfStandardType;
import org.asdfformat.asdf.testing.CoreReferenceFileType;
import org.asdfformat.asdf.testing.ReferenceFileUtils;
import org.asdfformat.asdf.util.Version;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;

import java.io.IOException;
import java.math.BigDecimal;
import java.nio.file.Path;
import java.util.Arrays;
import java.util.stream.Stream;

import static org.asdfformat.asdf.testing.TestCategories.REFERENCE_TESTS;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;

@Tag(REFERENCE_TESTS)
public class NdArrayFloat16ReferenceTest {
private static final Version FLOAT16_MIN_VERSION = new Version(1, 6, 0);

private static final CoreReferenceFileType[] FILE_TYPES = {
CoreReferenceFileType.NDARRAY_FLOAT16_1D_BLOCK_BIG,
CoreReferenceFileType.NDARRAY_FLOAT16_1D_BLOCK_LITTLE,
CoreReferenceFileType.NDARRAY_FLOAT16_1D_INLINE,
};

private static Stream<Arguments> float16Args() {
return Arrays.stream(FILE_TYPES)
.flatMap(fileType -> Arrays.stream(AsdfStandardType.values())
.filter(std -> std.getVersion().compareTo(FLOAT16_MIN_VERSION) >= 0)
.map(std -> Arguments.of(fileType, std)));
}

@ParameterizedTest
@MethodSource("float16Args")
public void testFloat1d(final CoreReferenceFileType coreTestFileType, final AsdfStandardType asdfStandardType) throws IOException {
final Path path = ReferenceFileUtils.getPath(coreTestFileType, asdfStandardType.getVersion());

try (final AsdfFile asdfFile = Asdf.open(path)) {
final FloatNdArray floatNdArray = asdfFile.getTree().get("arr").asNdArray().asFloatNdArray();

assertEquals(-65504.0f, floatNdArray.get(0));
assertEquals(65504.0f, floatNdArray.get(1));
assertEquals(5.9604645E-8f, floatNdArray.get(2));
assertEquals(0.0f, floatNdArray.get(3));
assertTrue(Float.isNaN(floatNdArray.get(4)));
assertEquals(Float.POSITIVE_INFINITY, floatNdArray.get(5));
assertEquals(Float.NEGATIVE_INFINITY, floatNdArray.get(6));
assertEquals(3.140625f, floatNdArray.get(7));
assertEquals(-3.140625f, floatNdArray.get(8));

final float[] arr = floatNdArray.toArray(new float[9]);
assertEquals(-65504.0f, arr[0]);
assertEquals(65504.0f, arr[1]);
assertEquals(5.9604645E-8f, arr[2]);
assertEquals(0.0f, arr[3]);
assertTrue(Float.isNaN(arr[4]));
assertEquals(Float.POSITIVE_INFINITY, arr[5]);
assertEquals(Float.NEGATIVE_INFINITY, arr[6]);
assertEquals(3.140625f, arr[7]);
assertEquals(-3.140625f, arr[8]);
}
}

@ParameterizedTest
@MethodSource("float16Args")
public void testDouble1d(final CoreReferenceFileType coreTestFileType, final AsdfStandardType asdfStandardType) throws IOException {
final Path path = ReferenceFileUtils.getPath(coreTestFileType, asdfStandardType.getVersion());

try (final AsdfFile asdfFile = Asdf.open(path)) {
final DoubleNdArray doubleNdArray = asdfFile.getTree().get("arr").asNdArray().asDoubleNdArray();

assertEquals(-65504.0, doubleNdArray.get(0));
assertEquals(65504.0, doubleNdArray.get(1));
assertEquals(5.960464477539063E-8, doubleNdArray.get(2));
assertEquals(0.0, doubleNdArray.get(3));
assertTrue(Double.isNaN(doubleNdArray.get(4)));
assertEquals(Double.POSITIVE_INFINITY, doubleNdArray.get(5));
assertEquals(Double.NEGATIVE_INFINITY, doubleNdArray.get(6));
assertEquals(3.140625, doubleNdArray.get(7));
assertEquals(-3.140625, doubleNdArray.get(8));

final double[] arr = doubleNdArray.toArray(new double[9]);
assertEquals(-65504.0, arr[0]);
assertEquals(65504.0, arr[1]);
assertEquals(5.960464477539063E-8, arr[2]);
assertEquals(0.0, arr[3]);
assertTrue(Double.isNaN(arr[4]));
assertEquals(Double.POSITIVE_INFINITY, arr[5]);
assertEquals(Double.NEGATIVE_INFINITY, arr[6]);
assertEquals(3.140625, arr[7]);
assertEquals(-3.140625, arr[8]);
}
}

@ParameterizedTest
@MethodSource("float16Args")
public void testBigDecimal1d(final CoreReferenceFileType coreTestFileType, final AsdfStandardType asdfStandardType) throws IOException {
final Path path = ReferenceFileUtils.getPath(coreTestFileType, asdfStandardType.getVersion());

try (final AsdfFile asdfFile = Asdf.open(path)) {
final BigDecimalNdArray bigDecimalNdArray = asdfFile.getTree().get("arr").asNdArray().asBigDecimalNdArray();

assertEquals(BigDecimal.valueOf(-65504.0), bigDecimalNdArray.get(0));
assertEquals(BigDecimal.valueOf(65504.0), bigDecimalNdArray.get(1));
assertEquals(BigDecimal.valueOf(5.960464477539063E-8), bigDecimalNdArray.get(2));
assertEquals(BigDecimal.valueOf(0.0), bigDecimalNdArray.get(3));
assertEquals(BigDecimal.valueOf(3.140625), bigDecimalNdArray.get(7));
assertEquals(BigDecimal.valueOf(-3.140625), bigDecimalNdArray.get(8));

}
}
}
Loading
Loading