Home / Class/ is_bmm Class — pytorch Architecture

is_bmm Class — pytorch Architecture

Architecture documentation for the is_bmm class in LinearAlgebra.cpp from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/LinearAlgebra.cpp lines 1648–1693

template <typename scalar_t, bool is_bmm>
static inline void baddbmm_cpu_kernel(const Tensor& result, const Tensor& self, const Tensor& mat2, const Scalar& beta_, const Scalar& alpha_) {
  int64_t bs = result.size(0);
  int64_t is = result.size(1);
  int64_t js = result.size(2);
  int64_t ks = self.size(2);

  using opmath_t = at::opmath_type<scalar_t>;
  opmath_t alpha = alpha_.to<opmath_t>();
  opmath_t beta = beta_.to<opmath_t>();

  auto r0 = result.accessor<scalar_t, 3>();
  auto s0 = self.accessor<const scalar_t, 3>();
  auto m0 = mat2.accessor<const scalar_t, 3>();

  int64_t grain_size = std::max(internal::GRAIN_SIZE / (is * js * ks), static_cast<int64_t>(1));
  using opmath_t = at::opmath_type<scalar_t>;
  parallel_for(0, bs, grain_size, [&](int64_t b_begin, int64_t b_end) {
      for (const auto b : c10::irange(b_begin, b_end)) {
        auto r1 = r0[b];
        auto s1 = s0[b];
        auto m1 = m0[b];
        for (const auto i : c10::irange(is)) {
          auto r2 = r1[i];
          auto s2 = s1[i];
          for (const auto j : c10::irange(js)) {
            opmath_t acc_value = 0;//is_bmm ? opmath_t(0) : opmath_t(r2[j]);
            for (const auto k : c10::irange(ks)) {
              acc_value += static_cast<opmath_t>(s2[k]) *
                  static_cast<opmath_t>(m1[k][j]);
            }
            if (is_bmm) {
              r2[j] = acc_value;
            } else {
              // For beta == 0, the r's value will be ignored, especially for nan value.
              if (beta == opmath_t{0}) {
                r2[j] = alpha * acc_value;
              } else {
                r2[j] = static_cast<opmath_t>(r2[j]) * beta + alpha * acc_value;
              }
            }
          }
        }
      }
    });
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free