Home / Class/ DefaultScaledGemmOp Class — pytorch Architecture

DefaultScaledGemmOp Class — pytorch Architecture

Architecture documentation for the DefaultScaledGemmOp class in TunableGemm.h from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/cuda/tunable/TunableGemm.h lines 85–116

class DefaultScaledGemmOp : public Callable<ScaledGemmParams<T>> {
  public:
    TuningStatus Call(const ScaledGemmParams<T>* params) override {
      at::cuda::blas::scaled_gemm(
          params->transa,
          params->transb,
          params->m,
          params->n,
          params->k,
          params->a,
          params->a_scale_ptr,
          params->lda,
          params->a_dtype,
          params->a_scale_dtype,
          params->a_scaling_type,
          params->b,
          params->b_scale_ptr,
          params->ldb,
          params->b_dtype,
          params->b_scale_dtype,
          params->b_scaling_type,
          params->bias_ptr,
          params->bias_dtype,
          params->c,
          params->c_scale_ptr,
          params->ldc,
          params->c_dtype,
          params->use_fast_accum,
          std::nullopt /* alpha */);
      return OK;
    }
};

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free