Home / Class/ exec_unary_kernel_with_params Class — pytorch Architecture

exec_unary_kernel_with_params Class — pytorch Architecture

Architecture documentation for the exec_unary_kernel_with_params class in OperationUtils.h from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/mps/OperationUtils.h lines 635–707

template <typename T>
void MetalShaderLibrary::exec_unary_kernel_with_params(TensorIteratorBase& iter,
                                                       const std::string& name,
                                                       T params,
                                                       const std::string& params_type_name) {
  using namespace at::mps;
  // Decompose 64-bit tensor into 32-bit ones
  if (!iter.can_use_32bit_indexing()) {
    for (auto&& sub_iter : iter.with_32bit_indexing()) {
      exec_unary_kernel_with_params(sub_iter, name, params, params_type_name);
    }
    return;
  }

  auto inputTensor = iter.input(0);
  auto outputTensor = iter.output(0);
  uint32_t length = iter.numel();
  if (length == 0) {
    return;
  }
  auto kernel_name = fmt::format("{}_{}_{}_{}{}",
                                 name,
                                 iter.is_contiguous() ? "dense" : "strided",
                                 scalarToMetalTypeString(outputTensor),
                                 scalarToMetalTypeString(inputTensor),
                                 fmt::format("_{}", params_type_name));
  @autoreleasepool {
    auto cplState = getPipelineStateForFunc(kernel_name);

    MPSStream* mpsStream = getCurrentMPSStream();
    dispatch_sync(mpsStream->queue(), ^() {
      auto computeEncoder = mpsStream->commandEncoder();

      getMPSProfiler().beginProfileKernel(cplState, name, {inputTensor});

      [computeEncoder setComputePipelineState:cplState];
      bind_iter_tensors(computeEncoder, iter);
      if (!iter.is_contiguous()) {
        mtl_setArgs<2>(computeEncoder,
                       outputTensor.sizes(),
                       inputTensor.strides(),
                       outputTensor.strides(),
                       inputTensor.ndimension());
      }
      detail::mtl_setArg(computeEncoder, params, iter.is_contiguous() ? 2 : 6);
      mtl_dispatch1DJob(computeEncoder, cplState, length);

      getMPSProfiler().endProfileKernel(cplState);
    });
  }
}

template <typename T>
void MetalShaderLibrary::exec_binary_kernel_with_params(TensorIteratorBase& iter,
                                                        const std::string& name,
                                                        T params,
                                                        const std::string& params_type_name) {
  using namespace mps;
  // TODO: Figure a better place to downcast double scalars (probably in tensor iterator itself?)
  // Right now running something like 1.0-torch.rand(5, device='mps') will create iterator with
  // double as common dtype (because Python floating point are always 64-bit values)
  TORCH_CHECK(iter.output().scalar_type() != at::kDouble, "float64 is not supported on MPS");

  // Skip for empty iterators
  if (iter.numel() == 0) {
    return;
  }

  // Decompose 64-bit tensor into 32-bit ones
  if (!iter.can_use_32bit_indexing()) {
    for (auto&& sub_iter : iter.with_32bit_indexing()) {
      exec_binary_kernel_with_params(sub_iter, name, params, params_type_name);
    }

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free