Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 1 addition & 37 deletions src/model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -162,43 +162,7 @@ uint16_t f8_e4m3_to_f16(uint8_t f8) {
}

uint16_t f8_e5m2_to_f16(uint8_t fp8) {
uint8_t sign = (fp8 >> 7) & 0x1;
uint8_t exponent = (fp8 >> 2) & 0x1F;
uint8_t mantissa = fp8 & 0x3;

uint16_t fp16_sign = sign << 15;
uint16_t fp16_exponent;
uint16_t fp16_mantissa;

if (exponent == 0 && mantissa == 0) { // zero
return fp16_sign;
}

if (exponent == 0x1F) { // NAN and INF
fp16_exponent = 0x1F;
fp16_mantissa = mantissa ? (mantissa << 8) : 0;
return fp16_sign | (fp16_exponent << 10) | fp16_mantissa;
}

if (exponent == 0) { // subnormal numbers
fp16_mantissa = (mantissa << 8);
return fp16_sign | fp16_mantissa;
}

// normal numbers
int16_t true_exponent = (int16_t)exponent - 15 + 15;
if (true_exponent <= 0) {
fp16_exponent = 0;
fp16_mantissa = (mantissa << 8);
} else if (true_exponent >= 0x1F) {
fp16_exponent = 0x1F;
fp16_mantissa = 0;
} else {
fp16_exponent = (uint16_t)true_exponent;
fp16_mantissa = mantissa << 8;
}

return fp16_sign | (fp16_exponent << 10) | fp16_mantissa;
return static_cast<uint16_t>(fp8) << 8;
}

void f8_e4m3_to_f16_vec(uint8_t* src, uint16_t* dst, int64_t n) {
Expand Down
Loading