Vectorized Class — pytorch Architecture
Architecture documentation for the Vectorized class in vec128_float_neon.h from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/cpu/vec/vec128/vec128_float_neon.h lines 77–459
template <>
class Vectorized<float> {
private:
float32x4_t values;
public:
using value_type = float;
using size_type = int;
static constexpr size_type size() {
return 4;
}
Vectorized() {
values = vmovq_n_f32(0);
}
Vectorized(float32x4_t v) : values(v) {}
Vectorized(float val) : values{vdupq_n_f32(val)} {}
Vectorized(float val0, float val1, float val2, float val3)
: values{val0, val1, val2, val3} {}
Vectorized(float (&arr)[4]) : Vectorized(arr[0], arr[1], arr[2], arr[3]) {}
operator float32x4_t() const {
return values;
}
template <int64_t mask>
static Vectorized<float> blend(
const Vectorized<float>& a,
const Vectorized<float>& b) {
Vectorized<float> vec;
vec.values = BlendRegs < 0,
(mask & 0x01) != 0 > ::impl(a.values, b.values, vec.values);
vec.values = BlendRegs < 1,
(mask & 0x02) != 0 > ::impl(a.values, b.values, vec.values);
vec.values = BlendRegs < 2,
(mask & 0x04) != 0 > ::impl(a.values, b.values, vec.values);
vec.values = BlendRegs < 3,
(mask & 0x08) != 0 > ::impl(a.values, b.values, vec.values);
return vec;
}
static Vectorized<float> blendv(
const Vectorized<float>& a,
const Vectorized<float>& b,
const Vectorized<float>& mask) {
// TODO
// NB: This requires that each value, i.e., each uint value,
// of the mask either all be zeros or all be 1s.
// We perhaps need some kind of an assert?
// But that will affect performance.
Vectorized<float> vec(mask.values);
vec.values =
vbslq_f32(vreinterpretq_u32_f32(vec.values), b.values, a.values);
return vec;
}
template <typename step_t>
static Vectorized<float> arange(
float base = 0.f,
step_t step = static_cast<step_t>(1)) {
const Vectorized<float> base_vec(base);
const Vectorized<float> step_vec(step);
const Vectorized<float> step_sizes(0, 1, 2, 3);
return fmadd(step_sizes, step_vec, base_vec);
}
static Vectorized<float> set(
const Vectorized<float>& a,
const Vectorized<float>& b,
int64_t count = size()) {
switch (count) {
case 0:
return a;
case 1: {
Vectorized<float> vec;
static uint32x4_t mask_low = {0xFFFFFFFF, 0x0, 0x0, 0x0};
vec.values = vreinterpretq_f32_u32(mask_low);
vec.values =
vbslq_f32(vreinterpretq_u32_f32(vec.values), b.values, a.values);
return vec;
}
case 2: {
Vectorized<float> vec;
static uint32x4_t mask_low = {0xFFFFFFFF, 0xFFFFFFFF, 0x0, 0x0};
vec.values = vreinterpretq_f32_u32(mask_low);
vec.values =
vbslq_f32(vreinterpretq_u32_f32(vec.values), b.values, a.values);
return vec;
}
case 3: {
Vectorized<float> vec;
static uint32x4_t mask_low = {0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0x0};
vec.values = vreinterpretq_f32_u32(mask_low);
vec.values =
vbslq_f32(vreinterpretq_u32_f32(vec.values), b.values, a.values);
return vec;
}
}
return b;
}
static Vectorized<float> loadu(const void* ptr, int64_t count = size()) {
if (count == size()) {
return vld1q_f32(reinterpret_cast<const float*>(ptr));
} else {
__at_align__ float tmp_values[size()];
for (const auto i : c10::irange(size())) {
tmp_values[i] = 0.0;
}
std::memcpy(
tmp_values,
reinterpret_cast<const float*>(ptr),
count * sizeof(float));
return vld1q_f32(reinterpret_cast<const float*>(tmp_values));
}
}
void store(void* ptr, int64_t count = size()) const {
if (count == size()) {
vst1q_f32(reinterpret_cast<float*>(ptr), values);
} else {
float tmp_values[size()];
vst1q_f32(reinterpret_cast<float*>(tmp_values), values);
std::memcpy(ptr, tmp_values, count * sizeof(float));
}
}
// Very slow implementation of indexing.
// Only required because vec256_qint refers to this.
// Once we specialize that implementation for ARM
// this should be removed. TODO (kimishpatel)
float operator[](int idx) const {
__at_align__ float tmp[size()];
store(tmp);
return tmp[idx];
}
float operator[](int idx) {
__at_align__ float tmp[size()];
store(tmp);
return tmp[idx];
}
int zero_mask() const {
uint32x4_t is_zero_vec = vceqzq_f32(values);
const int32x4_t shift = vcombine_s32(
vcreate_s32(0x0 | (int64_t(0x1) << 32)),
vcreate_s32(0x2 | (int64_t(0x3) << 32)));
uint32x4_t bits_vec =
vshlq_u32(vandq_u32(is_zero_vec, vdupq_n_u32(1)), shift);
return vaddvq_u32(bits_vec);
}
Vectorized<float> isnan() const {
return vreinterpretq_f32_u32(vmvnq_u32(vceqq_f32(values, values)));
}
bool has_inf_nan() const {
__at_align__ float tmp[size()];
store(tmp);
for (const auto i : c10::irange(size())) {
if (_isnan(tmp[i]) || _isinf(tmp[i])) {
return true;
}
}
return false;
}
Vectorized<float> map(float (*const f)(float)) const {
__at_align__ float tmp[size()];
store(tmp);
for (const auto i : c10::irange(size())) {
tmp[i] = f(tmp[i]);
}
return loadu(tmp);
}
Vectorized<float> map2(
const Vectorized<float>& second,
float (*const f)(float, float)) const {
__at_align__ float tmp[size()];
__at_align__ float tmp_second[size()];
store(tmp);
second.store(tmp_second);
for (const auto i : c10::irange(size())) {
tmp[i] = f(tmp[i], tmp_second[i]);
}
return loadu(tmp);
}
Vectorized<float> abs() const {
return Vectorized<float>(vabsq_f32(values));
}
Vectorized<float> angle() const {
auto zero = Vectorized<float>(0);
auto pi = Vectorized<float>(c10::pi<float>);
auto tmp = blendv(zero, pi, *this < zero);
return blendv(tmp, *this, isnan());
}
Vectorized<float> real() const {
return *this;
}
Vectorized<float> imag() const {
return Vectorized<float>(0.f);
}
Vectorized<float> conj() const {
return *this;
}
#define DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC_WITH_SLEEF_NAME( \
name, sleef_name) \
Vectorized<float> name() const { \
return USE_SLEEF(Vectorized<float>(sleef_name(values)), map(std::name)); \
}
#define DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(name) \
DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC_WITH_SLEEF_NAME( \
name, Sleef_##name##f4_u10)
DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(acos)
DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(acosh)
DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(asin)
DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(asinh)
DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(atan)
DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(atanh)
#define DEFINE_SLEEF_COMPATIBLE_BINARY_ELEMENTWISE_FUNC_WITH_SLEEF_NAME( \
name, sleef_name) \
Vectorized<float> name(const Vectorized<float>& arg) const { \
return USE_SLEEF( \
Vectorized<float>(sleef_name(values, arg.values)), \
map2(arg, std::name)); \
}
#define DEFINE_SLEEF_COMPATIBLE_BINARY_ELEMENTWISE_FUNC(name) \
DEFINE_SLEEF_COMPATIBLE_BINARY_ELEMENTWISE_FUNC_WITH_SLEEF_NAME( \
name, Sleef_##name##f4_u10)
DEFINE_SLEEF_COMPATIBLE_BINARY_ELEMENTWISE_FUNC(atan2)
DEFINE_SLEEF_COMPATIBLE_BINARY_ELEMENTWISE_FUNC_WITH_SLEEF_NAME(
copysign,
Sleef_copysignf4)
Vectorized<float> erf() const;
DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC_WITH_SLEEF_NAME(
erfc,
Sleef_erfcf4_u15)
Vectorized<float> erfinv() const {
return map(calc_erfinv);
}
DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(exp)
DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(exp2)
DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(expm1)
// Implementation copied from Arm Optimized Routine
// https://github.com/ARM-software/optimized-routines/blob/master/math/aarch64/advsimd/expf.c
inline Vectorized<float> vexpq_f32_u20() const {
// bail out to sleef if it's a special case:
// i.e. there's an input s.t. |input| > 87.3....
const float32x4_t special_bound = vdupq_n_f32(0x1.5d5e2ap+6f);
uint32x4_t cmp = vcagtq_f32(values, special_bound);
if (vpaddd_u64(vreinterpretq_u64_u32(cmp)) != 0) {
return exp();
}
const float32x4_t inv_ln2 = vdupq_n_f32(0x1.715476p+0f);
const float ln2_hi = 0x1.62e4p-1f;
const float ln2_lo = 0x1.7f7d1cp-20f;
const float c0 = 0x1.0e4020p-7f;
const float c2 = 0x1.555e66p-3f;
const float32x4_t ln2_c02 = {ln2_hi, ln2_lo, c0, c2};
const uint32x4_t exponent_bias = vdupq_n_u32(0x3f800000);
const float32x4_t c1 = vdupq_n_f32(0x1.573e2ep-5f);
const float32x4_t c3 = vdupq_n_f32(0x1.fffdb6p-2f);
const float32x4_t c4 = vdupq_n_f32(0x1.ffffecp-1f);
/* exp(x) = 2^n (1 + poly(r)), with 1 + poly(r) in [1/sqrt(2),sqrt(2)]
x = ln2*n + r, with r in [-ln2/2, ln2/2]. */
float32x4_t n = vrndaq_f32(vmulq_f32(values, inv_ln2));
float32x4_t r = vfmsq_laneq_f32(values, n, ln2_c02, 0);
r = vfmsq_laneq_f32(r, n, ln2_c02, 1);
uint32x4_t e = vshlq_n_u32(vreinterpretq_u32_s32(vcvtq_s32_f32(n)), 23);
float32x4_t scale = vreinterpretq_f32_u32(vaddq_u32(e, exponent_bias));
float32x4_t r2 = vmulq_f32(r, r);
float32x4_t p = vfmaq_laneq_f32(c1, r, ln2_c02, 2);
float32x4_t q = vfmaq_laneq_f32(c3, r, ln2_c02, 3);
q = vfmaq_f32(q, p, r2);
p = vmulq_f32(c4, r);
float32x4_t poly = vfmaq_f32(p, q, r2);
return vfmaq_f32(scale, poly, scale);
}
Vectorized<float> exp_u20() const {
return vexpq_f32_u20();
}
Vectorized<float> fexp_u20() const {
return exp_u20();
}
DEFINE_SLEEF_COMPATIBLE_BINARY_ELEMENTWISE_FUNC_WITH_SLEEF_NAME(
fmod,
Sleef_fmodf4)
DEFINE_SLEEF_COMPATIBLE_BINARY_ELEMENTWISE_FUNC_WITH_SLEEF_NAME(
hypot,
Sleef_hypotf4_u05)
Vectorized<float> i0() const {
return map(calc_i0);
}
Vectorized<float> i0e() const {
return map(calc_i0e);
}
Vectorized<float> digamma() const {
return map(calc_digamma);
}
Vectorized<float> igamma(const Vectorized<float>& x) const {
return map2(x, calc_igamma);
}
Vectorized<float> igammac(const Vectorized<float>& x) const {
return map2(x, calc_igammac);
}
DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(log)
DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(log10)
DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(log1p)
DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(log2)
DEFINE_SLEEF_COMPATIBLE_BINARY_ELEMENTWISE_FUNC_WITH_SLEEF_NAME(
nextafter,
Sleef_nextafterf4)
Vectorized<float> frac() const;
DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(sin)
DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(sinh)
DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(cos)
DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(cosh)
Vectorized<float> ceil() const {
return map(at::native::ceil_impl);
}
Vectorized<float> floor() const {
return map(at::native::floor_impl);
}
Vectorized<float> neg() const {
return Vectorized<float>(vnegq_f32(values));
}
Vectorized<float> round() const {
// We do not use std::round because we would like to round midway numbers to
// the nearest even integer.
return map(at::native::round_impl);
}
DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(tan)
DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(tanh)
Vectorized<float> trunc() const {
return Vectorized<float>(vrndq_f32(values));
}
DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(lgamma)
Vectorized<float> sqrt() const {
return Vectorized<float>(vsqrtq_f32(values));
}
Vectorized<float> reciprocal() const {
return Vectorized<float>(vdivq_f32(vdupq_n_f32(1.0f), values));
}
Vectorized<float> rsqrt() const {
return this->sqrt().reciprocal();
}
DEFINE_SLEEF_COMPATIBLE_BINARY_ELEMENTWISE_FUNC(pow)
Vectorized<float> operator==(const Vectorized<float>& other) const {
return Vectorized<float>(
vreinterpretq_f32_u32(vceqq_f32(values, other.values)));
}
Vectorized<float> operator!=(const Vectorized<float>& other) const {
float32x4_t r0 =
vreinterpretq_f32_u32(vmvnq_u32(vceqq_f32(values, other.values)));
return Vectorized<float>(r0);
}
Vectorized<float> operator<(const Vectorized<float>& other) const {
return Vectorized<float>(
vreinterpretq_f32_u32(vcltq_f32(values, other.values)));
}
Vectorized<float> operator<=(const Vectorized<float>& other) const {
return Vectorized<float>(
vreinterpretq_f32_u32(vcleq_f32(values, other.values)));
}
Vectorized<float> operator>(const Vectorized<float>& other) const {
return Vectorized<float>(
vreinterpretq_f32_u32(vcgtq_f32(values, other.values)));
}
Vectorized<float> operator>=(const Vectorized<float>& other) const {
return Vectorized<float>(
vreinterpretq_f32_u32(vcgeq_f32(values, other.values)));
}
Vectorized<float> eq(const Vectorized<float>& other) const;
Vectorized<float> ne(const Vectorized<float>& other) const;
Vectorized<float> gt(const Vectorized<float>& other) const;
Vectorized<float> ge(const Vectorized<float>& other) const;
Vectorized<float> lt(const Vectorized<float>& other) const;
Vectorized<float> le(const Vectorized<float>& other) const;
};
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free