Home / Class/ v00 Class — pytorch Architecture

v00 Class — pytorch Architecture

Architecture documentation for the v00 class in vec128_half_neon.h from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/cpu/vec/vec128/vec128_half_neon.h lines 66–381

template <>
class Vectorized<c10::Half> : public Vectorized16<
                                  float16x8_t,
                                  c10::Half,
                                  BlendHalfRegs,
                                  Vectorized<c10::Half>> {
  using Base = Vectorized16<
      float16x8_t,
      c10::Half,
      BlendHalfRegs,
      Vectorized<c10::Half>>;
  friend Base;

 private:
  // We use these private map functions to implement various methods
  Vectorized<c10::Half> map_with_vec_float_method(
      Vectorized<float> (Vectorized<float>::*m)() const) const {
    float32x4_t v00 = vcvt_f32_f16(vget_low_f16(values));
    float32x4_t v01 = vcvt_f32_f16(vget_high_f16(values));
    Vectorized<float> mv0 = (Vectorized<float>(v00).*m)();
    Vectorized<float> mv1 = (Vectorized<float>(v01).*m)();
    float16x4_t r00 = vcvt_f16_f32(mv0);
    float16x4_t r01 = vcvt_f16_f32(mv1);
    return Vectorized<c10::Half>(vcombine_f16(r00, r01));
  }

  Vectorized<c10::Half> map2_with_vec_float_method(
      const Vectorized<c10::Half>& second,
      Vectorized<float> (Vectorized<float>::*m)(const Vectorized<float>&)
          const) const {
    float32x4_t v00 = vcvt_f32_f16(vget_low_f16(values));
    float32x4_t v01 = vcvt_f32_f16(vget_high_f16(values));
    float32x4_t second_v00 = vcvt_f32_f16(vget_low_f16(second.values));
    float32x4_t second_v01 = vcvt_f32_f16(vget_high_f16(second.values));
    Vectorized<float> mv0 =
        (Vectorized<float>(v00).*m)(Vectorized<float>(second_v00));
    Vectorized<float> mv1 =
        (Vectorized<float>(v01).*m)(Vectorized<float>(second_v01));
    float16x4_t r00 = vcvt_f16_f32(mv0);
    float16x4_t r01 = vcvt_f16_f32(mv1);

    // Pack result into Vectorized<c10::Half>
    return Vectorized<c10::Half>(vcombine_f16(r00, r01));
  }

  Vectorized<c10::Half> map2_bitmask_with_vec_float_method(
      const Vectorized<c10::Half>& second,
      Vectorized<float> (Vectorized<float>::*m)(const Vectorized<float>&)
          const) const {
    float32x4_t v00 = vcvt_f32_f16(vget_low_f16(values));
    float32x4_t v01 = vcvt_f32_f16(vget_high_f16(values));
    float32x4_t second_v00 = vcvt_f32_f16(vget_low_f16(second.values));
    float32x4_t second_v01 = vcvt_f32_f16(vget_high_f16(second.values));
    Vectorized<float> mv0 =
        (Vectorized<float>(v00).*m)(Vectorized<float>(second_v00));
    Vectorized<float> mv1 =
        (Vectorized<float>(v01).*m)(Vectorized<float>(second_v01));
    // Assume the operator returns a bitmask, not "real" floats, and
    // just narrow the bits. All-ones is a NaN and will get mangled by
    // conversion!
    float16x4_t r00 =
        vreinterpret_f16_u16(vmovn_u32(vreinterpretq_u32_f32(mv0)));
    float16x4_t r01 =
        vreinterpret_f16_u16(vmovn_u32(vreinterpretq_u32_f32(mv1)));

    // Pack result into Vectorized<c10::Half>
    return Vectorized<c10::Half>(vcombine_f16(r00, r01));
  }

 public:
  using Vectorized16::Vectorized16;

  Vectorized() = default;

  // A ctor that accepts c10::Half is needed to fit interface with vec_base.h
  // A second constructor that takes float16_t is also included
  Vectorized(c10::Half val) : Vectorized((float16_t)val) {}
  Vectorized(float16_t val) : Vectorized16(vdupq_n_f16(val)) {}
  Vectorized(
      value_type val0,
      value_type val1,
      value_type val2,
      value_type val3,
      value_type val4,
      value_type val5,
      value_type val6,
      value_type val7)
      : Vectorized16(
            float16x8_t{val0, val1, val2, val3, val4, val5, val6, val7}) {}

  static Vectorized<c10::Half> blendv(
      const Vectorized<c10::Half>& a,
      const Vectorized<c10::Half>& b,
      const Vectorized<c10::Half>& mask) {
    // Note: using blendv is very awkward because 0xFFFF is one of
    // many NaN's in FP16 It's unfortunate that the mask has type Half
    // (required from vec_base)

    // TODO
    // NB: This requires that each value, i.e., each uint value,
    // of the mask either all be zeros or all be 1s.
    // We perhaps need some kind of an assert?
    // But that will affect performance.

    // NOTE [vbslq_f16]: vbslq_f16 doesn't work on clang without
    // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC. vbslq_u16 generates the
    // same instruction anyway. see https://godbolt.org/z/cY4a55Y7P
    Vectorized<c10::Half> vec(mask.values);
    vec.values = vreinterpretq_f16_u16(vbslq_u16(
        vreinterpretq_u16_f16(vec.values),
        vreinterpretq_u16_f16(b.values),
        vreinterpretq_u16_f16(a.values)));
    return vec;
  }
  static Vectorized<c10::Half> set(
      const Vectorized<c10::Half>& a,
      const Vectorized<c10::Half>& b,
      int64_t count = size()) {
    uint16_t pre_mask[size()] = {0};
    for (int i = 0; i < count; i++) {
      pre_mask[i] = 0xFFFF;
    }
    uint16x8_t mask = vld1q_u16(pre_mask);

    // Using blendv is awkward because 0xFFFF is one of many NaN's in FP16
    // so we directly use vbslq_u16 instead. (See NOTE [vbslq_f16] above.)
    Vectorized<c10::Half> vec(vreinterpretq_f16_u16(vbslq_u16(
        mask,
        vreinterpretq_u16_f16(b.values),
        vreinterpretq_u16_f16(a.values))));

    return vec;
  }
  static Vectorized<c10::Half> loadu(const void* ptr, int64_t count = size()) {
    if (count == size()) {
      return vld1q_f16(reinterpret_cast<const float16_t*>(ptr));
    }
    __at_align__ float16_t tmp_values[size()];
    for (const auto i : c10::irange(size())) {
      tmp_values[i] = 0;
    }
    std::memcpy(
        tmp_values,
        reinterpret_cast<const float16_t*>(ptr),
        count * sizeof(float16_t));
    return vld1q_f16(reinterpret_cast<const float16_t*>(tmp_values));
  }
  void store(void* ptr, int64_t count = size()) const {
    if (count == size()) {
      vst1q_f16(reinterpret_cast<float16_t*>(ptr), values);
      return;
    } else {
      float16_t tmp_values[size()];
      vst1q_f16(reinterpret_cast<float16_t*>(tmp_values), values);
      std::memcpy(ptr, tmp_values, count * sizeof(float16_t));
    }
  }
  int zero_mask() const {
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
    uint16x8_t is_zero_vec = vceqzq_f16(values);
    const int16x8_t shift = vcombine_s16(
        vcreate_s16(
            0x0 | (int64_t(0x1) << 16) | (int64_t(0x2) << 32) |
            (int64_t(0x3) << 48)),
        vcreate_s16(
            0x4 | (int64_t(0x5) << 16) | (int64_t(0x6) << 32) |
            (int64_t(0x7) << 48)));
    uint16x8_t bits_vec =
        vshlq_u16(vandq_u16(is_zero_vec, vdupq_n_u16(1)), shift);
    return vaddvq_u16(bits_vec);
#else // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
    // use known working implementation.
    __at_align__ value_type tmp[size()];
    store(tmp);
    int mask = 0;
    for (int i = 0; i < size(); ++i) {
      if (tmp[i] == 0) {
        mask |= (1 << i);
      }
    }
    return mask;
#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  }
  Vectorized<c10::Half> isnan() const {
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
    return vreinterpretq_f16_u16(vmvnq_u16(vceqq_f16(values, values)));
#else
    // NOTE: we could make this faster by doing vectorized checks of
    // exponent/payload bits.
    __at_align__ c10::Half tmp[size()];
    __at_align__ c10::Half res[size()];
    store(tmp);
    for (const auto i : c10::irange(size())) {
      if (_isnan(tmp[i])) {
        std::memset(static_cast<void*>(&res[i]), 0xFF, sizeof(c10::Half));
      } else {
        std::memset(static_cast<void*>(&res[i]), 0, sizeof(c10::Half));
      }
    }
    return loadu(res);
#endif
  }
  bool has_inf_nan() const {
    __at_align__ c10::Half tmp[size()];
    store(tmp);
    for (const auto i : c10::irange(size())) {
      if (_isnan(tmp[i]) || _isinf(tmp[i])) {
        return true;
      }
    }
    return false;
  }
  Vectorized<c10::Half> abs() const {
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
    return Vectorized<c10::Half>(vabsq_f16(values));
#else
    return map_with_vec_float_method(&Vectorized<float>::abs);
#endif
  }
  Vectorized<c10::Half> frac() const;
  Vectorized<c10::Half> neg() const {
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
    return Vectorized<c10::Half>(vnegq_f16(values));
#else
    return map_with_vec_float_method(&Vectorized<float>::neg);
#endif
  }
  Vectorized<c10::Half> trunc() const {
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
    return Vectorized<c10::Half>(vrndq_f16(values));
#else
    return map_with_vec_float_method(&Vectorized<float>::trunc);
#endif
  }
  Vectorized<c10::Half> sqrt() const {
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
    return Vectorized<c10::Half>(vsqrtq_f16(values));
#else
    return map_with_vec_float_method(&Vectorized<float>::sqrt);
#endif
  }
  Vectorized<c10::Half> reciprocal() const {
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
    auto ones = vdupq_n_f16(1.0f);
    return Vectorized<c10::Half>(vdivq_f16(ones, values));
#else
    return map_with_vec_float_method(&Vectorized<float>::reciprocal);
#endif
  }
  Vectorized<c10::Half> operator==(const Vectorized<c10::Half>& other) const {
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
    return Vectorized<c10::Half>(
        vreinterpretq_f16_u16(vceqq_f16(values, other.values)));
#else
    return map2_bitmask_with_vec_float_method(
        other, &Vectorized<float>::operator==);
#endif
  }

  Vectorized<c10::Half> operator!=(const Vectorized<c10::Half>& other) const {
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
    return Vectorized<c10::Half>(
        vreinterpretq_f16_u16(vmvnq_u16(vceqq_f16(values, other.values))));
#else
    return map2_bitmask_with_vec_float_method(
        other, &Vectorized<float>::operator!=);
#endif
  }

  Vectorized<c10::Half> operator<(const Vectorized<c10::Half>& other) const {
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
    return Vectorized<c10::Half>(
        vreinterpretq_f16_u16(vcltq_f16(values, other.values)));
#else
    return map2_bitmask_with_vec_float_method(
        other, &Vectorized<float>::operator<);
#endif
  }

  Vectorized<c10::Half> operator<=(const Vectorized<c10::Half>& other) const {
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
    return Vectorized<c10::Half>(
        vreinterpretq_f16_u16(vcleq_f16(values, other.values)));
#else
    return map2_bitmask_with_vec_float_method(
        other, &Vectorized<float>::operator<=);
#endif
  }

  Vectorized<c10::Half> operator>(const Vectorized<c10::Half>& other) const {
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
    return Vectorized<c10::Half>(
        vreinterpretq_f16_u16(vcgtq_f16(values, other.values)));
#else
    return map2_bitmask_with_vec_float_method(
        other, &Vectorized<float>::operator>);
#endif
  }

  Vectorized<c10::Half> operator>=(const Vectorized<c10::Half>& other) const {
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
    return Vectorized<c10::Half>(
        vreinterpretq_f16_u16(vcgeq_f16(values, other.values)));
#else
    return map2_bitmask_with_vec_float_method(
        other, &Vectorized<float>::operator>=);
#endif
  }

  Vectorized<c10::Half> eq(const Vectorized<c10::Half>& other) const;
  Vectorized<c10::Half> ne(const Vectorized<c10::Half>& other) const;
  Vectorized<c10::Half> gt(const Vectorized<c10::Half>& other) const;
  Vectorized<c10::Half> ge(const Vectorized<c10::Half>& other) const;
  Vectorized<c10::Half> lt(const Vectorized<c10::Half>& other) const;
  Vectorized<c10::Half> le(const Vectorized<c10::Half>& other) const;
}; // Vectorized<Half>

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free