v00 Class — pytorch Architecture
Architecture documentation for the v00 class in vec128_half_neon.h from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/cpu/vec/vec128/vec128_half_neon.h lines 66–381
template <>
class Vectorized<c10::Half> : public Vectorized16<
float16x8_t,
c10::Half,
BlendHalfRegs,
Vectorized<c10::Half>> {
using Base = Vectorized16<
float16x8_t,
c10::Half,
BlendHalfRegs,
Vectorized<c10::Half>>;
friend Base;
private:
// We use these private map functions to implement various methods
Vectorized<c10::Half> map_with_vec_float_method(
Vectorized<float> (Vectorized<float>::*m)() const) const {
float32x4_t v00 = vcvt_f32_f16(vget_low_f16(values));
float32x4_t v01 = vcvt_f32_f16(vget_high_f16(values));
Vectorized<float> mv0 = (Vectorized<float>(v00).*m)();
Vectorized<float> mv1 = (Vectorized<float>(v01).*m)();
float16x4_t r00 = vcvt_f16_f32(mv0);
float16x4_t r01 = vcvt_f16_f32(mv1);
return Vectorized<c10::Half>(vcombine_f16(r00, r01));
}
Vectorized<c10::Half> map2_with_vec_float_method(
const Vectorized<c10::Half>& second,
Vectorized<float> (Vectorized<float>::*m)(const Vectorized<float>&)
const) const {
float32x4_t v00 = vcvt_f32_f16(vget_low_f16(values));
float32x4_t v01 = vcvt_f32_f16(vget_high_f16(values));
float32x4_t second_v00 = vcvt_f32_f16(vget_low_f16(second.values));
float32x4_t second_v01 = vcvt_f32_f16(vget_high_f16(second.values));
Vectorized<float> mv0 =
(Vectorized<float>(v00).*m)(Vectorized<float>(second_v00));
Vectorized<float> mv1 =
(Vectorized<float>(v01).*m)(Vectorized<float>(second_v01));
float16x4_t r00 = vcvt_f16_f32(mv0);
float16x4_t r01 = vcvt_f16_f32(mv1);
// Pack result into Vectorized<c10::Half>
return Vectorized<c10::Half>(vcombine_f16(r00, r01));
}
Vectorized<c10::Half> map2_bitmask_with_vec_float_method(
const Vectorized<c10::Half>& second,
Vectorized<float> (Vectorized<float>::*m)(const Vectorized<float>&)
const) const {
float32x4_t v00 = vcvt_f32_f16(vget_low_f16(values));
float32x4_t v01 = vcvt_f32_f16(vget_high_f16(values));
float32x4_t second_v00 = vcvt_f32_f16(vget_low_f16(second.values));
float32x4_t second_v01 = vcvt_f32_f16(vget_high_f16(second.values));
Vectorized<float> mv0 =
(Vectorized<float>(v00).*m)(Vectorized<float>(second_v00));
Vectorized<float> mv1 =
(Vectorized<float>(v01).*m)(Vectorized<float>(second_v01));
// Assume the operator returns a bitmask, not "real" floats, and
// just narrow the bits. All-ones is a NaN and will get mangled by
// conversion!
float16x4_t r00 =
vreinterpret_f16_u16(vmovn_u32(vreinterpretq_u32_f32(mv0)));
float16x4_t r01 =
vreinterpret_f16_u16(vmovn_u32(vreinterpretq_u32_f32(mv1)));
// Pack result into Vectorized<c10::Half>
return Vectorized<c10::Half>(vcombine_f16(r00, r01));
}
public:
using Vectorized16::Vectorized16;
Vectorized() = default;
// A ctor that accepts c10::Half is needed to fit interface with vec_base.h
// A second constructor that takes float16_t is also included
Vectorized(c10::Half val) : Vectorized((float16_t)val) {}
Vectorized(float16_t val) : Vectorized16(vdupq_n_f16(val)) {}
Vectorized(
value_type val0,
value_type val1,
value_type val2,
value_type val3,
value_type val4,
value_type val5,
value_type val6,
value_type val7)
: Vectorized16(
float16x8_t{val0, val1, val2, val3, val4, val5, val6, val7}) {}
static Vectorized<c10::Half> blendv(
const Vectorized<c10::Half>& a,
const Vectorized<c10::Half>& b,
const Vectorized<c10::Half>& mask) {
// Note: using blendv is very awkward because 0xFFFF is one of
// many NaN's in FP16 It's unfortunate that the mask has type Half
// (required from vec_base)
// TODO
// NB: This requires that each value, i.e., each uint value,
// of the mask either all be zeros or all be 1s.
// We perhaps need some kind of an assert?
// But that will affect performance.
// NOTE [vbslq_f16]: vbslq_f16 doesn't work on clang without
// __ARM_FEATURE_FP16_VECTOR_ARITHMETIC. vbslq_u16 generates the
// same instruction anyway. see https://godbolt.org/z/cY4a55Y7P
Vectorized<c10::Half> vec(mask.values);
vec.values = vreinterpretq_f16_u16(vbslq_u16(
vreinterpretq_u16_f16(vec.values),
vreinterpretq_u16_f16(b.values),
vreinterpretq_u16_f16(a.values)));
return vec;
}
static Vectorized<c10::Half> set(
const Vectorized<c10::Half>& a,
const Vectorized<c10::Half>& b,
int64_t count = size()) {
uint16_t pre_mask[size()] = {0};
for (int i = 0; i < count; i++) {
pre_mask[i] = 0xFFFF;
}
uint16x8_t mask = vld1q_u16(pre_mask);
// Using blendv is awkward because 0xFFFF is one of many NaN's in FP16
// so we directly use vbslq_u16 instead. (See NOTE [vbslq_f16] above.)
Vectorized<c10::Half> vec(vreinterpretq_f16_u16(vbslq_u16(
mask,
vreinterpretq_u16_f16(b.values),
vreinterpretq_u16_f16(a.values))));
return vec;
}
static Vectorized<c10::Half> loadu(const void* ptr, int64_t count = size()) {
if (count == size()) {
return vld1q_f16(reinterpret_cast<const float16_t*>(ptr));
}
__at_align__ float16_t tmp_values[size()];
for (const auto i : c10::irange(size())) {
tmp_values[i] = 0;
}
std::memcpy(
tmp_values,
reinterpret_cast<const float16_t*>(ptr),
count * sizeof(float16_t));
return vld1q_f16(reinterpret_cast<const float16_t*>(tmp_values));
}
void store(void* ptr, int64_t count = size()) const {
if (count == size()) {
vst1q_f16(reinterpret_cast<float16_t*>(ptr), values);
return;
} else {
float16_t tmp_values[size()];
vst1q_f16(reinterpret_cast<float16_t*>(tmp_values), values);
std::memcpy(ptr, tmp_values, count * sizeof(float16_t));
}
}
int zero_mask() const {
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
uint16x8_t is_zero_vec = vceqzq_f16(values);
const int16x8_t shift = vcombine_s16(
vcreate_s16(
0x0 | (int64_t(0x1) << 16) | (int64_t(0x2) << 32) |
(int64_t(0x3) << 48)),
vcreate_s16(
0x4 | (int64_t(0x5) << 16) | (int64_t(0x6) << 32) |
(int64_t(0x7) << 48)));
uint16x8_t bits_vec =
vshlq_u16(vandq_u16(is_zero_vec, vdupq_n_u16(1)), shift);
return vaddvq_u16(bits_vec);
#else // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
// use known working implementation.
__at_align__ value_type tmp[size()];
store(tmp);
int mask = 0;
for (int i = 0; i < size(); ++i) {
if (tmp[i] == 0) {
mask |= (1 << i);
}
}
return mask;
#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
}
Vectorized<c10::Half> isnan() const {
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
return vreinterpretq_f16_u16(vmvnq_u16(vceqq_f16(values, values)));
#else
// NOTE: we could make this faster by doing vectorized checks of
// exponent/payload bits.
__at_align__ c10::Half tmp[size()];
__at_align__ c10::Half res[size()];
store(tmp);
for (const auto i : c10::irange(size())) {
if (_isnan(tmp[i])) {
std::memset(static_cast<void*>(&res[i]), 0xFF, sizeof(c10::Half));
} else {
std::memset(static_cast<void*>(&res[i]), 0, sizeof(c10::Half));
}
}
return loadu(res);
#endif
}
bool has_inf_nan() const {
__at_align__ c10::Half tmp[size()];
store(tmp);
for (const auto i : c10::irange(size())) {
if (_isnan(tmp[i]) || _isinf(tmp[i])) {
return true;
}
}
return false;
}
Vectorized<c10::Half> abs() const {
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
return Vectorized<c10::Half>(vabsq_f16(values));
#else
return map_with_vec_float_method(&Vectorized<float>::abs);
#endif
}
Vectorized<c10::Half> frac() const;
Vectorized<c10::Half> neg() const {
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
return Vectorized<c10::Half>(vnegq_f16(values));
#else
return map_with_vec_float_method(&Vectorized<float>::neg);
#endif
}
Vectorized<c10::Half> trunc() const {
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
return Vectorized<c10::Half>(vrndq_f16(values));
#else
return map_with_vec_float_method(&Vectorized<float>::trunc);
#endif
}
Vectorized<c10::Half> sqrt() const {
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
return Vectorized<c10::Half>(vsqrtq_f16(values));
#else
return map_with_vec_float_method(&Vectorized<float>::sqrt);
#endif
}
Vectorized<c10::Half> reciprocal() const {
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
auto ones = vdupq_n_f16(1.0f);
return Vectorized<c10::Half>(vdivq_f16(ones, values));
#else
return map_with_vec_float_method(&Vectorized<float>::reciprocal);
#endif
}
Vectorized<c10::Half> operator==(const Vectorized<c10::Half>& other) const {
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
return Vectorized<c10::Half>(
vreinterpretq_f16_u16(vceqq_f16(values, other.values)));
#else
return map2_bitmask_with_vec_float_method(
other, &Vectorized<float>::operator==);
#endif
}
Vectorized<c10::Half> operator!=(const Vectorized<c10::Half>& other) const {
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
return Vectorized<c10::Half>(
vreinterpretq_f16_u16(vmvnq_u16(vceqq_f16(values, other.values))));
#else
return map2_bitmask_with_vec_float_method(
other, &Vectorized<float>::operator!=);
#endif
}
Vectorized<c10::Half> operator<(const Vectorized<c10::Half>& other) const {
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
return Vectorized<c10::Half>(
vreinterpretq_f16_u16(vcltq_f16(values, other.values)));
#else
return map2_bitmask_with_vec_float_method(
other, &Vectorized<float>::operator<);
#endif
}
Vectorized<c10::Half> operator<=(const Vectorized<c10::Half>& other) const {
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
return Vectorized<c10::Half>(
vreinterpretq_f16_u16(vcleq_f16(values, other.values)));
#else
return map2_bitmask_with_vec_float_method(
other, &Vectorized<float>::operator<=);
#endif
}
Vectorized<c10::Half> operator>(const Vectorized<c10::Half>& other) const {
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
return Vectorized<c10::Half>(
vreinterpretq_f16_u16(vcgtq_f16(values, other.values)));
#else
return map2_bitmask_with_vec_float_method(
other, &Vectorized<float>::operator>);
#endif
}
Vectorized<c10::Half> operator>=(const Vectorized<c10::Half>& other) const {
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
return Vectorized<c10::Half>(
vreinterpretq_f16_u16(vcgeq_f16(values, other.values)));
#else
return map2_bitmask_with_vec_float_method(
other, &Vectorized<float>::operator>=);
#endif
}
Vectorized<c10::Half> eq(const Vectorized<c10::Half>& other) const;
Vectorized<c10::Half> ne(const Vectorized<c10::Half>& other) const;
Vectorized<c10::Half> gt(const Vectorized<c10::Half>& other) const;
Vectorized<c10::Half> ge(const Vectorized<c10::Half>& other) const;
Vectorized<c10::Half> lt(const Vectorized<c10::Half>& other) const;
Vectorized<c10::Half> le(const Vectorized<c10::Half>& other) const;
}; // Vectorized<Half>
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free