Home / Class/ __may_alias__ Class — pytorch Architecture

__may_alias__ Class — pytorch Architecture

Architecture documentation for the __may_alias__ class in vec256_complex_float_vsx.h from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/cpu/vec/vec256/vsx/vec256_complex_float_vsx.h lines 18–769

template <>
class Vectorized<ComplexFlt> {
 private:
  union {
    struct {
      vfloat32 _vec0;
      vfloat32 _vec1;
    };
    struct {
      vbool32 _vecb0;
      vbool32 _vecb1;
    };

  } __attribute__((__may_alias__));

 public:
  using value_type = ComplexFlt;
  using vec_internal_type = vfloat32;
  using vec_internal_mask_type = vbool32;
  using size_type = int;

  static constexpr size_type size() {
    return 4;
  }
  Vectorized() {}

  C10_ALWAYS_INLINE Vectorized(vfloat32 v) : _vec0{v}, _vec1{v} {}
  C10_ALWAYS_INLINE Vectorized(vbool32 vmask) : _vecb0{vmask}, _vecb1{vmask} {}
  C10_ALWAYS_INLINE Vectorized(vfloat32 v1, vfloat32 v2)
      : _vec0{v1}, _vec1{v2} {}
  C10_ALWAYS_INLINE Vectorized(vbool32 v1, vbool32 v2)
      : _vecb0{v1}, _vecb1{v2} {}

  Vectorized(ComplexFlt val) {
    float real_value = val.real();
    float imag_value = val.imag();
    _vec0 = vfloat32{real_value, imag_value, real_value, imag_value};
    _vec1 = vfloat32{real_value, imag_value, real_value, imag_value};
  }

  Vectorized(
      ComplexFlt val1,
      ComplexFlt val2,
      ComplexFlt val3,
      ComplexFlt val4) {
    _vec0 = vfloat32{val1.real(), val1.imag(), val2.real(), val2.imag()};
    _vec1 = vfloat32{val3.real(), val3.imag(), val4.real(), val4.imag()};
  }

  C10_ALWAYS_INLINE const vec_internal_type& vec0() const {
    return _vec0;
  }
  C10_ALWAYS_INLINE const vec_internal_type& vec1() const {
    return _vec1;
  }

  template <uint64_t mask>
  static std::enable_if_t<blendChoiceComplex(mask) == 0, Vectorized<ComplexFlt>>
      C10_ALWAYS_INLINE
      blend(const Vectorized<ComplexFlt>& a, const Vectorized<ComplexFlt>& b) {
    return a;
  }

  template <uint64_t mask>
  static std::enable_if_t<blendChoiceComplex(mask) == 1, Vectorized<ComplexFlt>>
      C10_ALWAYS_INLINE
      blend(const Vectorized<ComplexFlt>& a, const Vectorized<ComplexFlt>& b) {
    return b;
  }

  template <uint64_t mask>
  static std::enable_if_t<blendChoiceComplex(mask) == 2, Vectorized<ComplexFlt>>
      C10_ALWAYS_INLINE
      blend(const Vectorized<ComplexFlt>& a, const Vectorized<ComplexFlt>& b) {
    return {b._vec0, a._vec1};
  }

  template <uint64_t mask>
  static std::enable_if_t<blendChoiceComplex(mask) == 3, Vectorized<ComplexFlt>>
      C10_ALWAYS_INLINE
      blend(const Vectorized<ComplexFlt>& a, const Vectorized<ComplexFlt>& b) {
    return {a._vec0, b._vec1};
  }

  template <uint64_t mask>
  static std::enable_if_t<blendChoiceComplex(mask) == 4, Vectorized<ComplexFlt>>
      C10_ALWAYS_INLINE
      blend(const Vectorized<ComplexFlt>& a, const Vectorized<ComplexFlt>& b) {
    const vbool32 mask_1st = VsxComplexMask1(mask);
    return {(vfloat32)vec_sel(a._vec0, b._vec0, mask_1st), a._vec1};
  }

