is_same_v Class — pytorch Architecture
Architecture documentation for the is_same_v class in OperationUtils.h from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/mps/OperationUtils.h lines 448–473
template <typename encoder_t,
typename = std::enable_if_t<std::is_same_v<id<MTLComputeCommandEncoder>, encoder_t> ||
std::is_same_v<id<MTLArgumentEncoder>, encoder_t>>>
static inline void mtl_setBuffer(encoder_t encoder, const TensorBase& t, unsigned idx) {
if (C10_UNLIKELY(t.device().type() == kCPU)) {
if constexpr (std::is_same_v<id<MTLComputeCommandEncoder>, encoder_t>) {
TORCH_CHECK(t.dim() == 0, "Passed CPU tensor to MPS op");
// MPS does not support doubles, silently downcast CPU scalar to float
if (C10_UNLIKELY(t.scalar_type() == kDouble)) {
auto val = static_cast<float>(*reinterpret_cast<const double*>(t.const_data_ptr()));
[encoder setBytes:&val length:sizeof(val) atIndex:idx];
return;
}
if (C10_UNLIKELY(t.scalar_type() == kComplexDouble)) {
auto val = static_cast<c10::complex<float>>(*reinterpret_cast<const c10::complex<double>*>(t.const_data_ptr()));
[encoder setBytes:&val length:sizeof(val) atIndex:idx];
return;
}
[encoder setBytes:t.storage().data() length:t.element_size() atIndex:idx];
} else {
TORCH_CHECK(false, "Passed CPU tensor to MPS op");
}
return;
}
[encoder setBuffer:getMTLBufferStorage(t) offset:t.storage_offset() * t.element_size() atIndex:idx];
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free