is_bmm Class — pytorch Architecture
Architecture documentation for the is_bmm class in LinearAlgebra.cpp from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/LinearAlgebra.cpp lines 1648–1693
template <typename scalar_t, bool is_bmm>
static inline void baddbmm_cpu_kernel(const Tensor& result, const Tensor& self, const Tensor& mat2, const Scalar& beta_, const Scalar& alpha_) {
int64_t bs = result.size(0);
int64_t is = result.size(1);
int64_t js = result.size(2);
int64_t ks = self.size(2);
using opmath_t = at::opmath_type<scalar_t>;
opmath_t alpha = alpha_.to<opmath_t>();
opmath_t beta = beta_.to<opmath_t>();
auto r0 = result.accessor<scalar_t, 3>();
auto s0 = self.accessor<const scalar_t, 3>();
auto m0 = mat2.accessor<const scalar_t, 3>();
int64_t grain_size = std::max(internal::GRAIN_SIZE / (is * js * ks), static_cast<int64_t>(1));
using opmath_t = at::opmath_type<scalar_t>;
parallel_for(0, bs, grain_size, [&](int64_t b_begin, int64_t b_end) {
for (const auto b : c10::irange(b_begin, b_end)) {
auto r1 = r0[b];
auto s1 = s0[b];
auto m1 = m0[b];
for (const auto i : c10::irange(is)) {
auto r2 = r1[i];
auto s2 = s1[i];
for (const auto j : c10::irange(js)) {
opmath_t acc_value = 0;//is_bmm ? opmath_t(0) : opmath_t(r2[j]);
for (const auto k : c10::irange(ks)) {
acc_value += static_cast<opmath_t>(s2[k]) *
static_cast<opmath_t>(m1[k][j]);
}
if (is_bmm) {
r2[j] = acc_value;
} else {
// For beta == 0, the r's value will be ignored, especially for nan value.
if (beta == opmath_t{0}) {
r2[j] = alpha * acc_value;
} else {
r2[j] = static_cast<opmath_t>(r2[j]) * beta + alpha * acc_value;
}
}
}
}
}
});
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free