Home / Class/ apply_triu_tril Class — pytorch Architecture

apply_triu_tril Class — pytorch Architecture

Architecture documentation for the apply_triu_tril class in TriangularOps.cpp from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/TriangularOps.cpp lines 87–127

template <typename scalar_t>
void apply_triu_tril(const Tensor& result, const Tensor& self, bool inplace, int64_t k, bool upper) {
  auto n = self.size(-2);
  auto m = self.size(-1);
  auto self_data = self.const_data_ptr<scalar_t>();
  auto self_stride = (self.dim() > 2 && self.stride(-3) > 0) ? self.stride(-3) : 1;
  auto batchsize = batchCountTrilTriu(result);
  auto self_row_stride = self.stride(-2);
  auto self_col_stride = self.stride(-1);

  auto result_data = result.data_ptr<scalar_t>();
  int64_t result_stride = 0, result_row_stride = 0, result_col_stride = 0;
  if (result_data != self_data) {
    result_stride = (result.dim() > 2 && result.stride(-3) > 0) ? result.stride(-3) : 1;
    result_row_stride = result.stride(-2);
    result_col_stride = result.stride(-1);
  } else {
    result_stride = self_stride;
    result_row_stride = self_row_stride;
    result_col_stride = self_col_stride;
  }

  parallel_for(0, batchsize, 0, [&](int64_t start, int64_t end) {
    for (const auto b : c10::irange(start, end)) {
      const scalar_t* self_batch = &self_data[b * self_stride];
      scalar_t* result_batch = &result_data[b * result_stride];
      apply_triu_tril_single<scalar_t>(
          result_batch,
          self_batch,
          inplace,
          k,
          n,
          m,
          result_row_stride,
          result_col_stride,
          self_row_stride,
          self_col_stride,
          upper);
    }
  });
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free