Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 28 additions & 17 deletions bf16.h
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
#ifndef BF16_h
#define BF16_h

#ifdef __cplusplus

#include <stdint.h>

extern "C" {

#endif

/**
* Converts brain16 to float32.
*
Expand Down Expand Up @@ -42,29 +49,33 @@
* @see IEEE 754-2008
*/
static inline float from_brain(uint16_t h) {
union {
float f;
uint32_t i;
} u;
u.i = (uint32_t)h << 16;
return u.f;
union {
float f;
uint32_t i;
} u;
u.i = (uint32_t)h << 16;
return u.f;
}

/**
* Converts float32 to brain16.
*/
static inline uint16_t to_brain(float s) {
uint16_t h;
union {
float f;
uint32_t i;
} u;
u.f = s;
if ((u.i & 0x7fffffff) > 0x7f800000) { /* nan */
h = (u.i >> 16) | 64; /* force to quiet */
return h;
}
return (u.i + (0x7fff + ((u.i >> 16) & 1))) >> 16;
uint16_t h;
union {
float f;
uint32_t i;
} u;
u.f = s;
if ((u.i & 0x7fffffff) > 0x7f800000) { /* nan */
h = (u.i >> 16) | 64; /* force to quiet */
return h;
}
return (u.i + (0x7fff + ((u.i >> 16) & 1))) >> 16;
}

#ifdef __cplusplus
}
#endif

#endif
110 changes: 58 additions & 52 deletions fp16.c
Original file line number Diff line number Diff line change
Expand Up @@ -8,83 +8,89 @@
*
* The original code is MIT licensed. */

#include <stdint.h>
#include <math.h>
#include <stdint.h>

static inline float fp32_from_bits(uint32_t w) {
union {
uint32_t as_bits;
float as_value;
} fp32;
fp32.as_bits = w;
return fp32.as_value;
union {
uint32_t as_bits;
float as_value;
} fp32;
fp32.as_bits = w;
return fp32.as_value;
}

static inline uint32_t fp32_to_bits(float f) {
union {
float as_value;
uint32_t as_bits;
} fp32;
fp32.as_value = f;
return fp32.as_bits;
union {
float as_value;
uint32_t as_bits;
} fp32;
fp32.as_value = f;
return fp32.as_bits;
}

float from_half(uint16_t h) {
const uint32_t w = (uint32_t) h << 16;
const uint32_t sign = w & UINT32_C(0x80000000);
const uint32_t two_w = w + w;
const uint32_t w = (uint32_t)h << 16;
const uint32_t sign = w & UINT32_C(0x80000000);
const uint32_t two_w = w + w;

const uint32_t exp_offset = UINT32_C(0xE0) << 23;
#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) || defined(__GNUC__) && !defined(__STRICT_ANSI__)
const float exp_scale = 0x1.0p-112f;
const uint32_t exp_offset = UINT32_C(0xE0) << 23;
#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) || \
defined(__GNUC__) && !defined(__STRICT_ANSI__)
const float exp_scale = 0x1.0p-112f;
#else
const float exp_scale = fp32_from_bits(UINT32_C(0x7800000));
const float exp_scale = fp32_from_bits(UINT32_C(0x7800000));
#endif
const float normalized_value = fp32_from_bits((two_w >> 4) + exp_offset) * exp_scale;
const float normalized_value =
fp32_from_bits((two_w >> 4) + exp_offset) * exp_scale;

const uint32_t magic_mask = UINT32_C(126) << 23;
const float magic_bias = 0.5f;
const float denormalized_value = fp32_from_bits((two_w >> 17) | magic_mask) - magic_bias;
const uint32_t magic_mask = UINT32_C(126) << 23;
const float magic_bias = 0.5f;
const float denormalized_value =
fp32_from_bits((two_w >> 17) | magic_mask) - magic_bias;

const uint32_t denormalized_cutoff = UINT32_C(1) << 27;
const uint32_t result = sign |
(two_w < denormalized_cutoff ? fp32_to_bits(denormalized_value) : fp32_to_bits(normalized_value));
return fp32_from_bits(result);
const uint32_t denormalized_cutoff = UINT32_C(1) << 27;
const uint32_t result =
sign | (two_w < denormalized_cutoff ? fp32_to_bits(denormalized_value)
: fp32_to_bits(normalized_value));
return fp32_from_bits(result);
}

uint16_t to_half(float f) {
#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) || defined(__GNUC__) && !defined(__STRICT_ANSI__)
const float scale_to_inf = 0x1.0p+112f;
const float scale_to_zero = 0x1.0p-110f;
#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) || \
defined(__GNUC__) && !defined(__STRICT_ANSI__)
const float scale_to_inf = 0x1.0p+112f;
const float scale_to_zero = 0x1.0p-110f;
#else
const float scale_to_inf = fp32_from_bits(UINT32_C(0x77800000));
const float scale_to_zero = fp32_from_bits(UINT32_C(0x08800000));
const float scale_to_inf = fp32_from_bits(UINT32_C(0x77800000));
const float scale_to_zero = fp32_from_bits(UINT32_C(0x08800000));
#endif
float base = (fabsf(f) * scale_to_inf) * scale_to_zero;
float base = (fabsf(f) * scale_to_inf) * scale_to_zero;

const uint32_t w = fp32_to_bits(f);
const uint32_t shl1_w = w + w;
const uint32_t sign = w & UINT32_C(0x80000000);
uint32_t bias = shl1_w & UINT32_C(0xFF000000);
if (bias < UINT32_C(0x71000000)) {
bias = UINT32_C(0x71000000);
}
const uint32_t w = fp32_to_bits(f);
const uint32_t shl1_w = w + w;
const uint32_t sign = w & UINT32_C(0x80000000);
uint32_t bias = shl1_w & UINT32_C(0xFF000000);
if (bias < UINT32_C(0x71000000)) {
bias = UINT32_C(0x71000000);
}

base = fp32_from_bits((bias >> 1) + UINT32_C(0x07800000)) + base;
const uint32_t bits = fp32_to_bits(base);
const uint32_t exp_bits = (bits >> 13) & UINT32_C(0x00007C00);
const uint32_t mantissa_bits = bits & UINT32_C(0x00000FFF);
const uint32_t nonsign = exp_bits + mantissa_bits;
return (sign >> 16) | (shl1_w > UINT32_C(0xFF000000) ? UINT16_C(0x7E00) : nonsign);
base = fp32_from_bits((bias >> 1) + UINT32_C(0x07800000)) + base;
const uint32_t bits = fp32_to_bits(base);
const uint32_t exp_bits = (bits >> 13) & UINT32_C(0x00007C00);
const uint32_t mantissa_bits = bits & UINT32_C(0x00000FFF);
const uint32_t nonsign = exp_bits + mantissa_bits;
return (sign >> 16) |
(shl1_w > UINT32_C(0xFF000000) ? UINT16_C(0x7E00) : nonsign);
}

#ifdef TEST_MAIN
#include <stdio.h>
int main(void) {
float f = 1.2345;
uint16_t half = to_half(f);
float f2 = from_half(half);
printf("%f %f\n", f, f2);
return 0;
float f = 1.2345;
uint16_t half = to_half(f);
float f2 = from_half(half);
printf("%f %f\n", f, f2);
return 0;
}
#endif
11 changes: 11 additions & 0 deletions fp16.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,17 @@

#ifndef FP16_h
#define FP16_h

#ifdef __cplusplus
#include <cstdint>
extern "C" {
#endif

float from_half(uint16_t h);
uint16_t to_half(float f);

#ifdef __cplusplus
}
#endif

#endif
Loading