Home / Class/ s_addmm_out_sparse_dense_worker Class — pytorch Architecture

s_addmm_out_sparse_dense_worker Class — pytorch Architecture

Architecture documentation for the s_addmm_out_sparse_dense_worker class in SparseTensorMath.cpp from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/sparse/SparseTensorMath.cpp lines 1185–1233

template <typename scalar_t>
static void s_addmm_out_sparse_dense_worker(int64_t nnz, int64_t dim_i, int64_t dim_j, int64_t dim_k, Tensor& r, const Scalar& beta, const Tensor& t, const Scalar& alpha, const Tensor& indices, const Tensor& values, const Tensor& dense) {

  // r_ = alpha * sparse * dense
  scalar_t cast_alpha = alpha.to<scalar_t>();
  scalar_t cast_beta = beta.to<scalar_t>();

  if (cast_beta == static_cast<scalar_t>(0)) {
    r.zero_();
  } else if (cast_beta == static_cast<scalar_t>(1)) {
    if (!is_same_tensor(r, t)) {
      r.copy_(t);
    }
  } else {
    at::mul_out(r, t, scalar_to_tensor(beta));
  }

  auto indices_accessor = indices.accessor<int64_t, 2>();

  auto values_accessor = values.accessor<scalar_t, 1>();
  scalar_t* dense_ptr = dense.data_ptr<scalar_t>();
  scalar_t* r_ptr = r.data_ptr<scalar_t>();

  int64_t dense_stride0 = dense.stride(0);
  int64_t dense_stride1 = dense.stride(1);
  int64_t r_stride0 = r.stride(0);
  int64_t r_stride1 = r.stride(1);
  for (auto i: c10::irange(nnz)) {
    scalar_t val = values_accessor[i];
    int64_t row = indices_accessor[0][i];
    int64_t col = indices_accessor[1][i];
    if (col >= 0 && col < dim_j && row >= 0 && row < dim_i) {
      // AXPY call is no-op over an empty vector
      if (dim_k == 0) {
        continue;
      }
      at::native::cpublas::axpy<scalar_t>(dim_k,
            cast_alpha * val,
            dense_ptr + col * dense_stride0, dense_stride1,
            r_ptr + row * r_stride0, r_stride1);
    } else {
      if (col < 0 || col >= dim_j) {
        TORCH_CHECK(false, "addmm: index out of column bound: ", col, " not between 1 and ", dim_j);
      } else {
        TORCH_CHECK(false, "addmm: index out of row bound: ", row, " not between 1 and ", dim_i);
      }
    }
  }
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free