caller_is_meff Class — pytorch Architecture
Architecture documentation for the caller_is_meff class in sdp_utils.cpp from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/transformers/cuda/sdp_utils.cpp lines 146–216
template<bool caller_is_meff = false>
bool check_head_dim_size_flash(sdp_params const& params, bool debug) {
#if USE_ROCM_ATTENTION
if (at::cuda::device_count() == 0) {
return false;
}
// AOTriton 0.9+ supports head_dim up to 512
const static auto max_hdim = []() {
#if AOTRITON_VERSION_CURRENT == AOTRITON_VERSION_INT(0, 11)
// gfx11xx only support hdim <= 256 on AOTriton 0.11
auto dprops = at::cuda::getCurrentDeviceProperties();
const c10::basic_string_view<char> arch(dprops->gcnArchName);
if (arch.starts_with("gfx11")) {
return 256;
}
#endif // AOTriton 0.11
#if AOTRITON_VERSION_CURRENT >= AOTRITON_VERSION_INT(0, 9)
return 512;
#else
return 256;
#endif
}();
const auto max_size = c10::SymInt(max_hdim);
#else
// All head_dim sizes must be equal and less than 256
const auto max_size = c10::SymInt(256);
#endif
const auto query_size_last = params.query.sym_size(-1);
const auto key_size_last = params.key.sym_size(-1);
const auto value_size_last = params.value.sym_size(-1);
bool same_head_dim_size =
query_size_last == key_size_last && query_size_last == value_size_last;
if (!(same_head_dim_size && (query_size_last <= max_size))) {
if (debug) {
TORCH_WARN(
caller_is_meff ? "Efficient attention on ROCM" : "Flash attention",
" requires q,k,v to have the same last dimension and to be less than or equal to 256.",
" Got Query.size(-1): ",
query_size_last,
", Key.size(-1): ",
key_size_last,
", Value.size(-1): ",
value_size_last,
" instead.");
}
return false;
}
if constexpr(caller_is_meff) {
bool is_half = (params.query.dtype() == at::kHalf) ||
(params.query.dtype() == at::kBFloat16);
const int64_t alignment = is_half ? 8 : 4;
if (!(query_size_last % alignment == 0 && query_size_last > 0 &&
value_size_last % alignment == 0 && value_size_last > 0)) {
if (debug) {
TORCH_WARN(
"Mem efficient attention requires last dimension of inputs to be divisible by ",
alignment,
". ",
"Got Query.size(-1): ",
query_size_last,
", Key.size(-1): ",
params.key.sym_size(-1),
", Value.size(-1): ",
params.value.sym_size(-1),
" instead.");
}
return false;
}
}
return true;
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free