Home / Class/ ScaledGemmParams Class — pytorch Architecture

ScaledGemmParams Class — pytorch Architecture

Architecture documentation for the ScaledGemmParams class in GemmCommon.h from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/cuda/tunable/GemmCommon.h lines 588–716

template <typename T>
struct ScaledGemmParams : OpParams {
  ScaledGemmParams() = default;
  ScaledGemmParams(const ScaledGemmParams&) = default;
  ScaledGemmParams(ScaledGemmParams&&) noexcept = default;
  ScaledGemmParams& operator=(const ScaledGemmParams&) = default;
  ScaledGemmParams& operator=(ScaledGemmParams&&) noexcept = default;
  ~ScaledGemmParams() override = default;

  std::string BLASSignature() const override {
    // Excluding use_fast_accum and use_rowise booleans for now
    if (bias_ptr == nullptr) {
      return fmt::sprintf("- { function: matmul, M: %ld, N: %ld, K: %ld, lda: %ld, ldb: %ld, ldc: %ld, ldd: %ld, stride_a: 0, stride_b: 0, stride_c: 0, stride_d: 0, "
        "transA: %c, transB: %c, batch_count: 1, scaleA: f32_r, scaleB: f32_r, a_type: %s, b_type: %s, c_type: %s, d_type: %s, scale_type: %s, compute_type: %s }",
        m, n, k, lda, ldb, ldc, ldc, transa, transb,
        ScalarTypeToBLASType(a_dtype), ScalarTypeToBLASType(b_dtype), ScalarTypeToBLASType(c_dtype), ScalarTypeToBLASType(c_dtype),
        ComputeTypeFor<T>(), ComputeTypeFor<T>());
    }
    else {
      return fmt::sprintf("- { function: matmul, M: %ld, N: %ld, K: %ld, lda: %ld, ldb: %ld, ldc: %ld, ldd: %ld, stride_a: 0, stride_b: 0, stride_c: 0, stride_d: 0, "
        "transA: %c, transB: %c, batch_count: 1, scaleA: f32_r, scaleB: f32_r, a_type: %s, b_type: %s, c_type: %s, d_type: %s, bias_type: %s, scale_type: %s, compute_type: %s }",
        m, n, k, lda, ldb, ldc, ldc, transa, transb,
        ScalarTypeToBLASType(a_dtype), ScalarTypeToBLASType(b_dtype), ScalarTypeToBLASType(c_dtype), ScalarTypeToBLASType(c_dtype), ScalarTypeToBLASType(bias_dtype),
        ComputeTypeFor<T>(), ComputeTypeFor<T>());
    }
  }

  std::string Signature() const override {
    // In Blas.cpp, code defaults to a bias_dtype of Half even when there is no bias vector.
    // Search for this line::
    // params.bias_dtype = bias ? bias->scalar_type() : isFloat8Type(out_dtype_) ? at::ScalarType::Half : out_dtype_;
    //
    // In TunableOp, we must distinguish in param signature these two cases: with and without a bias vector.
    return fmt::sprintf("%c%c_%ld_%ld_%ld_ld_%ld_%ld_%ld_rw_%d_bias_%s",
      transa, transb, m, n, k, lda, ldb, ldc,
      a_scaling_type == ScalingType::RowWise && b_scaling_type == ScalingType::RowWise,
      bias_ptr == nullptr ? "None" : at::toString(bias_dtype));
  }

  size_t GetSizeA() const {
    size_t size_stride = lda * ((transa == 'n' || transa == 'N') ? k : m);
    size_t size_dense = m * k;
    return sizeof(T) * (size_stride > size_dense ? size_stride : size_dense);
  }

  size_t GetSizeB() const {
    size_t size_stride = ldb * ((transb == 'n' || transb == 'N') ? n : k);
    size_t size_dense = k * n;
    return sizeof(T) * (size_stride > size_dense ? size_stride : size_dense);
  }

  size_t GetSizeC() const {
    size_t size_stride = ldc * n;
    size_t size_dense = m * n;
    return sizeof(T) * (size_stride > size_dense ? size_stride : size_dense);
  }

  size_t GetSize(bool duplicate_inputs) const {
    size_t size = GetSizeC();
    if (duplicate_inputs) {
      size += GetSizeA();
      size += GetSizeB();
    }
    return size;
  }

  ScaledGemmParams* DeepCopy(bool duplicate_inputs) const {
    ScaledGemmParams* copy = new ScaledGemmParams(*this);
    c10::DeviceIndex device = 0;
    AT_CUDA_CHECK(c10::cuda::GetDevice(&device));
    size_t c_size = GetSizeC();
    copy->c = c10::cuda::CUDACachingAllocator::raw_alloc(c_size);
    AT_CUDA_CHECK(c10::cuda::CUDACachingAllocator::memcpyAsync(
        copy->c, device, c, device, c_size, getCurrentCUDAStream(device), true));
    if (duplicate_inputs) {
      size_t a_size = GetSizeA();
      size_t b_size = GetSizeB();
      copy->a = c10::cuda::CUDACachingAllocator::raw_alloc(a_size);
      copy->b = c10::cuda::CUDACachingAllocator::raw_alloc(b_size);
      copy->duplicate_inputs_ = true;
    }
    return copy;
  }

  // only call on object returned by DeepCopy
  void Delete() {
    c10::cuda::CUDACachingAllocator::raw_delete(c);
    if (duplicate_inputs_) {
      // NOLINTNEXTLINE(*const-cast*)
      c10::cuda::CUDACachingAllocator::raw_delete(const_cast<void*>(a));
      // NOLINTNEXTLINE(*const-cast*)
      c10::cuda::CUDACachingAllocator::raw_delete(const_cast<void*>(b));
    }
  }

  TuningStatus NumericalCheck(ScaledGemmParams<T> *other) {
    auto* ctx = getTuningContext();
    auto cfg = ctx->GetNumericalCheckConfig();
    return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T), cfg) ? OK : FAIL;
  }

  char transa{};
  char transb{};
  int64_t m{};
  int64_t n{};
  int64_t k{};
  const void* a{};
  const void* a_scale_ptr{};
  int64_t lda{};
  ScalarType a_dtype{};
  ScalarType a_scale_dtype{};
  ScalingType a_scaling_type{};
  const void* b{};
  const void* b_scale_ptr{};
  int64_t ldb{};
  ScalarType b_dtype{};
  ScalarType b_scale_dtype{};
  ScalingType b_scaling_type{};
  const void* bias_ptr{};
  ScalarType bias_dtype{};
  void* c{};
  const void* c_scale_ptr{};
  int64_t ldc{};
  ScalarType c_dtype{};
  void* amax_ptr{};
  bool use_fast_accum{};
private:
  bool duplicate_inputs_{false};
};

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free