  template <uint64_t mask>
  static std::enable_if_t<blendChoiceComplex(mask) == 5, Vectorized<ComplexFlt>>
      C10_ALWAYS_INLINE
      blend(const Vectorized<ComplexFlt>& a, const Vectorized<ComplexFlt>& b) {
    const vbool32 mask_1st = VsxComplexMask1(mask);
    return {(vfloat32)vec_sel(a._vec0, b._vec0, mask_1st), b._vec1};
  }

  template <uint64_t mask>
  static std::enable_if_t<blendChoiceComplex(mask) == 6, Vectorized<ComplexFlt>>
      C10_ALWAYS_INLINE
      blend(const Vectorized<ComplexFlt>& a, const Vectorized<ComplexFlt>& b) {
    const vbool32 mask_2nd = VsxComplexMask2(mask);
    // generated masks
    return {a._vec0, (vfloat32)vec_sel(a._vec1, b._vec1, mask_2nd)};
  }

  template <uint64_t mask>
  static std::enable_if_t<blendChoiceComplex(mask) == 7, Vectorized<ComplexFlt>>
      C10_ALWAYS_INLINE
      blend(const Vectorized<ComplexFlt>& a, const Vectorized<ComplexFlt>& b) {
    const vbool32 mask_2nd = VsxComplexMask2(mask);
    // generated masks
    return {b._vec0, (vfloat32)vec_sel(a._vec1, b._vec1, mask_2nd)};
  }

  template <uint64_t mask>
  static std::enable_if_t<blendChoiceComplex(mask) == 8, Vectorized<ComplexFlt>>
      C10_ALWAYS_INLINE
      blend(const Vectorized<ComplexFlt>& a, const Vectorized<ComplexFlt>& b) {
    const vbool32 mask_1st = VsxComplexMask1(mask);
    const vbool32 mask_2nd = VsxComplexMask2(mask);
    return {
        (vfloat32)vec_sel(a._vec0, b._vec0, mask_1st),
        (vfloat32)vec_sel(a._vec1, b._vec1, mask_2nd)};
  }

  template <int64_t mask>
  static Vectorized<ComplexFlt> C10_ALWAYS_INLINE
  el_blend(const Vectorized<ComplexFlt>& a, const Vectorized<ComplexFlt>& b) {
    const vbool32 mask_1st = VsxMask1(mask);
    const vbool32 mask_2nd = VsxMask2(mask);
    return {
        (vfloat32)vec_sel(a._vec0, b._vec0, mask_1st),
        (vfloat32)vec_sel(a._vec1, b._vec1, mask_2nd)};
  }

  static Vectorized<ComplexFlt> blendv(
      const Vectorized<ComplexFlt>& a,
      const Vectorized<ComplexFlt>& b,
      const Vectorized<ComplexFlt>& mask) {
    // convert std::complex<V> index mask to V index mask: xy -> xxyy
    auto mask_complex = Vectorized<ComplexFlt>(
        vec_mergeh(mask._vec0, mask._vec0), vec_mergeh(mask._vec1, mask._vec1));
    return {
        vec_sel(
            a._vec0, b._vec0, reinterpret_cast<vbool32>(mask_complex._vec0)),
        vec_sel(
            a._vec1, b._vec1, reinterpret_cast<vbool32>(mask_complex._vec1)),
    };
  }

  static Vectorized<ComplexFlt> elwise_blendv(
      const Vectorized<ComplexFlt>& a,
      const Vectorized<ComplexFlt>& b,
      const Vectorized<ComplexFlt>& mask) {
    return {
        vec_sel(a._vec0, b._vec0, reinterpret_cast<vbool32>(mask._vec0)),
        vec_sel(a._vec1, b._vec1, reinterpret_cast<vbool32>(mask._vec1)),
    };
  }

