common_checks_baddbmm_bmm Class — pytorch Architecture
Architecture documentation for the common_checks_baddbmm_bmm class in LinearAlgebra.cpp from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/LinearAlgebra.cpp lines 289–334
template <typename Meta>
static void common_checks_baddbmm_bmm(Meta& meta, const Tensor& batch1, const Tensor& batch2, const Scalar& beta, const Scalar& alpha, bool is_bmm, const std::optional<Tensor>& self_baddbmm = std::nullopt) {
TORCH_CHECK(batch1.dim() == 3, "batch1 must be a 3D tensor");
TORCH_CHECK(batch2.dim() == 3, "batch2 must be a 3D tensor");
const auto batch1_sizes = batch1.sizes();
const auto batch2_sizes = batch2.sizes();
int64_t bs = batch1_sizes[0];
int64_t contraction_size = batch1_sizes[2];
int64_t res_rows = batch1_sizes[1];
int64_t res_cols = batch2_sizes[2];
std::vector<int64_t> output_size {bs, res_rows, res_cols};
TORCH_CHECK(batch2_sizes[0] == bs && batch2_sizes[1] == contraction_size,
"Expected size for first two dimensions of batch2 tensor to be: [",
bs, ", ", contraction_size, "] but got: [", batch2_sizes[0], ", ", batch2_sizes[1], "].");
auto& result = meta.maybe_get_output(0);
// 'set_output' does not resize for in-place calls
meta.set_output_raw_strided(0, output_size, {}, batch2.options());
const auto result_sizes = result.sizes();
// Error is raised if called from in-place overload with incorrect shape
TORCH_CHECK(result_sizes == output_size,
"Expected an output tensor with shape [", output_size, "] but got shape ", result_sizes);
std::vector<Dimname> outnames = {};
if (!is_bmm) {
if (self_baddbmm.has_value()) {
const auto& self = self_baddbmm.value();
if (beta.toComplexDouble() != 0.0) result.copy_(self);
TORCH_CHECK(self.dim() == 3, "self must be a 3D tensor");
const auto self_sizes = self.sizes();
TORCH_CHECK(self_sizes == output_size,
"Expected an input tensor shape with shape ", output_size, " but got shape: ", self_sizes);
outnames = namedinference::compute_baddbmm_outnames(result, batch1, batch2, self);
}
} else {
outnames = namedinference::compute_bmm_outnames(result, batch1, batch2);
}
namedinference::propagate_names_if_nonempty(
result,
outnames
);
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free