Home / Class/ caller_is_meff Class — pytorch Architecture

caller_is_meff Class — pytorch Architecture

Architecture documentation for the caller_is_meff class in sdp_utils.cpp from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/transformers/cuda/sdp_utils.cpp lines 146–216

template<bool caller_is_meff = false>
bool check_head_dim_size_flash(sdp_params const& params, bool debug) {
#if USE_ROCM_ATTENTION
  if (at::cuda::device_count() == 0) {
    return false;
  }
  // AOTriton 0.9+ supports head_dim up to 512
  const static auto max_hdim = []() {
#if AOTRITON_VERSION_CURRENT == AOTRITON_VERSION_INT(0, 11)
    // gfx11xx only support hdim <= 256 on AOTriton 0.11
    auto dprops = at::cuda::getCurrentDeviceProperties();
    const c10::basic_string_view<char> arch(dprops->gcnArchName);
    if (arch.starts_with("gfx11")) {
      return 256;
    }
#endif // AOTriton 0.11
#if AOTRITON_VERSION_CURRENT >= AOTRITON_VERSION_INT(0, 9)
    return 512;
#else
    return 256;
#endif
  }();
  const auto max_size = c10::SymInt(max_hdim);
#else
  // All head_dim sizes must be equal and less than 256
  const auto max_size = c10::SymInt(256);
#endif
  const auto query_size_last = params.query.sym_size(-1);
  const auto key_size_last = params.key.sym_size(-1);
  const auto value_size_last = params.value.sym_size(-1);
  bool same_head_dim_size =
      query_size_last == key_size_last && query_size_last == value_size_last;
  if (!(same_head_dim_size && (query_size_last <= max_size))) {
    if (debug) {
      TORCH_WARN(
          caller_is_meff ? "Efficient attention on ROCM" : "Flash attention",
          " requires q,k,v to have the same last dimension and to be less than or equal to 256.",
          " Got Query.size(-1): ",
          query_size_last,
          ", Key.size(-1): ",
          key_size_last,
          ", Value.size(-1): ",
          value_size_last,
          " instead.");
    }
    return false;
  }
  if constexpr(caller_is_meff) {
    bool is_half = (params.query.dtype() == at::kHalf) ||
      (params.query.dtype() == at::kBFloat16);
    const int64_t alignment = is_half ? 8 : 4;
    if (!(query_size_last % alignment == 0 && query_size_last > 0 &&
          value_size_last % alignment == 0 && value_size_last > 0)) {
      if (debug) {
        TORCH_WARN(
            "Mem efficient attention requires last dimension of inputs to be divisible by ",
            alignment,
            ". ",
            "Got Query.size(-1): ",
            query_size_last,
            ", Key.size(-1): ",
            params.key.sym_size(-1),
            ", Value.size(-1): ",
            params.value.sym_size(-1),
            " instead.");
      }
      return false;
    }
  }
  return true;
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free