  template <typename step_t>
  static Vectorized<ComplexFlt> arange(
      ComplexFlt base = 0.,
      step_t step = static_cast<step_t>(1)) {
    return Vectorized<ComplexFlt>(
        base,
        base + step,
        base + ComplexFlt(2) * step,
        base + ComplexFlt(3) * step);
  }
  static Vectorized<ComplexFlt> set(
      const Vectorized<ComplexFlt>& a,
      const Vectorized<ComplexFlt>& b,
      int64_t count = size()) {
    switch (count) {
      case 0:
        return a;
      case 1:
        return blend<1>(a, b);
      case 2:
        return blend<3>(a, b);
      case 3:
        return blend<7>(a, b);
    }
    return b;
  }

  static Vectorized<value_type> C10_ALWAYS_INLINE
  loadu(const void* ptr, int count = size()) {
    if (count == size()) {
      return {
          vec_vsx_ld(offset0, reinterpret_cast<const float*>(ptr)),
          vec_vsx_ld(offset16, reinterpret_cast<const float*>(ptr))};
    }

    __at_align__ value_type tmp_values[size()] = {};
    std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type));

    return {
        vec_vsx_ld(offset0, reinterpret_cast<const float*>(tmp_values)),
        vec_vsx_ld(offset16, reinterpret_cast<const float*>(tmp_values))};
  }

  void C10_ALWAYS_INLINE store(void* ptr, int count = size()) const {
    if (count == size()) {
      vec_vsx_st(_vec0, offset0, reinterpret_cast<float*>(ptr));
      vec_vsx_st(_vec1, offset16, reinterpret_cast<float*>(ptr));
    } else if (count > 0) {
      __at_align__ value_type tmp_values[size()];
      vec_vsx_st(_vec0, offset0, reinterpret_cast<float*>(tmp_values));
      vec_vsx_st(_vec1, offset16, reinterpret_cast<float*>(tmp_values));
      std::memcpy(
          ptr, tmp_values, std::min(count, size()) * sizeof(value_type));
    }
  }

  const ComplexFlt& operator[](int idx) const = delete;
  ComplexFlt& operator[](int idx) = delete;

  Vectorized<ComplexFlt> map(ComplexFlt (*const f)(ComplexFlt)) const {
    __at_align__ ComplexFlt tmp[size()];
    store(tmp);
    for (const auto i : c10::irange(size())) {
      tmp[i] = f(tmp[i]);
    }
    return loadu(tmp);
  }

  Vectorized<ComplexFlt> map(ComplexFlt (*const f)(const ComplexFlt&)) const {
    __at_align__ ComplexFlt tmp[size()];
    store(tmp);
    for (const auto i : c10::irange(size())) {
      tmp[i] = f(tmp[i]);
    }
    return loadu(tmp);
  }

  static Vectorized<ComplexFlt> horizontal_add(
      Vectorized<ComplexFlt>& first,
      Vectorized<ComplexFlt>& second) {
    // Operates on individual floats, see _mm_hadd_ps
    // {f0+f1, s0+s1, f2+f3, s2+s3, ...}
    // i.e. it sums the re and im of each value and interleaves first and
    // second: {f_re0 + f_im0, s_re0 + s_im0, f_re1 + f_im1, s_re1 + s_im1, ...}
    return el_mergee(first, second) + el_mergeo(first, second);
  }

  static Vectorized<ComplexFlt> horizontal_sub_permD8(
      Vectorized<ComplexFlt>& first,
      Vectorized<ComplexFlt>& second) {
    // we will simulate it differently with 6 instructions total
    // lets permute second so that we can add it getting horizontal sums
    auto first_perm = first.el_swapped(); // 2perm
    auto second_perm = second.el_swapped(); // 2perm
    // sum
    auto first_ret = first - first_perm; // 2sub
    auto second_ret = second - second_perm; // 2 sub
    // now lets choose evens
    return el_mergee(first_ret, second_ret); // 2 mergee's
  }

  Vectorized<ComplexFlt> abs_2_() const {
    auto a = (*this).elwise_mult(*this);
    auto permuted = a.el_swapped();
    a = a + permuted;
    return a.el_mergee();
  }

  Vectorized<ComplexFlt> abs_() const {
    auto vi = el_mergeo();
    auto vr = el_mergee();
    return {
        Sleef_hypotf4_u05vsx(vr._vec0, vi._vec0),
        Sleef_hypotf4_u05vsx(vr._vec1, vi._vec1)};
  }

  Vectorized<ComplexFlt> abs() const {
    return abs_() & real_mask;
  }

  Vectorized<ComplexFlt> real_() const {
    return *this & real_mask;
  }
  Vectorized<ComplexFlt> real() const {
    return *this & real_mask;
  }
  Vectorized<ComplexFlt> imag_() const {
    return *this & imag_mask;
  }
  Vectorized<ComplexFlt> imag() const {
    // we can use swap_mask or sldwi
    auto ret = imag_();
    return {
        vec_sldw(ret._vec0, ret._vec0, 3), vec_sldw(ret._vec1, ret._vec1, 3)};
  }

  Vectorized<ComplexFlt> conj_() const {
    return *this ^ isign_mask;
  }
  Vectorized<ComplexFlt> conj() const {
    return *this ^ isign_mask;
  }

  Vectorized<ComplexFlt> log() const {
    // Most trigonomic ops use the log() op to improve complex number
    // performance.
    return map(std::log);
  }

  Vectorized<ComplexFlt> log2() const {
    // log2eB_inv
    auto ret = log();
    return ret.elwise_mult(log2e_inv);
  }
  Vectorized<ComplexFlt> log10() const {
    auto ret = log();
    return ret.elwise_mult(log10e_inv);
  }

  Vectorized<ComplexFlt> log1p() const {
    return map(std::log1p);
  }

  Vectorized<ComplexFlt> el_swapped() const {
    vfloat32 v0 = vec_perm(_vec0, _vec0, swap_mask);
    vfloat32 v1 = vec_perm(_vec1, _vec1, swap_mask);
    return {v0, v1};
  }

  Vectorized<ComplexFlt> el_mergee() const {
    // as mergee phased in , we can use vec_perm with mask
    return {vec_mergee(_vecb0, _vecb0), vec_mergee(_vecb1, _vecb1)};
  }

  Vectorized<ComplexFlt> el_mergeo() const {
    // as mergeo phased in , we can use vec_perm with mask
    return {vec_mergeo(_vecb0, _vecb0), vec_mergeo(_vecb1, _vecb1)};
  }

  Vectorized<ComplexFlt> el_madd(
      const Vectorized<ComplexFlt>& multiplier,
      const Vectorized<ComplexFlt>& val) const {
    return {
        vec_madd(_vec0, multiplier._vec0, val._vec0),
        vec_madd(_vec1, multiplier._vec1, val._vec1)};
  }

  static Vectorized<ComplexFlt> el_mergee(
      const Vectorized<ComplexFlt>& first,
      const Vectorized<ComplexFlt>& second) {
    return {
        vec_mergee(first._vecb0, second._vecb0),
        vec_mergee(first._vecb1, second._vecb1)};
  }

  static Vectorized<ComplexFlt> el_mergeo(
      const Vectorized<ComplexFlt>& first,
      const Vectorized<ComplexFlt>& second) {
    return {
        vec_mergeo(first._vecb0, second._vecb0),
        vec_mergeo(first._vecb1, second._vecb1)};
  }

  Vectorized<ComplexFlt> angle_() const {
    // angle = atan2(b/a)
    // auto b_a = _mm256_permute_ps(values, 0xB1); // b        a
    // return Sleef_atan2f8_u10(values, b_a); // 90-angle angle
    Vectorized<ComplexFlt> ret;
    for (int i = 0; i < 4; i += 2) {
      ret._vec0[i] = std::atan2(_vec0[i + 1], _vec0[i]);
      ret._vec1[i] = std::atan2(_vec1[i + 1], _vec1[i]);
    }
    return ret;
  }

  Vectorized<ComplexFlt> angle() const {
    return angle_() & real_mask;
  }

  Vectorized<ComplexFlt> sin() const {
    return map(std::sin);
  }
  Vectorized<ComplexFlt> sinh() const {
    return map(std::sinh);
  }
  Vectorized<ComplexFlt> cos() const {
    return map(std::cos);
  }
  Vectorized<ComplexFlt> cosh() const {
    return map(std::cosh);
  }
  Vectorized<ComplexFlt> ceil() const {
    return {vec_ceil(_vec0), vec_ceil(_vec1)};
  }
  Vectorized<ComplexFlt> floor() const {
    return {vec_floor(_vec0), vec_floor(_vec1)};
  }
  Vectorized<ComplexFlt> neg() const {
    auto z = Vectorized<ComplexFlt>(zero);
    return z - *this;
  }
  Vectorized<ComplexFlt> round() const {
    return {vec_round(_vec0), vec_round(_vec1)};
  }
  Vectorized<ComplexFlt> tan() const {
    return map(std::tan);
  }
  Vectorized<ComplexFlt> tanh() const {
    return map(std::tanh);
  }
  Vectorized<ComplexFlt> trunc() const {
    return {vec_trunc(_vec0), vec_trunc(_vec1)};
  }

  Vectorized<ComplexFlt> elwise_sqrt() const {
    return {vec_sqrt(_vec0), vec_sqrt(_vec1)};
  }

  Vectorized<ComplexFlt> sqrt() const {
    return map(std::sqrt);
  }

  Vectorized<ComplexFlt> reciprocal() const {
    // re + im*i = (a + bi)  / (c + di)
    // re = (ac + bd)/abs_2() = c/abs_2()
    // im = (bc - ad)/abs_2() = d/abs_2()
    auto c_d = *this ^ isign_mask; // c       -d
    auto abs = abs_2_();
    return c_d.elwise_div(abs);
  }

  Vectorized<ComplexFlt> rsqrt() const {
    return sqrt().reciprocal();
  }

  Vectorized<ComplexFlt> pow(const Vectorized<ComplexFlt>& exp) const {
    __at_align__ ComplexFlt x_tmp[size()];
    __at_align__ ComplexFlt y_tmp[size()];
    store(x_tmp);
    exp.store(y_tmp);
    for (const auto i : c10::irange(size())) {
      x_tmp[i] = std::pow(x_tmp[i], y_tmp[i]);
    }
    return loadu(x_tmp);
  }

  Vectorized<ComplexFlt> atan() const {
    // atan(x) = i/2 * ln((i + z)/(i - z))
    auto ione = Vectorized(imag_one);
    auto sum = ione + *this;
    auto sub = ione - *this;
    auto ln = (sum / sub).log(); // ln((i + z)/(i - z))
    return ln * imag_half; // i/2*ln()
  }
  Vectorized<ComplexFlt> atanh() const {
    return map(std::atanh);
  }

  Vectorized<ComplexFlt> acos() const {
    // acos(x) = pi/2 - asin(x)
    return Vectorized(pi_2) - asin();
  }

  Vectorized<ComplexFlt> inline operator*(
      const Vectorized<ComplexFlt>& b) const {
    //(a + bi)  * (c + di) = (ac - bd) + (ad + bc)i

#if 1
    // this is more vsx friendly than simulating horizontal from x86

    auto vi = b.el_mergeo();
    auto vr = b.el_mergee();
    vi = vi ^ rsign_mask;
    auto ret = elwise_mult(vr);
    auto vx_swapped = el_swapped();
    ret = vx_swapped.elwise_mult(vi) + ret;
    return ret;

#else

    auto ac_bd = elwise_mult(b);
    auto d_c = b.el_swapped();
    d_c = d_c ^ isign_mask;
    auto ad_bc = elwise_mult(d_c);
    auto ret = horizontal_sub_permD8(ac_bd, ad_bc);
    return ret;
#endif
  }

  Vectorized<ComplexFlt> inline operator/(
      const Vectorized<ComplexFlt>& b) const {
#if 1
    __at_align__ c10::complex<float>
        tmp1[Vectorized<c10::complex<float>>::size()];
    __at_align__ c10::complex<float>
        tmp2[Vectorized<c10::complex<float>>::size()];
    __at_align__ c10::complex<float>
        out[Vectorized<c10::complex<float>>::size()];
    this->store(tmp1);
    b.store(tmp2);

    for (const auto i : c10::irange(Vectorized<c10::complex<float>>::size())) {
      out[i] = tmp1[i] / tmp2[i];
    }
    return loadu(out);
#else
    auto fabs_cd = Vectorized{
        vec_andc(b._vec0, sign_mask), vec_andc(b._vec1, sign_mask)}; // |c| |d|
    auto fabs_dc = fabs_cd.el_swapped(); // |d|            |c|
    auto scale = fabs_cd.elwise_max(fabs_dc); // sc = max(|c|, |d|)
    auto a2 = elwise_div(scale); // a/sc           b/sc
    auto b2 = b.elwise_div(scale); // c/sc           d/sc
    auto acbd2 = a2.elwise_mult(b2); // ac/sc^2        bd/s
    auto dc2 = b2.el_swapped(); // d/sc           c/sc
    dc2 = dc2 ^ rsign_mask; // -d/sc          c/sc
    auto adbc2 = a2.elwise_mult(dc2); // -ad/sc^2       bc/sc^2
    auto ret = horizontal_add(acbd2, adbc2); // (ac+bd)/sc^2   (bc-ad)/sc^2
    auto denom2 = b2.abs_2_(); // (c^2+d^2)/sc^2 (c^2+d^2)/sc^2
    ret = ret.elwise_div(denom2);
    return ret;
#endif
  }

  Vectorized<ComplexFlt> asin() const {
    // asin(x)
    // = -i*ln(iz + sqrt(1 -z^2))
    // = -i*ln((ai - b) + sqrt(1 - (a + bi)*(a + bi)))
    // = -i*ln((-b + ai) + sqrt(1 - (a**2 - b**2) - 2*abi))

#if 1
    auto conj = conj_();
    auto b_a = conj.el_swapped();
    auto ab = conj.elwise_mult(b_a);
    auto im = ab + ab;
    auto val_2 = (*this).elwise_mult(*this);
    auto val_2_swapped = val_2.el_swapped();
    auto re = horizontal_sub_permD8(val_2, val_2_swapped);
    re = Vectorized<ComplexFlt>(one) - re;
    auto root = el_blend<0xAA>(re, im).sqrt();
    auto ln = (b_a + root).log();
    return ln.el_swapped().conj();
#else
    return map(std::asin);
#endif
  }

  Vectorized<ComplexFlt> exp() const {
    return map(std::exp);
  }
  Vectorized<ComplexFlt> exp2() const {
    return map(exp2_impl);
  }
  Vectorized<ComplexFlt> expm1() const {
    return map(std::expm1);
  }

  Vectorized<ComplexFlt> eq(const Vectorized<ComplexFlt>& other) const {
    auto eq = (*this == other); // compares real and imag individually
    // If both real numbers and imag numbers are equal, then the complex numbers
    // are equal
    return (eq.real() & eq.imag()) & one;
  }
  Vectorized<ComplexFlt> ne(const Vectorized<ComplexFlt>& other) const {
    auto ne = (*this != other); // compares real and imag individually
    // If either real numbers or imag numbers are not equal, then the complex
    // numbers are not equal
    return (ne.real() | ne.imag()) & one;
  }

  Vectorized<ComplexFlt> sgn() const {
    return map(at::native::sgn_impl);
  }

  Vectorized<ComplexFlt> operator<(const Vectorized<ComplexFlt>& other) const {
    TORCH_CHECK(false, "not supported for complex numbers");
  }

  Vectorized<ComplexFlt> operator<=(const Vectorized<ComplexFlt>& other) const {
    TORCH_CHECK(false, "not supported for complex numbers");
  }

  Vectorized<ComplexFlt> operator>(const Vectorized<ComplexFlt>& other) const {
    TORCH_CHECK(false, "not supported for complex numbers");
  }

  Vectorized<ComplexFlt> operator>=(const Vectorized<ComplexFlt>& other) const {
    TORCH_CHECK(false, "not supported for complex numbers");
  }

  DEFINE_MEMBER_OP(operator==, ComplexFlt, vec_cmpeq)
  DEFINE_MEMBER_OP(operator!=, ComplexFlt, vec_cmpne)

  DEFINE_MEMBER_OP(operator+, ComplexFlt, vec_add)
  DEFINE_MEMBER_OP(operator-, ComplexFlt, vec_sub)
  DEFINE_MEMBER_OP(operator&, ComplexFlt, vec_and)
  DEFINE_MEMBER_OP(operator|, ComplexFlt, vec_or)
  DEFINE_MEMBER_OP(operator^, ComplexFlt, vec_xor)
  // elementwise helpers
  DEFINE_MEMBER_OP(elwise_mult, ComplexFlt, vec_mul)
  DEFINE_MEMBER_OP(elwise_div, ComplexFlt, vec_div)
  DEFINE_MEMBER_OP(elwise_gt, ComplexFlt, vec_cmpgt)
  DEFINE_MEMBER_OP(elwise_ge, ComplexFlt, vec_cmpge)
  DEFINE_MEMBER_OP(elwise_lt, ComplexFlt, vec_cmplt)
  DEFINE_MEMBER_OP(elwise_le, ComplexFlt, vec_cmple)
  DEFINE_MEMBER_OP(elwise_max, ComplexFlt, vec_max)
};

