Home / Class/ ALayout Class — pytorch Architecture

ALayout Class — pytorch Architecture

Architecture documentation for the ALayout class in GemmHipblaslt.h from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/cuda/tunable/GemmHipblaslt.h lines 454–614

template <typename AT, typename BT, typename CT, BlasOp ALayout, BlasOp BLayout, typename ParamsT>
class HipblasltGemmOp : public Callable<ParamsT> {
  public:
    HipblasltGemmOp(hipblasLtMatmulAlgo_t algo) : algo_{algo} {}

    TuningStatus Call(const ParamsT* params) override {
      hipblasOperation_t transa_outer = MapLayoutToHipBlasLt(ALayout);
      hipblasOperation_t transb_outer = MapLayoutToHipBlasLt(BLayout);
      auto a_datatype = HipDataTypeFor<AT>();
      auto b_datatype = HipDataTypeFor<BT>();
      auto in_out_datatype = HipDataTypeFor<CT>();
      auto opa = _hipblasOpFromChar(params->transa);
      auto opb = _hipblasOpFromChar(params->transb);

      TORCH_CHECK(transa_outer == opa && transb_outer == opb, "trans mismatch, shouldn't happen");

      float alpha = GetAlphaFromParams<CT>(params);
      float beta = GetBetaFromParams<CT>(params);

      hipblasLtMatrixLayout_t mat_a, mat_b, mat_c;
      if (opa == HIPBLAS_OP_N) {
        TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutCreate(&mat_a, a_datatype, params->m, params->k, params->lda));
      }
      else {
        TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutCreate(&mat_a, a_datatype, params->k, params->m, params->lda));
      }
      if (opb == HIPBLAS_OP_N) {
        TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutCreate(&mat_b, b_datatype, params->k, params->n, params->ldb));
      }
      else {
        TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutCreate(&mat_b, b_datatype, params->n, params->k, params->ldb));
      }
      TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutCreate(&mat_c, in_out_datatype, params->m, params->n, params->ldc));

      // specific to batched gemmm
      int batch = GetBatchFromParams<CT>(params);
      if (batch > 1) {
        int64_t stride_a = GetStrideAFromParams<CT>(params);
        int64_t stride_b = GetStrideBFromParams<CT>(params);
        int64_t stride_c = GetStrideCFromParams<CT>(params);
        TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutSetAttribute(
            mat_a, HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch, sizeof(batch)));
        TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutSetAttribute(
            mat_a, HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stride_a, sizeof(stride_a)));
        TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutSetAttribute(
            mat_b, HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch, sizeof(batch)));
        TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutSetAttribute(
            mat_b, HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stride_b, sizeof(stride_b)));
        TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutSetAttribute(
            mat_c, HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch, sizeof(batch)));
        TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutSetAttribute(
            mat_c, HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stride_c, sizeof(stride_c)));
      }

      hipblasComputeType_t computeType = HIPBLAS_COMPUTE_32F;
      if (at::globalContext().float32Precision(at::Float32Backend::CUDA, at::Float32Op::MATMUL) == at::Float32Precision::TF32) {
        computeType = HIPBLAS_COMPUTE_32F_FAST_TF32;
      }
      HipBlasLtMatmulDescriptor matmul(computeType, HIP_R_32F);
      matmul.setAttribute(HIPBLASLT_MATMUL_DESC_TRANSA, opa);
      matmul.setAttribute(HIPBLASLT_MATMUL_DESC_TRANSB, opb);

      // specific to scaled gemm
      const void* mat1_scale_ptr = GetAScalePointerFromParams<CT>(params);
      const void* mat2_scale_ptr = GetBScalePointerFromParams<CT>(params);
      const void* result_scale_ptr = GetDScalePointerFromParams<CT>(params);
      if (mat1_scale_ptr && mat2_scale_ptr) {
        hipblasLtMatmulDescAttributes_t a_scale_ptr_desc = HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER;
        hipblasLtMatmulDescAttributes_t b_scale_ptr_desc = HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER;
        if (GetAScalingTypeFromParams<CT>(params) == ScalingType::RowWise) {
#if defined(HIPBLASLT_OUTER_VEC)
          matmul.setAttribute(HIPBLASLT_MATMUL_DESC_A_SCALE_MODE, HIPBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F);
#elif defined(HIPBLASLT_VEC_EXT)
          a_scale_ptr_desc = HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER_VEC_EXT;
#endif
        }
        if (GetBScalingTypeFromParams<CT>(params) == ScalingType::RowWise) {
#if defined(HIPBLASLT_OUTER_VEC)
          matmul.setAttribute(HIPBLASLT_MATMUL_DESC_B_SCALE_MODE, HIPBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F);
#elif defined(HIPBLASLT_VEC_EXT)
          b_scale_ptr_desc = HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER_VEC_EXT;
#endif
        }
        matmul.setAttribute(a_scale_ptr_desc, mat1_scale_ptr);
        matmul.setAttribute(b_scale_ptr_desc, mat2_scale_ptr);
      }
      if (result_scale_ptr) {
        matmul.setAttribute(HIPBLASLT_MATMUL_DESC_D_SCALE_POINTER, result_scale_ptr);
      }

      const void* bias_ptr = GetBiasPointerFromParams<CT>(params);
      auto bias_datatype = GetBiasTypeFromParams<CT>(params);
      if (bias_ptr) {
        matmul.setAttribute(HIPBLASLT_MATMUL_DESC_BIAS_POINTER, bias_ptr);
        matmul.setAttribute(HIPBLASLT_MATMUL_DESC_BIAS_DATA_TYPE, bias_datatype);
        auto activation = GetActivationFromParams<CT>(params);
        if (activation == at::cuda::blas::GEMMAndBiasActivationEpilogue::RELU) {
          matmul.setAttribute(HIPBLASLT_MATMUL_DESC_EPILOGUE, HIPBLASLT_EPILOGUE_RELU_BIAS);
        }
        else if (activation == at::cuda::blas::GEMMAndBiasActivationEpilogue::GELU) {
          matmul.setAttribute(HIPBLASLT_MATMUL_DESC_EPILOGUE, HIPBLASLT_EPILOGUE_GELU_BIAS);
        }
        else {
          matmul.setAttribute(HIPBLASLT_MATMUL_DESC_EPILOGUE, HIPBLASLT_EPILOGUE_BIAS);
        }
      }

      size_t workspace_size = at::cuda::getCUDABlasLtWorkspaceSize();

      auto op_handle = at::cuda::getCurrentCUDABlasLtHandle();

      size_t ret_workspace_size = 0;
      auto status = hipblaslt_ext::matmulIsAlgoSupported(op_handle,
          matmul.descriptor(),
          &alpha,
          mat_a,
          mat_b,
          &beta,
          mat_c,
          mat_c,
          algo_,
          ret_workspace_size);

      if (status == HIPBLAS_STATUS_SUCCESS) {
        if (ret_workspace_size >= workspace_size) {
          return FAIL;
        }
      }
      else {
        return FAIL;
      }

      void* workspace_buffer = at::cuda::getCUDABlasLtWorkspace();

      TORCH_HIPBLASLT_CHECK(hipblasLtMatmul(op_handle,
            matmul.descriptor(),
            &alpha,
            params->a,
            mat_a,
            params->b,
            mat_b,
            &beta,
            params->c,
            mat_c,
            params->c,
            mat_c,
            &algo_,
            workspace_buffer,
            workspace_size,
            at::cuda::getCurrentCUDAStream()));

      //TORCH_HIPBLASLT_CHECK(hipblasLtMatmulDescDestroy(matmul));
      TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutDestroy(mat_a));
      TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutDestroy(mat_b));
      TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutDestroy(mat_c));
      return OK;
    }

  private:
    hipblasLtMatmulAlgo_t algo_;
};

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free