Home / Class/ MetalShaderLibrary Class — pytorch Architecture

MetalShaderLibrary Class — pytorch Architecture

Architecture documentation for the MetalShaderLibrary class in MetalShaderLibrary.h from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/mps/MetalShaderLibrary.h lines 100–187

class MetalShaderLibrary {
 public:
  MetalShaderLibrary(std::string src)
      : shaderSource(std::move(src)), nparams(0), compile_options(nullptr) {}
  MetalShaderLibrary(std::string src, unsigned nparams_)
      : shaderSource(std::move(src)),
        nparams(nparams_),
        compile_options(nullptr) {}
  MetalShaderLibrary(
      std::string src,
      unsigned nparams_,
      MTLCompileOptions* compile_options_)
      : shaderSource(std::move(src)),
        nparams(nparams_),
        compile_options(compile_options_) {}
  MetalShaderLibrary(const MetalShaderLibrary&) = delete;
  virtual ~MetalShaderLibrary();
  std::vector<std::string> getFunctionNames();
  std::shared_ptr<MetalKernelFunction> getKernelFunction(
      const std::string& name);
  // Returns a raw pointer to the kernel function for use in C APIs
  MetalKernelFunction* getCachedKernelFunctionPtr(const std::string& name);
  inline MTLComputePipelineState_t getPipelineStateForFunc(
      const std::string& fname) {
    return getLibraryPipelineState(getLibrary(), fname).first;
  }
  MTLComputePipelineState_t getPipelineStateForFunc(
      const std::string& fname,
      const std::initializer_list<std::string>& params) {
    return getLibraryPipelineState(getLibrary(params), fname).first;
  }
  inline MTLFunction_t getMTLFunction(const std::string& fname) {
    return getLibraryPipelineState(getLibrary(), fname).second;
  }
  MTLFunction_t getMTLFunction(
      const std::string& fname,
      const std::initializer_list<std::string>& params) {
    return getLibraryPipelineState(getLibrary(params), fname).second;
  }
  static MetalShaderLibrary& getBundledLibrary();
  void exec_unary_kernel(
      TensorIteratorBase& iter,
      const std::string& name,
      const std::optional<c10::Scalar> alpha = std::nullopt,
      const std::optional<c10::ScalarType> scalar_arg_type = std::nullopt);
  void exec_binary_kernel(
      TensorIteratorBase& iter,
      const std::string& name,
      const std::optional<c10::Scalar> alpha = std::nullopt,
      const std::optional<c10::ScalarType> scalar_arg_type = std::nullopt);
  void exec_ternary_kernel(TensorIteratorBase& iter, const std::string& name);

  template <typename T>
  void exec_unary_kernel_with_params(
      TensorIteratorBase& iter,
      const std::string& name,
      T params,
      const std::string& params_type_name);
  template <typename T>
  void exec_binary_kernel_with_params(
      TensorIteratorBase& iter,
      const std::string& name,
      T params,
      const std::string& params_type_name);

 protected:
  virtual MTLLibrary_t getLibrary();
  virtual MTLLibrary_t getLibrary(
      const std::initializer_list<std::string>& params);
  MTLLibrary_t library = nullptr;

 private:
  std::pair<MTLComputePipelineState_t, MTLFunction_t> getLibraryPipelineState(
      MTLLibrary_t lib,
      const std::string& fname);
  MTLLibrary_t compileLibrary(const std::string& src);
  std::string shaderSource;
  unsigned nparams;
  MTLCompileOptions* compile_options;
  std::unordered_map<std::string, MTLLibrary_t> libMap;
  std::unordered_map<
      std::string,
      std::pair<MTLComputePipelineState_t, MTLFunction_t>>
      cplMap;
  // Cache for kernel functions returned by getCachedKernelFunctionPtr
  std::unordered_map<std::string, std::unique_ptr<MetalKernelFunction>>
      kernelCache;
};

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free