Home / Class/ ignore_singleton_dim Class — pytorch Architecture

ignore_singleton_dim Class — pytorch Architecture

Architecture documentation for the ignore_singleton_dim class in sdp_utils_cpp.h from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/transformers/sdp_utils_cpp.h lines 491–533

template<bool ignore_singleton_dim>
inline bool check_last_dim_stride_equals_1_dense(sdp_params const& params, bool debug) {
  // The stride checking for NestedTensors is done within the kernel
  // And .contiguous will be called if needed

  // This function checks that the last dimension of the inputs to
  // fused_attention have stride 1
  bool qkv_strides_equal_1 = params.query.sym_stride(-1) == 1 &&
      params.key.sym_stride(-1) == 1 && params.value.sym_stride(-1) == 1;

  // https://github.com/pytorch/pytorch/issues/116333
  // If the head_dim is size 1 the stride won't matter, but we
  // check this condition before padding the head_dim to 1
  if (ignore_singleton_dim){
    qkv_strides_equal_1 = qkv_strides_equal_1 || params.query.sym_size(-1) == 1;
  }
  bool is_cpu = params.query.device().type() == c10::DeviceType::CPU;
  bool mask_stride_equal_1 = params.attn_mask.has_value()
      ? params.attn_mask.value().sym_stride(-1) == 1
      : true;
  bool mask_stride_valid = is_cpu ? true : mask_stride_equal_1;
  if (!(qkv_strides_equal_1 && mask_stride_valid)) {
    if (debug) {
      std::ostringstream message;
      message
          << "All fused kernels require the last dimension of the input to have stride 1. ";
      message << "Got Query.stride(-1): " << params.query.sym_stride(-1)
              << ", Key.stride(-1): " << params.key.sym_stride(-1)
              << ", Value.stride(-1): " << params.value.sym_stride(-1);

      if (params.attn_mask.has_value()) {
        message
            << ", Attn_mask.stride(-1): "
            << params.attn_mask.value().sym_stride(-1)
            << " (GPU backends require attn_mask's last dimension to have stride 1 while the CPU does not).";
      }
      TORCH_WARN(message.str());
    }

    return false;
  }
  return true;
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free