Home / Class/ train Class — pytorch Architecture

train Class — pytorch Architecture

Architecture documentation for the train class in SparseCsrTensorMath.h from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/sparse/SparseCsrTensorMath.h lines 62–82

template <bool train>
inline void check_sparse_mm_reduce_impl_inputs(
    const Tensor& self,
    const Tensor& grad_out,
    const Tensor& other) {
  TORCH_INTERNAL_ASSERT(self.is_sparse_csr());

  const auto input_scalar_type = self.values().scalar_type();
  CheckedFrom c = train ? "sparse_mm_reduce_backward" : "sparse_mm_reduce";
  if (train) {
    checkLayout(c, grad_out, kStrided);
    checkScalarType(c, {grad_out, "grad_out", 1}, input_scalar_type);
    check_dim_size(grad_out, 2, 0, self.size(0));
    check_dim_size(grad_out, 2, 1, other.size(1));
  }

  int pos = train ? 2 : 1;
  checkLayout(c, other, kStrided);
  checkScalarType(c, {other, "other", pos}, input_scalar_type);
  check_dim_size(other, 2, 0, self.size(1));
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free