lhs_values Class — pytorch Architecture
Architecture documentation for the lhs_values class in SparseBinaryOpIntersectionKernel.cpp from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/sparse/SparseBinaryOpIntersectionKernel.cpp lines 44–119
template <typename binary_op_t>
struct CPUValueSelectionIntersectionKernel {
static Tensor apply(
const Tensor& lhs_values,
const Tensor& lhs_select_idx,
const Tensor& rhs_values,
const Tensor& rhs_select_idx,
const Tensor& intersection_counts,
const Tensor& argsort,
const bool accumulate_matches) {
auto iter = make_value_selection_intersection_iter(
lhs_values,
lhs_select_idx,
rhs_values,
rhs_select_idx,
intersection_counts);
auto res_values = iter.tensor(0);
auto lhs_nnz_stride = lhs_values.stride(0);
auto rhs_nnz_stride = rhs_values.stride(0);
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(
ScalarType::Bool, ScalarType::Half, ScalarType::BFloat16, at::ScalarType::ComplexHalf,
res_values.scalar_type(),
"binary_op_intersection_cpu", [&] {
// COO indices are only 64-bit for now.
using index_t = int64_t;
auto loop = [&](char** data, const int64_t* strides, int64_t n) {
auto* ptr_res_values_bytes = data[0];
const auto* ptr_lhs_values_bytes = data[1];
const auto* ptr_lhs_select_idx_bytes = data[2];
const auto* ptr_rhs_values_bytes = data[3];
const auto* ptr_rhs_select_idx_bytes = data[4];
const auto* ptr_intersection_counts_bytes = data[5];
const auto* ptr_argsort = argsort.const_data_ptr<index_t>();
for (int64_t i = 0; i < n; ++i) {
// Extract data
auto* ptr_res_values = reinterpret_cast<scalar_t*>(ptr_res_values_bytes);
const auto* ptr_lhs_values = reinterpret_cast<const scalar_t*>(ptr_lhs_values_bytes);
const auto lhs_nnz_idx = *reinterpret_cast<const index_t*>(ptr_lhs_select_idx_bytes);
const auto* ptr_rhs_values = reinterpret_cast<const scalar_t*>(ptr_rhs_values_bytes);
const auto rhs_nnz_idx = *reinterpret_cast<const index_t*>(ptr_rhs_select_idx_bytes);
const auto count = *reinterpret_cast<const int64_t*>(ptr_intersection_counts_bytes);
const auto* ptr_lhs_begin = ptr_lhs_values + lhs_nnz_idx * lhs_nnz_stride;
const auto* ptr_rhs_sorted_nnz_idx = ptr_argsort + rhs_nnz_idx;
using accscalar_t = at::acc_type<scalar_t, /*is_gpu=*/false>;
accscalar_t res_values = 0;
accscalar_t lhs_values = static_cast<accscalar_t>(*ptr_lhs_begin);
accscalar_t rhs_values;
index_t rhs_sorted_nnz_idx;
const auto match_count = accumulate_matches ? count : std::min<int64_t>(count, 1);
for (int64_t c = 0; c < match_count; ++c) {
rhs_sorted_nnz_idx = *ptr_rhs_sorted_nnz_idx++;
rhs_values = static_cast<accscalar_t>(*(ptr_rhs_values + rhs_sorted_nnz_idx * rhs_nnz_stride));
res_values += binary_op_t::apply(lhs_values, rhs_values);
}
*ptr_res_values = static_cast<scalar_t>(res_values);
// Advance
ptr_res_values_bytes += strides[0];
ptr_lhs_values_bytes += strides[1];
ptr_lhs_select_idx_bytes += strides[2];
ptr_rhs_values_bytes += strides[3];
ptr_rhs_select_idx_bytes += strides[4];
ptr_intersection_counts_bytes += strides[5];
}
};
iter.for_each(loop, at::internal::GRAIN_SIZE);
});
return res_values;
}
};
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free