Home / Class/ infer_size_impl Class — pytorch Architecture

infer_size_impl Class — pytorch Architecture

Architecture documentation for the infer_size_impl class in InferSize.h from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/InferSize.h lines 20–105

template <typename InputArrayRef, typename NumelType, typename ResultVec>
inline void infer_size_impl(
    InputArrayRef shape,
    NumelType numel,
    ResultVec& res) {
  NumelType newsize = 1;
  // N.B. this is an index, not a sym dim!
  std::optional<int64_t> infer_dim;
  for (int64_t dim = 0, ndim = shape.size(); dim != ndim; dim++) {
    if (TORCH_GUARD_OR_FALSE(sym_eq(shape[dim], -1))) {
      TORCH_CHECK(!infer_dim, "only one dimension can be inferred");
      infer_dim = dim;
    } else {
      // in case of unbacked shape[dim] we assume it's not -1 and add a runtime
      // assertion.
      TORCH_MAYBE_SYM_CHECK(
          sym_gt(shape[dim], -1),
          "invalid shape dimension ",
          shape[dim],
          " at index ",
          dim,
          " of shape ",
          shape);
      newsize *= shape[dim];
    }
  }

  if (infer_dim) {
    // numel is the product of known sizes, it has to be divisible by newsize.
    // and newsize should be positive unless newsize == numel (we throw
    // different) error message in that case.
    if constexpr (std::is_same_v<NumelType, c10::SymInt>) {
      auto v = newsize.maybe_as_int();
      if (v and *v == 0) {
        // Avoid div by 0 when sym_eq(numel % newsize, 0) is constructed!
        // which may happen when newsize is not a symbol! if its a symbol
        // division won't happen anyway during compile.
        TORCH_MAYBE_SYM_CHECK(
            numel == newsize,
            "shape '",
            shape,
            "' is invalid for input of size ",
            numel);
      } else {
        auto cond = sym_gt(newsize, 0)
                        .sym_and(sym_eq(numel % newsize, 0))
                        .sym_or(sym_eq(numel, newsize));
        TORCH_MAYBE_SYM_CHECK(
            cond, "shape '", shape, "' is invalid for input of size ", numel);
      }

    } else {
      TORCH_CHECK(
          (newsize > 0 && (numel % newsize == 0)) || numel == newsize,
          "shape '",
          shape,
          "' is invalid for input of size ",
          numel);
    }

    // We have a degree of freedom here to select the dimension size; follow
    // NumPy semantics and just bail.  However, a nice error message is needed
    // because users often use `view` as a way to flatten & unflatten
    // dimensions and will otherwise be confused why
    //   empty_tensor.view( 0, 0)
    // works yet
    //   empty_tensor.view(-1, 0)
    // doesn't.
    TORCH_MAYBE_SYM_CHECK(
        newsize != 0,
        "cannot reshape tensor of 0 elements into shape ",
        shape,
        " because the unspecified dimension size -1 can be any "
        "value and is ambiguous");

    res[*infer_dim] = numel / newsize;
    return;
  }

  TORCH_MAYBE_SYM_CHECK(
      sym_eq(numel, newsize),
      "shape '",
      shape,
      "' is invalid for input of size ",
      numel);
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free