convert_bfloat16_float Class — pytorch Architecture
Architecture documentation for the convert_bfloat16_float class in vec128_bfloat16_neon.h from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/cpu/vec/vec128/vec128_bfloat16_neon.h lines 141–438
template <>
class Vectorized<c10::BFloat16> : public Vectorized16<
at_bfloat16x8_t,
c10::BFloat16,
BlendBFloat16Regs,
Vectorized<c10::BFloat16>> {
using Base = Vectorized16<
at_bfloat16x8_t,
c10::BFloat16,
BlendBFloat16Regs,
Vectorized<c10::BFloat16>>;
friend Base;
friend std::tuple<Vectorized<float>, Vectorized<float>> convert_bfloat16_float(
const Vectorized<c10::BFloat16>& a);
friend Vectorized<c10::BFloat16> convert_float_bfloat16(
const Vectorized<float>& a,
const Vectorized<float>& b);
private:
Vectorized<c10::BFloat16> map2(
const Vectorized<c10::BFloat16>& second,
c10::BFloat16 (*const f)(c10::BFloat16, c10::BFloat16)) const {
__at_align__ c10::BFloat16 tmp_first[size()];
__at_align__ c10::BFloat16 tmp_second[size()];
store(tmp_first); // store this to tmp_first
second.store(tmp_second);
for (const auto i : c10::irange(size())) {
tmp_first[i] = f(tmp_first[i], tmp_second[i]);
}
return loadu(tmp_first);
}
static float32x4_t convert_f32_bf16(at_bfloat16x4_t bf16) {
#ifdef __ARM_FEATURE_BF16
return vcvt_f32_bf16(bf16);
#else
int32x4_t shift = vdupq_n_s32(16);
return vreinterpretq_f32_u32(vshlq_u32(vmovl_u16(bf16), shift));
#endif // __ARM_FEATURE_BF16
}
static at_bfloat16x4_t convert_bf16_f32(const Vectorized<float>& f32) {
#ifdef __ARM_FEATURE_BF16
return vcvt_bf16_f32(f32);
#else
static_assert(std::is_same_v<uint16x4_t, at_bfloat16x4_t>);
uint32x4_t as_uint32 = vreinterpretq_u32_f32(f32);
uint32x4_t rounding_bias = vaddq_u32(
vandq_u32(vshrq_n_u32(as_uint32, 16), vdupq_n_u32(1)),
vdupq_n_u32(0x7FFF));
at_bfloat16x4_t rounded =
vshrn_n_u32(vaddq_u32(as_uint32, rounding_bias), 16);
const auto bf16_nan = vdup_n_u16(0x7FC0);
return vbsl_u16(
vmovn_u32(vreinterpretq_u32_f32(f32.isnan())), bf16_nan, rounded);
#endif // __ARM_FEATURE_BF16
}
Vectorized<c10::BFloat16> map_with_vec_float_method(
Vectorized<float> (Vectorized<float>::*m)() const) const {
float32x4_t v00 = convert_f32_bf16(at_vget_low_bf16(values));
float32x4_t v01 = convert_f32_bf16(at_vget_high_bf16(values));
Vectorized<float> mv0 = (Vectorized<float>(v00).*m)();
Vectorized<float> mv1 = (Vectorized<float>(v01).*m)();
at_bfloat16x4_t r00 = convert_bf16_f32(mv0);
at_bfloat16x4_t r01 = convert_bf16_f32(mv1);
return Vectorized<c10::BFloat16>(at_vcombine_bf16(r00, r01));
}
Vectorized<c10::BFloat16> map2_with_vec_float_method(
const Vectorized<c10::BFloat16>& second,
Vectorized<float> (Vectorized<float>::*m)(const Vectorized<float>&)
const) const {
float32x4_t v00 = convert_f32_bf16(at_vget_low_bf16(values));
float32x4_t v01 = convert_f32_bf16(at_vget_high_bf16(values));
float32x4_t second_v00 = convert_f32_bf16(at_vget_low_bf16(second.values));
float32x4_t second_v01 = convert_f32_bf16(at_vget_high_bf16(second.values));
Vectorized<float> mv0 = (Vectorized<float>(v00).*m)(second_v00);
Vectorized<float> mv1 = (Vectorized<float>(v01).*m)(second_v01);
at_bfloat16x4_t r00 = convert_bf16_f32(mv0);
at_bfloat16x4_t r01 = convert_bf16_f32(mv1);
return Vectorized<c10::BFloat16>(at_vcombine_bf16(r00, r01));
}
Vectorized<c10::BFloat16> map2_bitmask_with_vec_float_method(
const Vectorized<c10::BFloat16>& second,
Vectorized<float> (Vectorized<float>::*m)(const Vectorized<float>&)
const) const {
float32x4_t v00 = convert_f32_bf16(at_vget_low_bf16(values));
float32x4_t v01 = convert_f32_bf16(at_vget_high_bf16(values));
float32x4_t second_v00 = convert_f32_bf16(at_vget_low_bf16(second.values));
float32x4_t second_v01 = convert_f32_bf16(at_vget_high_bf16(second.values));
Vectorized<float> mv0 = (Vectorized<float>(v00).*m)(second_v00);
Vectorized<float> mv1 = (Vectorized<float>(v01).*m)(second_v01);
// Assume the operator returns a bitmask, not "real" floats, and
// just narrow the bits. All-ones is a NaN and will get mangled by
// conversion!
at_bfloat16x4_t r00 =
at_vreinterpret_bf16_u16(vmovn_u32(vreinterpretq_u32_f32(mv0)));
at_bfloat16x4_t r01 =
at_vreinterpret_bf16_u16(vmovn_u32(vreinterpretq_u32_f32(mv1)));
return Vectorized<c10::BFloat16>(at_vcombine_bf16(r00, r01));
}
public:
using Vectorized16::Vectorized16;
Vectorized() = default;
Vectorized(c10::BFloat16 val)
: Vectorized16(at_vdupq_n_bf16(c10::bit_cast<at_bfloat16_t>(val.x))) {}
Vectorized(float val) : Vectorized(c10::BFloat16(val)) {}
Vectorized(
value_type val0,
value_type val1,
value_type val2,
value_type val3,
value_type val4,
value_type val5,
value_type val6,
value_type val7)
: Vectorized16(at_bfloat16x8_t{
c10::bit_cast<at_bfloat16_t>(val0.x),
c10::bit_cast<at_bfloat16_t>(val1.x),
c10::bit_cast<at_bfloat16_t>(val2.x),
c10::bit_cast<at_bfloat16_t>(val3.x),
c10::bit_cast<at_bfloat16_t>(val4.x),
c10::bit_cast<at_bfloat16_t>(val5.x),
c10::bit_cast<at_bfloat16_t>(val6.x),
c10::bit_cast<at_bfloat16_t>(val7.x)}) {}
static Vectorized<c10::BFloat16> blendv(
const Vectorized<c10::BFloat16>& a,
const Vectorized<c10::BFloat16>& b,
const Vectorized<c10::BFloat16>& mask) {
// NOTE: blendv has the same problems as it does for Half; see comments in
// vec128_half_neon.h.
Vectorized<c10::BFloat16> vec(mask.values);
vec.values = at_vreinterpretq_bf16_u16(vbslq_u16(
at_vreinterpretq_u16_bf16(vec.values),
at_vreinterpretq_u16_bf16(b.values),
at_vreinterpretq_u16_bf16(a.values)));
return vec;
}
static Vectorized<c10::BFloat16> set(
const Vectorized<c10::BFloat16>& a,
const Vectorized<c10::BFloat16>& b,
int64_t count = size()) {
uint16_t pre_mask[size()] = {0};
for (int i = 0; i < count; i++) {
pre_mask[i] = 0xFFFF;
}
uint16x8_t mask = vld1q_u16(pre_mask);
Vectorized<c10::BFloat16> vec(at_vreinterpretq_bf16_u16(vbslq_u16(
mask,
at_vreinterpretq_u16_bf16(b.values),
at_vreinterpretq_u16_bf16(a.values))));
return vec;
}
static Vectorized<c10::BFloat16> loadu(
const void* ptr,
int64_t count = size()) {
if (count == size()) {
return at_vld1q_bf16(reinterpret_cast<const at_bfloat16_t*>(ptr));
}
__at_align__ at_bfloat16_t tmp_values[size()];
std::memset(tmp_values, 0, sizeof(tmp_values));
std::memcpy(
tmp_values,
reinterpret_cast<const at_bfloat16_t*>(ptr),
count * sizeof(at_bfloat16_t));
return at_vld1q_bf16(reinterpret_cast<const at_bfloat16_t*>(tmp_values));
}
void store(void* ptr, int64_t count = size()) const {
if (count == size()) {
at_vst1q_bf16(reinterpret_cast<at_bfloat16_t*>(ptr), values);
return;
} else {
at_bfloat16_t tmp_values[size()];
at_vst1q_bf16(reinterpret_cast<at_bfloat16_t*>(tmp_values), values);
std::memcpy(ptr, tmp_values, count * sizeof(at_bfloat16_t));
}
}
Vectorized<c10::BFloat16> isnan() const {
// NOTE: we could make this faster by doing vectorized checks of
// exponent/payload bits.
__at_align__ c10::BFloat16 tmp[size()];
__at_align__ c10::BFloat16 res[size()];
store(tmp);
for (const auto i : c10::irange(size())) {
if (_isnan(tmp[i])) {
std::memset(static_cast<void*>(&res[i]), 0xFF, sizeof(c10::BFloat16));
} else {
std::memset(static_cast<void*>(&res[i]), 0, sizeof(c10::BFloat16));
}
}
return loadu(res);
}
bool has_inf_nan() const {
__at_align__ c10::BFloat16 tmp[size()];
store(tmp);
for (const auto i : c10::irange(size())) {
if (_isnan(tmp[i]) || _isinf(tmp[i])) {
return true;
}
}
return false;
}
#define DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(name) \
Vectorized name() const { \
return map_with_vec_float_method(&Vectorized<float>::name); \
}
#define DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(name) \
Vectorized name(const Vectorized& other) const { \
return map2_bitmask_with_vec_float_method( \
other, &Vectorized<float>::name); \
}
Vectorized frac() const;
DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(trunc)
DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(sqrt)
#ifdef __ARM_FEATURE_BF16
// Flip sign bit
Vectorized<c10::BFloat16> neg() const {
return vreinterpretq_bf16_s16(vreinterpretq_s16_bf16(values) ^ (-32768));
}
// Fast reciprocal is fine because we are truncating results
Vectorized<c10::BFloat16> reciprocal() const {
auto x = vcvtq_low_f32_bf16(values);
auto y = vcvtq_high_f32_bf16(values);
x = vrecpeq_f32(x);
y = vrecpeq_f32(y);
return vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(x), y);
}
// Clearing the sign bit
Vectorized<c10::BFloat16> abs() const {
return vreinterpretq_bf16_u16(vreinterpretq_u16_bf16(values) & 0x7FFF);
}
#else
DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(abs)
DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(neg)
DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(reciprocal)
#endif
// These functions are optimized on clang-21+
#if BF16_ARITHMETIC_SUPPORTED() && (__clang_major__ >= 21)
Vectorized<c10::BFloat16> operator==(
const Vectorized<c10::BFloat16>& other) const {
return values == other.values;
}
Vectorized<c10::BFloat16> operator!=(
const Vectorized<c10::BFloat16>& other) const {
return values != other.values;
}
Vectorized<c10::BFloat16> operator<(
const Vectorized<c10::BFloat16>& other) const {
return values < other.values;
}
Vectorized<c10::BFloat16> operator<=(
const Vectorized<c10::BFloat16>& other) const {
return values <= other.values;
}
Vectorized<c10::BFloat16> operator>(
const Vectorized<c10::BFloat16>& other) const {
return values > other.values;
}
Vectorized<c10::BFloat16> operator>=(
const Vectorized<c10::BFloat16>& other) const {
return values >= other.values;
}
#else
DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator==)
DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator!=)
DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator<)
DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator<=)
DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator>)
DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator>=)
#endif
#undef DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD
#undef DEFINE_BINARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD
Vectorized eq(const Vectorized& other) const;
Vectorized ne(const Vectorized& other) const;
Vectorized gt(const Vectorized& other) const;
Vectorized ge(const Vectorized& other) const;
Vectorized lt(const Vectorized& other) const;
Vectorized le(const Vectorized& other) const;
}; // Vectorized<c10::BFloat16>
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free