Home / Class/ VECTOR_WIDTH Class — pytorch Architecture

VECTOR_WIDTH Class — pytorch Architecture

Architecture documentation for the VECTOR_WIDTH class in vec_bfloat16.h from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/cpu/vec/sve/vec_bfloat16.h lines 27–225

template <>
class Vectorized<BFloat16> {
 private:
  vls_bfloat16_t values;

 public:
  using value_type = BFloat16;
  using size_type = int;

  static constexpr size_type size() {
    return VECTOR_WIDTH / sizeof(BFloat16);
  }

  Vectorized();
  Vectorized(svbfloat16_t v) : values(v) {}
  Vectorized(int val);
  Vectorized(BFloat16 val);

  template <
      typename... Args,
      typename = std::enable_if_t<(sizeof...(Args) == size())>>
  Vectorized(Args... vals) {
    __at_align__ BFloat16 buffer[size()] = {vals...};
    values = svld1_bf16(ptrue, reinterpret_cast<const bfloat16_t*>(buffer));
  }

  operator svbfloat16_t() const {
    return values;
  }
  static Vectorized<BFloat16> blendv(
      const Vectorized<BFloat16>& a,
      const Vectorized<BFloat16>& b,
      const Vectorized<BFloat16>& mask_) {
    svbool_t mask =
        svcmpeq_s16(ptrue, svreinterpret_s16_bf16(mask_), ALL_S16_TRUE_MASK);
    return svsel_bf16(mask, b, a);
  }
  template <typename step_t>
  static Vectorized<BFloat16> arange(
      BFloat16 base = 0.f,
      step_t step = static_cast<step_t>(1)) {
    __at_align__ BFloat16 buffer[size()];
    for (int64_t i = 0; i < size(); i++) {
      buffer[i] = base + i * step;
    }
    return svld1_bf16(ptrue, reinterpret_cast<bfloat16_t*>(buffer));
  }
  static Vectorized<BFloat16> set(
      const Vectorized<BFloat16>& a,
      const Vectorized<BFloat16>& b,
      int64_t count = size()) {
    if (count == 0) {
      return a;
    } else if (count < size()) {
      return svsel_bf16(svwhilelt_b16(0ull, count), b, a);
    }
    return b;
  }
  static Vectorized<BFloat16> loadu(const void* ptr, int64_t count = size()) {
    if (count == size())
      return svld1_bf16(ptrue, reinterpret_cast<const bfloat16_t*>(ptr));
    svbool_t pg = svwhilelt_b16(0ull, count);
    return svld1_bf16(pg, reinterpret_cast<const bfloat16_t*>(ptr));
  }
  void store(void* ptr, int64_t count = size()) const {
    __at_align__ bfloat16_t tmp[size()];
    std::memset(tmp, 0, sizeof(tmp));
    if (count == size()) {
      svst1_bf16(ptrue, reinterpret_cast<bfloat16_t*>(tmp), values);
    } else {
      svbool_t pg = svwhilelt_b16(0ull, count);
      svst1_bf16(pg, reinterpret_cast<bfloat16_t*>(tmp), values);
    }
    std::memcpy(
        reinterpret_cast<bfloat16_t*>(ptr),
        reinterpret_cast<const bfloat16_t*>(tmp),
        count * sizeof(bfloat16_t));
  }
  const BFloat16& operator[](int idx) const = delete;
  BFloat16& operator[](int idx) = delete;
  int64_t zero_mask() const {
    int64_t mask = 0;
    // returns an integer mask where all zero elements are translated to
    // 1-bit and others are translated to 0-bit int64_t mask = 0;
    __at_align__ int16_t mask_array[size()];

    svbool_t svbool_mask =
        svcmpeq_f16(ptrue, svreinterpret_f16_bf16(values), ZERO_F16);
    svst1_s16(
        ptrue,
        mask_array,
        svsel_s16(svbool_mask, ALL_S16_TRUE_MASK, ALL_S16_FALSE_MASK));
    for (int64_t i = 0; i < size(); ++i) {
      if (mask_array[i])
        mask |= (1ull << i);
    }
    return mask;
  }
  Vectorized<BFloat16> isnan() const;
  bool has_inf_nan() const;
  Vectorized<BFloat16> map(BFloat16 (*f)(BFloat16)) const {
    __at_align__ BFloat16 tmp[size()];
    store(tmp);
    for (int64_t i = 0; i < size(); ++i) {
      tmp[i] = f(tmp[i]);
    }
    return loadu(tmp);
  }
  Vectorized<BFloat16> abs() const {
    auto mask = svdup_n_u16(0x7FFF);
    auto vals = svreinterpret_u16_bf16(values);
    vals = svand_u16_x(ptrue, vals, mask);
    return svreinterpret_bf16_u16(vals);
  }
  Vectorized<BFloat16> angle() const;
  Vectorized<BFloat16> real() const {
    return values;
  }
  Vectorized<BFloat16> imag() const {
    return Vectorized<BFloat16>(0.f);
  }
  Vectorized<BFloat16> conj() const {
    return values;
  }
  Vectorized<BFloat16> acos() const;
  Vectorized<BFloat16> acosh() const;
  Vectorized<BFloat16> asin() const;
  Vectorized<BFloat16> atan() const;
  Vectorized<BFloat16> atanh() const;
  Vectorized<BFloat16> atan2(const Vectorized<BFloat16>& b) const;
  Vectorized<BFloat16> copysign(const Vectorized<BFloat16>& sign) const;
  Vectorized<BFloat16> erf() const;
  Vectorized<BFloat16> erfc() const;
  Vectorized<BFloat16> erfinv() const;
  Vectorized<BFloat16> exp() const;
  Vectorized<BFloat16> exp2() const;
  Vectorized<BFloat16> expm1() const;
  Vectorized<BFloat16> exp_u20() const {
    return exp();
  }
  Vectorized<BFloat16> fexp_u20() const {
    return exp();
  }
  Vectorized<BFloat16> fmod(const Vectorized<BFloat16>& q) const;
  Vectorized<BFloat16> hypot(const Vectorized<BFloat16>& b) const;
  Vectorized<BFloat16> i0() const;
  Vectorized<BFloat16> i0e() const;
  Vectorized<BFloat16> digamma() const;
  Vectorized<BFloat16> igamma(const Vectorized<BFloat16>& x) const;
  Vectorized<BFloat16> igammac(const Vectorized<BFloat16>& x) const;
  Vectorized<BFloat16> nextafter(const Vectorized<BFloat16>& b) const;
  Vectorized<BFloat16> log() const;
  Vectorized<BFloat16> log2() const;
  Vectorized<BFloat16> log10() const;
  Vectorized<BFloat16> log1p() const;
  Vectorized<BFloat16> frac() const;
  Vectorized<BFloat16> sin() const;
  Vectorized<BFloat16> sinh() const;
  Vectorized<BFloat16> cos() const;
  Vectorized<BFloat16> cosh() const;
  Vectorized<BFloat16> ceil() const;
  Vectorized<BFloat16> floor() const;
  Vectorized<BFloat16> neg() const {
    auto mask = svdup_n_u16(0x8000);
    auto vals = svreinterpret_u16_bf16(values);
    vals = sveor_u16_x(ptrue, vals, mask);
    return svreinterpret_bf16_u16(vals);
  }
  Vectorized<BFloat16> round() const;
  Vectorized<BFloat16> tan() const;
  Vectorized<BFloat16> tanh() const;
  Vectorized<BFloat16> trunc() const;
  Vectorized<BFloat16> lgamma() const;
  Vectorized<BFloat16> sqrt() const;
  Vectorized<BFloat16> reciprocal() const;
  Vectorized<BFloat16> rsqrt() const;
  Vectorized<BFloat16> pow(const Vectorized<BFloat16>& b) const;
  // Comparison using the _CMP_**_OQ predicate.
  //   `O`: get false if an operand is NaN
  //   `Q`: do not raise if an operand is NaN
  Vectorized<BFloat16> operator==(const Vectorized<BFloat16>& other) const;

  Vectorized<BFloat16> operator!=(const Vectorized<BFloat16>& other) const;

  Vectorized<BFloat16> operator<(const Vectorized<BFloat16>& other) const;

  Vectorized<BFloat16> operator<=(const Vectorized<BFloat16>& other) const;

  Vectorized<BFloat16> operator>(const Vectorized<BFloat16>& other) const;

  Vectorized<BFloat16> operator>=(const Vectorized<BFloat16>& other) const;

  Vectorized<BFloat16> eq(const Vectorized<BFloat16>& other) const;
  Vectorized<BFloat16> ne(const Vectorized<BFloat16>& other) const;
  Vectorized<BFloat16> gt(const Vectorized<BFloat16>& other) const;
  Vectorized<BFloat16> ge(const Vectorized<BFloat16>& other) const;
  Vectorized<BFloat16> lt(const Vectorized<BFloat16>& other) const;
  Vectorized<BFloat16> le(const Vectorized<BFloat16>& other) const;
};

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free