Home / Class/ addmm_out_sparse_csr_native_cpu Class — pytorch Architecture

addmm_out_sparse_csr_native_cpu Class — pytorch Architecture

Architecture documentation for the addmm_out_sparse_csr_native_cpu class in SparseCsrTensorMath.cpp from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/sparse/SparseCsrTensorMath.cpp lines 527–584

template <typename scalar_t>
static void addmm_out_sparse_csr_native_cpu(
    const Tensor& sparse,
    const Tensor& dense,
    const Tensor& r,
    Scalar alpha,
    const Scalar& beta) {
  auto dim_i = sparse.size(0);
  auto dim_k = dense.size(1);

  auto csr = sparse.crow_indices();
  auto col_indices = sparse.col_indices();
  auto values = sparse.values();

  scalar_t cast_alpha = alpha.to<scalar_t>();
  // If beta is zero NaN and Inf should not be propagated to the result
  if (beta.toComplexDouble() == 0.) {
    r.zero_();
  } else {
    r.mul_(beta);
  }
  AT_DISPATCH_INDEX_TYPES(
      col_indices.scalar_type(), "csr_mm_crow_indices", [&]() {
        auto csr_accessor = csr.accessor<index_t, 1>();
        auto col_indices_accessor = col_indices.accessor<index_t, 1>();

        auto values_accessor = values.accessor<scalar_t, 1>();
        scalar_t* dense_ptr = dense.data_ptr<scalar_t>();
        scalar_t* r_ptr = r.data_ptr<scalar_t>();

        int64_t dense_stride0 = dense.stride(0);
        int64_t dense_stride1 = dense.stride(1);
        int64_t r_stride0 = r.stride(0);
        int64_t r_stride1 = r.stride(1);

        at::parallel_for(
            0,
            dim_i,
            internal::GRAIN_SIZE,
            [&](int64_t irow_start, int64_t irow_end) {
              for (index_t h = irow_start; h < irow_end; ++h) {
                index_t i_start = csr_accessor[h];
                index_t i_end = csr_accessor[h + 1];
                for (index_t i = i_start; i < i_end; i++) {
                  scalar_t val = values_accessor[i];
                  index_t col = col_indices_accessor[i];
                  at::native::cpublas::axpy<scalar_t>(
                      dim_k,
                      cast_alpha * val,
                      dense_ptr + col * dense_stride0,
                      dense_stride1,
                      r_ptr + h * r_stride0,
                      r_stride1);
                }
              }
            });
      });
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free