template <>
Vectorized<ComplexFlt> inline maximum(
    const Vectorized<ComplexFlt>& a,
    const Vectorized<ComplexFlt>& b) {
  auto abs_a = a.abs_2_();
  auto abs_b = b.abs_2_();
  // auto mask = _mm256_cmp_ps(abs_a, abs_b, _CMP_LT_OQ);
  // auto max = _mm256_blendv_ps(a, b, mask);
  auto mask = abs_a.elwise_lt(abs_b);
  auto max = Vectorized<ComplexFlt>::elwise_blendv(a, b, mask);

  return max;
  // Exploit the fact that all-ones is a NaN.
  // auto isnan = _mm256_cmp_ps(abs_a, abs_b, _CMP_UNORD_Q);
  // return _mm256_or_ps(max, isnan);
}

template <>
Vectorized<ComplexFlt> inline minimum(
    const Vectorized<ComplexFlt>& a,
    const Vectorized<ComplexFlt>& b) {
  auto abs_a = a.abs_2_();
  auto abs_b = b.abs_2_();
  // auto mask = _mm256_cmp_ps(abs_a, abs_b, _CMP_GT_OQ);
  // auto min = _mm256_blendv_ps(a, b, mask);
  auto mask = abs_a.elwise_gt(abs_b);
  auto min = Vectorized<ComplexFlt>::elwise_blendv(a, b, mask);
  return min;
  // Exploit the fact that all-ones is a NaN.
  // auto isnan = _mm256_cmp_ps(abs_a, abs_b, _CMP_UNORD_Q);
  // return _mm256_or_ps(min, isnan);
}

