Home / Class/ algorithm Class — pytorch Architecture

algorithm Class — pytorch Architecture

Architecture documentation for the algorithm class in HistogramKernel.cpp from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/cpu/HistogramKernel.cpp lines 79–205

template<typename input_t, BIN_SELECTION_ALGORITHM algorithm>
void histogramdd_cpu_contiguous(Tensor& hist, const TensorList& bin_edges,
        const Tensor& input, const std::optional<Tensor>& weight) {
    TORCH_INTERNAL_ASSERT(input.dim() == 2);

    const int64_t N = input.size(0);
    if (weight.has_value()) {
        TORCH_INTERNAL_ASSERT(weight.value().dim() == 1 && weight.value().numel() == N);
    }

    const int64_t D = input.size(1);
    TORCH_INTERNAL_ASSERT(int64_t(bin_edges.size()) == D);
    for (const auto dim : c10::irange(D)) {
        TORCH_INTERNAL_ASSERT(bin_edges[dim].is_contiguous());
        TORCH_INTERNAL_ASSERT(hist.size(dim) + 1 == bin_edges[dim].numel());
    }

    if (D == 0) {
        // hist is an empty tensor in this case; nothing to do here
        return;
    }

    TensorAccessor<const input_t, 2> accessor_in = input.accessor<const input_t, 2>();

    /* Constructs a std::optional<TensorAccessor> containing an accessor if
     * the optional weight tensor has a value.
     */
    const auto accessor_wt = weight.has_value()
            ? std::optional<TensorAccessor<const input_t, 1>>(weight.value().accessor<const input_t, 1>())
            : std::optional<TensorAccessor<const input_t, 1>>();

    std::vector<input_t*> bin_seq(D);
    std::vector<int64_t> num_bin_edges(D);
    std::vector<input_t> leftmost_edge(D), rightmost_edge(D);

    for (const auto dim : c10::irange(D)) {
        bin_seq[dim] = bin_edges[dim].data_ptr<input_t>();
        num_bin_edges[dim] = bin_edges[dim].numel();
        leftmost_edge[dim] = bin_seq[dim][0];
        rightmost_edge[dim] = bin_seq[dim][num_bin_edges[dim] - 1];
    }

    int64_t GRAIN_SIZE = std::max(int64_t(1), HISTOGRAM_GRAIN_SIZE / D);

    /* Parallelizes processing of input using at::parallel_for.
     * Each thread accumulates a local result into their own slice of
     * thread_histograms which get summed together at the end.
     */
    const auto num_threads = at::get_num_threads();
    const auto hist_sizes = hist.sizes();
    DimVector thread_hist_sizes(hist_sizes.size() + 1);
    thread_hist_sizes[0] = num_threads;
    std::copy(hist_sizes.begin(), hist_sizes.end(),
              thread_hist_sizes.begin() + 1);
    Tensor thread_histograms = at::zeros(thread_hist_sizes, hist.dtype());
    TORCH_INTERNAL_ASSERT(thread_histograms.is_contiguous());

    at::parallel_for(0, N, GRAIN_SIZE, [&](int64_t start, int64_t end) {
        const auto tid = at::get_thread_num();
        auto hist_strides = thread_histograms.strides();
        input_t *hist_local_data = thread_histograms.data_ptr<input_t>();

        // View only this thread's local results
        hist_local_data += hist_strides[0] * tid;
        hist_strides = hist_strides.slice(1);

        for (const auto i : c10::irange(start, end)) {
            bool skip_elt = false;
            int64_t hist_index = 0;

            for (const auto dim : c10::irange(D)) {
                const input_t elt = accessor_in[i][dim];

                // Skips elements which fall outside the specified bins and NaN elements
                if (!(elt >= leftmost_edge[dim] && elt <= rightmost_edge[dim])) {
                    skip_elt = true;
                    break;
                }

                int64_t pos = -1;

                if (algorithm == BINARY_SEARCH) {
                    // Handles the general case via binary search on the bin edges.
                    pos = std::upper_bound(bin_seq[dim], bin_seq[dim] + num_bin_edges[dim], elt)
                            - bin_seq[dim] - 1;
                } else if (algorithm == LINEAR_INTERPOLATION
                        || algorithm == LINEAR_INTERPOLATION_WITH_LOCAL_SEARCH) {
                    /* When bin_edges is known to be a linear progression, maps elt to
                     * the appropriate bin via simple division.
                     */
                    pos = static_cast<int64_t>((elt - leftmost_edge[dim])
                            * (num_bin_edges[dim] - 1)
                            / (rightmost_edge[dim] - leftmost_edge[dim]));

                    /* Ensures consistency with bin_edges by checking the bins to the left and right
                     * of the selected position. Necessary for cases in which an element very close
                     * to a bin edge may be misclassified by simple division.
                     */
                    if (algorithm == LINEAR_INTERPOLATION_WITH_LOCAL_SEARCH) {
                        int64_t pos_min = std::max(static_cast<int64_t>(0), pos - 1);
                        int64_t pos_max = std::min(pos + 2, num_bin_edges[dim]);
                        pos = std::upper_bound(bin_seq[dim] + pos_min, bin_seq[dim] + pos_max, elt)
                                - bin_seq[dim] - 1;
                    }
                } else {
                    TORCH_INTERNAL_ASSERT(false);
                }

                // Unlike other bins, the rightmost bin includes its right boundary
                if (pos == (num_bin_edges[dim] - 1)) {
                    pos -= 1;
                }

                hist_index += hist_strides[dim] * pos;
            }

            if (!skip_elt) {
                // In the unweighted case, the default weight is 1
                input_t wt = accessor_wt.has_value() ? accessor_wt.value()[i] : static_cast<input_t>(1);

                hist_local_data[hist_index] += wt;
            }
        }
    });

    at::sum_out(hist, thread_histograms, /*dim=*/{0});
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free