Home / Class/ convert_bfloat16_float Class — pytorch Architecture

convert_bfloat16_float Class — pytorch Architecture

Architecture documentation for the convert_bfloat16_float class in vec128_bfloat16_neon.h from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/cpu/vec/vec128/vec128_bfloat16_neon.h lines 141–438

template <>
class Vectorized<c10::BFloat16> : public Vectorized16<
                                      at_bfloat16x8_t,
                                      c10::BFloat16,
                                      BlendBFloat16Regs,
                                      Vectorized<c10::BFloat16>> {
  using Base = Vectorized16<
      at_bfloat16x8_t,
      c10::BFloat16,
      BlendBFloat16Regs,
      Vectorized<c10::BFloat16>>;
  friend Base;
  friend std::tuple<Vectorized<float>, Vectorized<float>> convert_bfloat16_float(
      const Vectorized<c10::BFloat16>& a);
  friend Vectorized<c10::BFloat16> convert_float_bfloat16(
      const Vectorized<float>& a,
      const Vectorized<float>& b);

 private:
  Vectorized<c10::BFloat16> map2(
      const Vectorized<c10::BFloat16>& second,
      c10::BFloat16 (*const f)(c10::BFloat16, c10::BFloat16)) const {
    __at_align__ c10::BFloat16 tmp_first[size()];
    __at_align__ c10::BFloat16 tmp_second[size()];
    store(tmp_first); // store this to tmp_first
    second.store(tmp_second);
    for (const auto i : c10::irange(size())) {
      tmp_first[i] = f(tmp_first[i], tmp_second[i]);
    }
    return loadu(tmp_first);
  }

  static float32x4_t convert_f32_bf16(at_bfloat16x4_t bf16) {
#ifdef __ARM_FEATURE_BF16
    return vcvt_f32_bf16(bf16);
#else
    int32x4_t shift = vdupq_n_s32(16);
    return vreinterpretq_f32_u32(vshlq_u32(vmovl_u16(bf16), shift));
#endif // __ARM_FEATURE_BF16
  }

  static at_bfloat16x4_t convert_bf16_f32(const Vectorized<float>& f32) {
#ifdef __ARM_FEATURE_BF16
    return vcvt_bf16_f32(f32);
#else
    static_assert(std::is_same_v<uint16x4_t, at_bfloat16x4_t>);
    uint32x4_t as_uint32 = vreinterpretq_u32_f32(f32);
    uint32x4_t rounding_bias = vaddq_u32(
        vandq_u32(vshrq_n_u32(as_uint32, 16), vdupq_n_u32(1)),
        vdupq_n_u32(0x7FFF));
    at_bfloat16x4_t rounded =
        vshrn_n_u32(vaddq_u32(as_uint32, rounding_bias), 16);
    const auto bf16_nan = vdup_n_u16(0x7FC0);
    return vbsl_u16(
        vmovn_u32(vreinterpretq_u32_f32(f32.isnan())), bf16_nan, rounded);
#endif // __ARM_FEATURE_BF16
  }

  Vectorized<c10::BFloat16> map_with_vec_float_method(
      Vectorized<float> (Vectorized<float>::*m)() const) const {
    float32x4_t v00 = convert_f32_bf16(at_vget_low_bf16(values));
    float32x4_t v01 = convert_f32_bf16(at_vget_high_bf16(values));
    Vectorized<float> mv0 = (Vectorized<float>(v00).*m)();
    Vectorized<float> mv1 = (Vectorized<float>(v01).*m)();
    at_bfloat16x4_t r00 = convert_bf16_f32(mv0);
    at_bfloat16x4_t r01 = convert_bf16_f32(mv1);
    return Vectorized<c10::BFloat16>(at_vcombine_bf16(r00, r01));
  }

  Vectorized<c10::BFloat16> map2_with_vec_float_method(
      const Vectorized<c10::BFloat16>& second,
      Vectorized<float> (Vectorized<float>::*m)(const Vectorized<float>&)
          const) const {
    float32x4_t v00 = convert_f32_bf16(at_vget_low_bf16(values));
    float32x4_t v01 = convert_f32_bf16(at_vget_high_bf16(values));
    float32x4_t second_v00 = convert_f32_bf16(at_vget_low_bf16(second.values));
    float32x4_t second_v01 = convert_f32_bf16(at_vget_high_bf16(second.values));
    Vectorized<float> mv0 = (Vectorized<float>(v00).*m)(second_v00);
    Vectorized<float> mv1 = (Vectorized<float>(v01).*m)(second_v01);
    at_bfloat16x4_t r00 = convert_bf16_f32(mv0);
    at_bfloat16x4_t r01 = convert_bf16_f32(mv1);
    return Vectorized<c10::BFloat16>(at_vcombine_bf16(r00, r01));
  }

  Vectorized<c10::BFloat16> map2_bitmask_with_vec_float_method(
      const Vectorized<c10::BFloat16>& second,
      Vectorized<float> (Vectorized<float>::*m)(const Vectorized<float>&)
          const) const {
    float32x4_t v00 = convert_f32_bf16(at_vget_low_bf16(values));
    float32x4_t v01 = convert_f32_bf16(at_vget_high_bf16(values));
    float32x4_t second_v00 = convert_f32_bf16(at_vget_low_bf16(second.values));
    float32x4_t second_v01 = convert_f32_bf16(at_vget_high_bf16(second.values));
    Vectorized<float> mv0 = (Vectorized<float>(v00).*m)(second_v00);
    Vectorized<float> mv1 = (Vectorized<float>(v01).*m)(second_v01);
    // Assume the operator returns a bitmask, not "real" floats, and
    // just narrow the bits. All-ones is a NaN and will get mangled by
    // conversion!
    at_bfloat16x4_t r00 =
        at_vreinterpret_bf16_u16(vmovn_u32(vreinterpretq_u32_f32(mv0)));
    at_bfloat16x4_t r01 =
        at_vreinterpret_bf16_u16(vmovn_u32(vreinterpretq_u32_f32(mv1)));
    return Vectorized<c10::BFloat16>(at_vcombine_bf16(r00, r01));
  }

 public:
  using Vectorized16::Vectorized16;

  Vectorized() = default;

  Vectorized(c10::BFloat16 val)
      : Vectorized16(at_vdupq_n_bf16(c10::bit_cast<at_bfloat16_t>(val.x))) {}
  Vectorized(float val) : Vectorized(c10::BFloat16(val)) {}
  Vectorized(
      value_type val0,
      value_type val1,
      value_type val2,
      value_type val3,
      value_type val4,
      value_type val5,
      value_type val6,
      value_type val7)
      : Vectorized16(at_bfloat16x8_t{
            c10::bit_cast<at_bfloat16_t>(val0.x),
            c10::bit_cast<at_bfloat16_t>(val1.x),
            c10::bit_cast<at_bfloat16_t>(val2.x),
            c10::bit_cast<at_bfloat16_t>(val3.x),
            c10::bit_cast<at_bfloat16_t>(val4.x),
            c10::bit_cast<at_bfloat16_t>(val5.x),
            c10::bit_cast<at_bfloat16_t>(val6.x),
            c10::bit_cast<at_bfloat16_t>(val7.x)}) {}

  static Vectorized<c10::BFloat16> blendv(
      const Vectorized<c10::BFloat16>& a,
      const Vectorized<c10::BFloat16>& b,
      const Vectorized<c10::BFloat16>& mask) {
    // NOTE: blendv has the same problems as it does for Half; see comments in
    // vec128_half_neon.h.
    Vectorized<c10::BFloat16> vec(mask.values);
    vec.values = at_vreinterpretq_bf16_u16(vbslq_u16(
        at_vreinterpretq_u16_bf16(vec.values),
        at_vreinterpretq_u16_bf16(b.values),
        at_vreinterpretq_u16_bf16(a.values)));
    return vec;
  }
  static Vectorized<c10::BFloat16> set(
      const Vectorized<c10::BFloat16>& a,
      const Vectorized<c10::BFloat16>& b,
      int64_t count = size()) {
    uint16_t pre_mask[size()] = {0};
    for (int i = 0; i < count; i++) {
      pre_mask[i] = 0xFFFF;
    }
    uint16x8_t mask = vld1q_u16(pre_mask);

    Vectorized<c10::BFloat16> vec(at_vreinterpretq_bf16_u16(vbslq_u16(
        mask,
        at_vreinterpretq_u16_bf16(b.values),
        at_vreinterpretq_u16_bf16(a.values))));

    return vec;
  }
  static Vectorized<c10::BFloat16> loadu(
      const void* ptr,
      int64_t count = size()) {
    if (count == size()) {
      return at_vld1q_bf16(reinterpret_cast<const at_bfloat16_t*>(ptr));
    }
    __at_align__ at_bfloat16_t tmp_values[size()];
    std::memset(tmp_values, 0, sizeof(tmp_values));
    std::memcpy(
        tmp_values,
        reinterpret_cast<const at_bfloat16_t*>(ptr),
        count * sizeof(at_bfloat16_t));
    return at_vld1q_bf16(reinterpret_cast<const at_bfloat16_t*>(tmp_values));
  }
  void store(void* ptr, int64_t count = size()) const {
    if (count == size()) {
      at_vst1q_bf16(reinterpret_cast<at_bfloat16_t*>(ptr), values);
      return;
    } else {
      at_bfloat16_t tmp_values[size()];
      at_vst1q_bf16(reinterpret_cast<at_bfloat16_t*>(tmp_values), values);
      std::memcpy(ptr, tmp_values, count * sizeof(at_bfloat16_t));
    }
  }
  Vectorized<c10::BFloat16> isnan() const {
    // NOTE: we could make this faster by doing vectorized checks of
    // exponent/payload bits.
    __at_align__ c10::BFloat16 tmp[size()];
    __at_align__ c10::BFloat16 res[size()];
    store(tmp);
    for (const auto i : c10::irange(size())) {
      if (_isnan(tmp[i])) {
        std::memset(static_cast<void*>(&res[i]), 0xFF, sizeof(c10::BFloat16));
      } else {
        std::memset(static_cast<void*>(&res[i]), 0, sizeof(c10::BFloat16));
      }
    }
    return loadu(res);
  }
  bool has_inf_nan() const {
    __at_align__ c10::BFloat16 tmp[size()];
    store(tmp);
    for (const auto i : c10::irange(size())) {
      if (_isnan(tmp[i]) || _isinf(tmp[i])) {
        return true;
      }
    }
    return false;
  }
#define DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(name)    \
  Vectorized name() const {                                     \
    return map_with_vec_float_method(&Vectorized<float>::name); \
  }

#define DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(name) \
  Vectorized name(const Vectorized& other) const {               \
    return map2_bitmask_with_vec_float_method(                   \
        other, &Vectorized<float>::name);                        \
  }

  Vectorized frac() const;
  DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(trunc)
  DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(sqrt)

#ifdef __ARM_FEATURE_BF16
  // Flip sign bit
  Vectorized<c10::BFloat16> neg() const {
    return vreinterpretq_bf16_s16(vreinterpretq_s16_bf16(values) ^ (-32768));
  }
  // Fast reciprocal is fine because we are truncating results
  Vectorized<c10::BFloat16> reciprocal() const {
    auto x = vcvtq_low_f32_bf16(values);
    auto y = vcvtq_high_f32_bf16(values);
    x = vrecpeq_f32(x);
    y = vrecpeq_f32(y);
    return vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(x), y);
  }
  // Clearing the sign bit
  Vectorized<c10::BFloat16> abs() const {
    return vreinterpretq_bf16_u16(vreinterpretq_u16_bf16(values) & 0x7FFF);
  }
