Home / Class/ RocblasGemmStridedBatchedOp Class — pytorch Architecture

RocblasGemmStridedBatchedOp Class — pytorch Architecture

Architecture documentation for the RocblasGemmStridedBatchedOp class in GemmRocblas.h from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/cuda/tunable/GemmRocblas.h lines 205–241

template <typename T>
class RocblasGemmStridedBatchedOp : public Callable<GemmStridedBatchedParams<T>> {
  public:
    RocblasGemmStridedBatchedOp(int solution) : solution_{solution} {}

    TuningStatus Call(const GemmStridedBatchedParams<T>* params) override {
      auto input_output_type = RocBlasDataTypeFor<T>();
      if (at::globalContext().float32Precision(at::Float32Backend::CUDA, at::Float32Op::MATMUL) == at::Float32Precision::TF32 && input_output_type == rocblas_datatype_f32_r)
        return FAIL;  // no support for TF32 in rocBLAS
      auto compute_type = RocBlasComputeTypeFor<T>();
      auto h_a = DoCastForHalfOrBfloat16(params->alpha);
      auto h_b = DoCastForHalfOrBfloat16(params->beta);
      auto status = rocblas_gemm_strided_batched_ex(
          (rocblas_handle)at::cuda::getCurrentCUDABlasHandle(),
          _rocblasOpFromChar(params->transa),
          _rocblasOpFromChar(params->transb),
          params->m, params->n, params->k,
          &h_a,
          params->a, input_output_type, params->lda, params->stride_a,
          params->b, input_output_type, params->ldb, params->stride_b,
          &h_b,
          params->c, input_output_type, params->ldc, params->stride_c,
          params->c, input_output_type, params->ldc, params->stride_c,
          params->batch,
          compute_type,
          rocblas_gemm_algo_solution_index,
          solution_,
          rocblas_gemm_flags_none);
      if (status != rocblas_status_success) {
        return FAIL;
      }
      return OK;
    }

  private:
    int solution_;
};

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free