algorithm Class — pytorch Architecture
Architecture documentation for the algorithm class in HistogramKernel.cpp from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/cpu/HistogramKernel.cpp lines 79–205
template<typename input_t, BIN_SELECTION_ALGORITHM algorithm>
void histogramdd_cpu_contiguous(Tensor& hist, const TensorList& bin_edges,
const Tensor& input, const std::optional<Tensor>& weight) {
TORCH_INTERNAL_ASSERT(input.dim() == 2);
const int64_t N = input.size(0);
if (weight.has_value()) {
TORCH_INTERNAL_ASSERT(weight.value().dim() == 1 && weight.value().numel() == N);
}
const int64_t D = input.size(1);
TORCH_INTERNAL_ASSERT(int64_t(bin_edges.size()) == D);
for (const auto dim : c10::irange(D)) {
TORCH_INTERNAL_ASSERT(bin_edges[dim].is_contiguous());
TORCH_INTERNAL_ASSERT(hist.size(dim) + 1 == bin_edges[dim].numel());
}
if (D == 0) {
// hist is an empty tensor in this case; nothing to do here
return;
}
TensorAccessor<const input_t, 2> accessor_in = input.accessor<const input_t, 2>();
/* Constructs a std::optional<TensorAccessor> containing an accessor if
* the optional weight tensor has a value.
*/
const auto accessor_wt = weight.has_value()
? std::optional<TensorAccessor<const input_t, 1>>(weight.value().accessor<const input_t, 1>())
: std::optional<TensorAccessor<const input_t, 1>>();
std::vector<input_t*> bin_seq(D);
std::vector<int64_t> num_bin_edges(D);
std::vector<input_t> leftmost_edge(D), rightmost_edge(D);
for (const auto dim : c10::irange(D)) {
bin_seq[dim] = bin_edges[dim].data_ptr<input_t>();
num_bin_edges[dim] = bin_edges[dim].numel();
leftmost_edge[dim] = bin_seq[dim][0];
rightmost_edge[dim] = bin_seq[dim][num_bin_edges[dim] - 1];
}
int64_t GRAIN_SIZE = std::max(int64_t(1), HISTOGRAM_GRAIN_SIZE / D);
/* Parallelizes processing of input using at::parallel_for.
* Each thread accumulates a local result into their own slice of
* thread_histograms which get summed together at the end.
*/
const auto num_threads = at::get_num_threads();
const auto hist_sizes = hist.sizes();
DimVector thread_hist_sizes(hist_sizes.size() + 1);
thread_hist_sizes[0] = num_threads;
std::copy(hist_sizes.begin(), hist_sizes.end(),
thread_hist_sizes.begin() + 1);
Tensor thread_histograms = at::zeros(thread_hist_sizes, hist.dtype());
TORCH_INTERNAL_ASSERT(thread_histograms.is_contiguous());
at::parallel_for(0, N, GRAIN_SIZE, [&](int64_t start, int64_t end) {
const auto tid = at::get_thread_num();
auto hist_strides = thread_histograms.strides();
input_t *hist_local_data = thread_histograms.data_ptr<input_t>();
// View only this thread's local results
hist_local_data += hist_strides[0] * tid;
hist_strides = hist_strides.slice(1);
for (const auto i : c10::irange(start, end)) {
bool skip_elt = false;
int64_t hist_index = 0;
for (const auto dim : c10::irange(D)) {
const input_t elt = accessor_in[i][dim];
// Skips elements which fall outside the specified bins and NaN elements
if (!(elt >= leftmost_edge[dim] && elt <= rightmost_edge[dim])) {
skip_elt = true;
break;
}
int64_t pos = -1;
if (algorithm == BINARY_SEARCH) {
// Handles the general case via binary search on the bin edges.
pos = std::upper_bound(bin_seq[dim], bin_seq[dim] + num_bin_edges[dim], elt)
- bin_seq[dim] - 1;
} else if (algorithm == LINEAR_INTERPOLATION
|| algorithm == LINEAR_INTERPOLATION_WITH_LOCAL_SEARCH) {
/* When bin_edges is known to be a linear progression, maps elt to
* the appropriate bin via simple division.
*/
pos = static_cast<int64_t>((elt - leftmost_edge[dim])
* (num_bin_edges[dim] - 1)
/ (rightmost_edge[dim] - leftmost_edge[dim]));
/* Ensures consistency with bin_edges by checking the bins to the left and right
* of the selected position. Necessary for cases in which an element very close
* to a bin edge may be misclassified by simple division.
*/
if (algorithm == LINEAR_INTERPOLATION_WITH_LOCAL_SEARCH) {
int64_t pos_min = std::max(static_cast<int64_t>(0), pos - 1);
int64_t pos_max = std::min(pos + 2, num_bin_edges[dim]);
pos = std::upper_bound(bin_seq[dim] + pos_min, bin_seq[dim] + pos_max, elt)
- bin_seq[dim] - 1;
}
} else {
TORCH_INTERNAL_ASSERT(false);
}
// Unlike other bins, the rightmost bin includes its right boundary
if (pos == (num_bin_edges[dim] - 1)) {
pos -= 1;
}
hist_index += hist_strides[dim] * pos;
}
if (!skip_elt) {
// In the unweighted case, the default weight is 1
input_t wt = accessor_wt.has_value() ? accessor_wt.value()[i] : static_cast<input_t>(1);
hist_local_data[hist_index] += wt;
}
}
});
at::sum_out(hist, thread_histograms, /*dim=*/{0});
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free