#else
  DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(abs)
  DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(neg)
  DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(reciprocal)
#endif

// These functions are optimized on clang-21+
#if BF16_ARITHMETIC_SUPPORTED() && (__clang_major__ >= 21)
  Vectorized<c10::BFloat16> operator==(
      const Vectorized<c10::BFloat16>& other) const {
    return values == other.values;
  }

  Vectorized<c10::BFloat16> operator!=(
      const Vectorized<c10::BFloat16>& other) const {
    return values != other.values;
  }

  Vectorized<c10::BFloat16> operator<(
      const Vectorized<c10::BFloat16>& other) const {
    return values < other.values;
  }

  Vectorized<c10::BFloat16> operator<=(
      const Vectorized<c10::BFloat16>& other) const {
    return values <= other.values;
  }

  Vectorized<c10::BFloat16> operator>(
      const Vectorized<c10::BFloat16>& other) const {
    return values > other.values;
  }

  Vectorized<c10::BFloat16> operator>=(
      const Vectorized<c10::BFloat16>& other) const {
    return values >= other.values;
  }
#else
  DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator==)
  DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator!=)
  DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator<)
  DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator<=)
  DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator>)
  DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator>=)
#endif

#undef DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD
#undef DEFINE_BINARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD

  Vectorized eq(const Vectorized& other) const;
  Vectorized ne(const Vectorized& other) const;
  Vectorized gt(const Vectorized& other) const;
  Vectorized ge(const Vectorized& other) const;
  Vectorized lt(const Vectorized& other) const;
  Vectorized le(const Vectorized& other) const;
}; // Vectorized<c10::BFloat16>

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free