template <>
Vectorized<ComplexFlt> C10_ALWAYS_INLINE
operator+(const Vectorized<ComplexFlt>& a, const Vectorized<ComplexFlt>& b) {
  return Vectorized<ComplexFlt>{
      vec_add(a.vec0(), b.vec0()), vec_add(a.vec1(), b.vec1())};
}

template <>
Vectorized<ComplexFlt> C10_ALWAYS_INLINE
operator-(const Vectorized<ComplexFlt>& a, const Vectorized<ComplexFlt>& b) {
  return Vectorized<ComplexFlt>{
      vec_sub(a.vec0(), b.vec0()), vec_sub(a.vec1(), b.vec1())};
}

template <>
Vectorized<ComplexFlt> C10_ALWAYS_INLINE
operator&(const Vectorized<ComplexFlt>& a, const Vectorized<ComplexFlt>& b) {
  return Vectorized<ComplexFlt>{
      vec_and(a.vec0(), b.vec0()), vec_and(a.vec1(), b.vec1())};
}

template <>
Vectorized<ComplexFlt> C10_ALWAYS_INLINE
operator|(const Vectorized<ComplexFlt>& a, const Vectorized<ComplexFlt>& b) {
  return Vectorized<ComplexFlt>{
      vec_or(a.vec0(), b.vec0()), vec_or(a.vec1(), b.vec1())};
}

