Home / Class/ GemmStridedBatchedTunableOp Class — pytorch Architecture

GemmStridedBatchedTunableOp Class — pytorch Architecture

Architecture documentation for the GemmStridedBatchedTunableOp class in TunableGemm.h from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/cuda/tunable/TunableGemm.h lines 271–303

class GemmStridedBatchedTunableOp : public TunableOp<GemmStridedBatchedParams<T>> {
 public:
  GemmStridedBatchedTunableOp() {
    this->RegisterOp(std::string("Default"), std::make_unique<DefaultGemmStridedBatchedOp<T>>());

#ifdef USE_ROCM
    static const auto env_rocblas = c10::utils::check_env("PYTORCH_TUNABLEOP_ROCBLAS_ENABLED");
    if (!env_rocblas.has_value() || env_rocblas.value()) {
      for (auto&& [name, op] : GetRocBlasGemmStridedBatchedTypeStringAndOps<T>()) {
        this->RegisterOp(std::move(name), std::move(op));
      }
    }

    static const auto env_hipblaslt = c10::utils::check_env("PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED");
    if (!env_hipblaslt.has_value() || env_hipblaslt.value()) {
      // disallow tuning of hipblaslt with c10::complex
      if constexpr (
          !std::is_same_v<T, c10::complex<float>> &&
          !std::is_same_v<T, c10::complex<double>>) {
        for (auto&& [name, op] : GetHipBlasLtGemmStridedBatchedTypeStringAndOps<T, ALayout, BLayout>()) {
          this->RegisterOp(std::move(name), std::move(op));
        }
      }
    }
#endif

    this->RegisterOp(std::string("Default"), std::make_unique<DefaultGemmStridedBatchedOp<T>>());
  }

  std::string Signature() override {
    return fmt::sprintf("GemmStridedBatchedTunableOp_%s_%c%c", TypeName<T>(T{}), BlasOpToString(ALayout), BlasOpToString(BLayout));
  }
};

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free