MetalShaderLibrary Class — pytorch Architecture
Architecture documentation for the MetalShaderLibrary class in MetalShaderLibrary.h from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/mps/MetalShaderLibrary.h lines 100–187
class MetalShaderLibrary {
public:
MetalShaderLibrary(std::string src)
: shaderSource(std::move(src)), nparams(0), compile_options(nullptr) {}
MetalShaderLibrary(std::string src, unsigned nparams_)
: shaderSource(std::move(src)),
nparams(nparams_),
compile_options(nullptr) {}
MetalShaderLibrary(
std::string src,
unsigned nparams_,
MTLCompileOptions* compile_options_)
: shaderSource(std::move(src)),
nparams(nparams_),
compile_options(compile_options_) {}
MetalShaderLibrary(const MetalShaderLibrary&) = delete;
virtual ~MetalShaderLibrary();
std::vector<std::string> getFunctionNames();
std::shared_ptr<MetalKernelFunction> getKernelFunction(
const std::string& name);
// Returns a raw pointer to the kernel function for use in C APIs
MetalKernelFunction* getCachedKernelFunctionPtr(const std::string& name);
inline MTLComputePipelineState_t getPipelineStateForFunc(
const std::string& fname) {
return getLibraryPipelineState(getLibrary(), fname).first;
}
MTLComputePipelineState_t getPipelineStateForFunc(
const std::string& fname,
const std::initializer_list<std::string>& params) {
return getLibraryPipelineState(getLibrary(params), fname).first;
}
inline MTLFunction_t getMTLFunction(const std::string& fname) {
return getLibraryPipelineState(getLibrary(), fname).second;
}
MTLFunction_t getMTLFunction(
const std::string& fname,
const std::initializer_list<std::string>& params) {
return getLibraryPipelineState(getLibrary(params), fname).second;
}
static MetalShaderLibrary& getBundledLibrary();
void exec_unary_kernel(
TensorIteratorBase& iter,
const std::string& name,
const std::optional<c10::Scalar> alpha = std::nullopt,
const std::optional<c10::ScalarType> scalar_arg_type = std::nullopt);
void exec_binary_kernel(
TensorIteratorBase& iter,
const std::string& name,
const std::optional<c10::Scalar> alpha = std::nullopt,
const std::optional<c10::ScalarType> scalar_arg_type = std::nullopt);
void exec_ternary_kernel(TensorIteratorBase& iter, const std::string& name);
template <typename T>
void exec_unary_kernel_with_params(
TensorIteratorBase& iter,
const std::string& name,
T params,
const std::string& params_type_name);
template <typename T>
void exec_binary_kernel_with_params(
TensorIteratorBase& iter,
const std::string& name,
T params,
const std::string& params_type_name);
protected:
virtual MTLLibrary_t getLibrary();
virtual MTLLibrary_t getLibrary(
const std::initializer_list<std::string>& params);
MTLLibrary_t library = nullptr;
private:
std::pair<MTLComputePipelineState_t, MTLFunction_t> getLibraryPipelineState(
MTLLibrary_t lib,
const std::string& fname);
MTLLibrary_t compileLibrary(const std::string& src);
std::string shaderSource;
unsigned nparams;
MTLCompileOptions* compile_options;
std::unordered_map<std::string, MTLLibrary_t> libMap;
std::unordered_map<
std::string,
std::pair<MTLComputePipelineState_t, MTLFunction_t>>
cplMap;
// Cache for kernel functions returned by getCachedKernelFunctionPtr
std::unordered_map<std::string, std::unique_ptr<MetalKernelFunction>>
kernelCache;
};
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free