Home / Class/ compute_triu_tril Class — pytorch Architecture

compute_triu_tril Class — pytorch Architecture

Architecture documentation for the compute_triu_tril class in TriangularOps.cpp from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/TriangularOps.cpp lines 139–172

template <typename Triangle>
void compute_triu_tril(const Tensor& self, int64_t k, const Tensor &result) {
  if (self.numel() == 0) {
    return;
  }

  bool inplace_op = self.is_same(result);

  bool inplace_update = false;
  Tensor self_c;
  std::tie(inplace_update, self_c) = checkTrilTriuBatchContiguous(self, inplace_op);

  Tensor result_c;
  if (inplace_op && !inplace_update) {
    result_c = at::empty_like(result, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
  } else {
    result_c = result;
  }

  AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(
      ScalarType::ComplexHalf,
      ScalarType::BFloat16,
      ScalarType::Half,
      ScalarType::Bool,
      self.scalar_type(),
      Triangle::op_name,
      [&]{
        apply_triu_tril<scalar_t>(result_c, self_c, inplace_op && inplace_update, k, Triangle::upper);
      });

  if (inplace_op && !inplace_update) {
    result.copy_(result_c);
  }
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free