Home / Class/ common_checks_baddbmm_bmm Class — pytorch Architecture

common_checks_baddbmm_bmm Class — pytorch Architecture

Architecture documentation for the common_checks_baddbmm_bmm class in LinearAlgebra.cpp from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/LinearAlgebra.cpp lines 289–334

template <typename Meta>
static void common_checks_baddbmm_bmm(Meta& meta, const Tensor& batch1, const Tensor& batch2, const Scalar& beta, const Scalar& alpha, bool is_bmm, const std::optional<Tensor>& self_baddbmm = std::nullopt) {
  TORCH_CHECK(batch1.dim() == 3, "batch1 must be a 3D tensor");
  TORCH_CHECK(batch2.dim() == 3, "batch2 must be a 3D tensor");

  const auto batch1_sizes = batch1.sizes();
  const auto batch2_sizes = batch2.sizes();

  int64_t bs = batch1_sizes[0];
  int64_t contraction_size = batch1_sizes[2];
  int64_t res_rows = batch1_sizes[1];
  int64_t res_cols = batch2_sizes[2];
  std::vector<int64_t> output_size {bs, res_rows, res_cols};

  TORCH_CHECK(batch2_sizes[0] == bs && batch2_sizes[1] == contraction_size,
              "Expected size for first two dimensions of batch2 tensor to be: [",
              bs, ", ", contraction_size, "] but got: [", batch2_sizes[0], ", ", batch2_sizes[1], "].");

  auto& result = meta.maybe_get_output(0);
  // 'set_output' does not resize for in-place calls
  meta.set_output_raw_strided(0, output_size, {}, batch2.options());
  const auto result_sizes = result.sizes();
  // Error is raised if called from in-place overload with incorrect shape
  TORCH_CHECK(result_sizes == output_size,
              "Expected an output tensor with shape [", output_size, "] but got shape ", result_sizes);

  std::vector<Dimname> outnames = {};
  if (!is_bmm) {
    if (self_baddbmm.has_value()) {
      const auto& self = self_baddbmm.value();
      if (beta.toComplexDouble() != 0.0) result.copy_(self);
      TORCH_CHECK(self.dim() == 3, "self must be a 3D tensor");
      const auto self_sizes = self.sizes();
      TORCH_CHECK(self_sizes == output_size,
                  "Expected an input tensor shape with shape ", output_size, " but got shape: ", self_sizes);
      outnames = namedinference::compute_baddbmm_outnames(result, batch1, batch2, self);
    }
  } else {
    outnames = namedinference::compute_bmm_outnames(result, batch1, batch2);
  }

  namedinference::propagate_names_if_nonempty(
    result,
    outnames
  );
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free