template <>
Vectorized<ComplexFlt> C10_ALWAYS_INLINE
operator^(const Vectorized<ComplexFlt>& a, const Vectorized<ComplexFlt>& b) {
  return Vectorized<ComplexFlt>{
      vec_xor(a.vec0(), b.vec0()), vec_xor(a.vec1(), b.vec1())};
}

template <>
Vectorized<ComplexFlt> C10_ALWAYS_INLINE
operator*(const Vectorized<ComplexFlt>& a, const Vectorized<ComplexFlt>& b) {
  // (a + ib) * (c + id) = (ac - bd) + i(ad + bc)
  // Split into real and imaginary parts
  auto a_real = a.el_mergee(); // real part of a
  auto a_imag = a.el_mergeo(); // imag part of a
  auto b_real = b.el_mergee(); // real part of b
  auto b_imag = b.el_mergeo(); // imag part of b

  auto b_imag_neg = b_imag ^ rsign_mask;
  // Compute components
  auto ac = a_real.elwise_mult(b_real); // real * real
  auto bd = a_imag.elwise_mult(b_imag_neg); // imag * imag
  auto ad = a_real.elwise_mult(b_imag); // real * imag
  auto bc = a_imag.elwise_mult(b_real); // imag * real

  // Real = ac - bd (fix the negative bd part)
  auto real = ac + bd; // Real part calculation
  auto imag = ad + bc; // Imaginary part calculation

  // Step 1: Extract from real and imag
  __vector float r0 = real.vec0(); // {r0, r1, r2, r3}
  __vector float i0 = imag.vec0(); // {i0, i1, i2, i3}

  __vector float r1 = real.vec1(); // imag[0..3]
  __vector float i1 = imag.vec1(); // imag[4..7]

  __vector unsigned char perm_lo = {
      0,
      1,
      2,
      3, // r0
      16,
      17,
      18,
      19, //
      8,
      9,
      10,
      11, // r1
      24,
      25,
      26,
      27};
  __vector float v0 =
      vec_perm(r0, i0, perm_lo); // Interleave r0 and i0, r1 and i1
  __vector float v1 = vec_perm(r1, i1, perm_lo);
  Vectorized<ComplexFlt> result(v0, v1);
  return result;
}

template <>
Vectorized<ComplexFlt> C10_ALWAYS_INLINE
operator/(const Vectorized<ComplexFlt>& a, const Vectorized<ComplexFlt>& b) {
  // Take absolute values of real and imaginary parts of b
  __at_align__ c10::complex<float>
      tmp1[Vectorized<c10::complex<float>>::size()];
  __at_align__ c10::complex<float>
      tmp2[Vectorized<c10::complex<float>>::size()];
  __at_align__ c10::complex<float> out[Vectorized<c10::complex<float>>::size()];
  a.store(tmp1);
  b.store(tmp2);
  for (const auto i :
       c10::irange(Vectorized<c10::complex<float>>::
                       size())) { //{Vectorized<c10::complex<float>>::size()))
                                  //{
    out[i] = tmp1[i] / tmp2[i];
  }
  return Vectorized<ComplexFlt>::loadu(out);
}

} // namespace CPU_CAPABILITY

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free