Home / Class/ stage_pack_weights Class — pytorch Architecture

stage_pack_weights Class — pytorch Architecture

Architecture documentation for the stage_pack_weights class in Mm.h from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/vulkan/ops/Mm.h lines 16–54

template <typename T>
void stage_pack_weights(
    api::Context* const context,
    vTensor& v_weight,
    const Tensor& weight,
    const int64_t src_kb_sz,
    const int64_t src_kh_sz,
    const int64_t src_kw_sz,
    const int64_t dst_kh_sz,
    const int64_t dst_kw_sz) {
  const int64_t src_matrix_sz = src_kw_sz * src_kh_sz;
  const int64_t dst_plane_sz = dst_kw_sz * dst_kh_sz;
  const int64_t dst_matrix_sz = dst_plane_sz * 4;
  const T* const src_weight_ptr = weight.const_data_ptr<T>();
  api::StorageBuffer staging(context, api::kFloat, v_weight.gpu_numel());
  {
    api::MemoryMap mapping(staging.buffer(), api::MemoryAccessType::WRITE);

    T* dst_weight_ptr = mapping.template data<T>();

    memset(dst_weight_ptr, 0, v_weight.nbytes());

    for (const auto src_b : c10::irange(src_kb_sz)) {
      for (const auto src_h : c10::irange(src_kh_sz)) {
        for (const auto src_w : c10::irange(src_kw_sz)) {
          int64_t dst_plane = 2 * (src_h % 2) + (src_w % 2);
          int64_t dst_index = (src_h / 2) * dst_kw_sz + (src_w / 2);
          memcpy(
              dst_weight_ptr + src_b * dst_matrix_sz +
                  dst_plane * dst_plane_sz + dst_index,
              src_weight_ptr + src_b * src_matrix_sz + src_h * src_kw_sz +
                  src_w,
              sizeof(T));
        }
      }
    }
  }
  utils::pack_staging_to_vtensor(staging.buffer(), v_weight);
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free