GemmStridedBatchedTunableOp Class — pytorch Architecture
Architecture documentation for the GemmStridedBatchedTunableOp class in TunableGemm.h from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/cuda/tunable/TunableGemm.h lines 271–303
class GemmStridedBatchedTunableOp : public TunableOp<GemmStridedBatchedParams<T>> {
public:
GemmStridedBatchedTunableOp() {
this->RegisterOp(std::string("Default"), std::make_unique<DefaultGemmStridedBatchedOp<T>>());
#ifdef USE_ROCM
static const auto env_rocblas = c10::utils::check_env("PYTORCH_TUNABLEOP_ROCBLAS_ENABLED");
if (!env_rocblas.has_value() || env_rocblas.value()) {
for (auto&& [name, op] : GetRocBlasGemmStridedBatchedTypeStringAndOps<T>()) {
this->RegisterOp(std::move(name), std::move(op));
}
}
static const auto env_hipblaslt = c10::utils::check_env("PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED");
if (!env_hipblaslt.has_value() || env_hipblaslt.value()) {
// disallow tuning of hipblaslt with c10::complex
if constexpr (
!std::is_same_v<T, c10::complex<float>> &&
!std::is_same_v<T, c10::complex<double>>) {
for (auto&& [name, op] : GetHipBlasLtGemmStridedBatchedTypeStringAndOps<T, ALayout, BLayout>()) {
this->RegisterOp(std::move(name), std::move(op));
}
}
}
#endif
this->RegisterOp(std::string("Default"), std::make_unique<DefaultGemmStridedBatchedOp<T>>());
}
std::string Signature() override {
return fmt::sprintf("GemmStridedBatchedTunableOp_%s_%c%c", TypeName<T>(T{}), BlasOpToString(ALayout), BlasOpToString(BLayout));
}
};
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free