Home / Class/ is_same_v Class — pytorch Architecture

is_same_v Class — pytorch Architecture

Architecture documentation for the is_same_v class in OperationUtils.h from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/mps/OperationUtils.h lines 448–473

template <typename encoder_t,
          typename = std::enable_if_t<std::is_same_v<id<MTLComputeCommandEncoder>, encoder_t> ||
                                      std::is_same_v<id<MTLArgumentEncoder>, encoder_t>>>
static inline void mtl_setBuffer(encoder_t encoder, const TensorBase& t, unsigned idx) {
  if (C10_UNLIKELY(t.device().type() == kCPU)) {
    if constexpr (std::is_same_v<id<MTLComputeCommandEncoder>, encoder_t>) {
      TORCH_CHECK(t.dim() == 0, "Passed CPU tensor to MPS op");
      // MPS does not support doubles, silently downcast CPU scalar to float
      if (C10_UNLIKELY(t.scalar_type() == kDouble)) {
        auto val = static_cast<float>(*reinterpret_cast<const double*>(t.const_data_ptr()));
        [encoder setBytes:&val length:sizeof(val) atIndex:idx];
        return;
      }
      if (C10_UNLIKELY(t.scalar_type() == kComplexDouble)) {
        auto val = static_cast<c10::complex<float>>(*reinterpret_cast<const c10::complex<double>*>(t.const_data_ptr()));
        [encoder setBytes:&val length:sizeof(val) atIndex:idx];
        return;
      }
      [encoder setBytes:t.storage().data() length:t.element_size() atIndex:idx];
    } else {
      TORCH_CHECK(false, "Passed CPU tensor to MPS op");
    }
    return;
  }
  [encoder setBuffer:getMTLBufferStorage(t) offset:t.storage_offset() * t.element_size() atIndex:idx];
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free