_rowwise_prune_helper Class — pytorch Architecture
Architecture documentation for the _rowwise_prune_helper class in RowwisePrune.cpp from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/RowwisePrune.cpp lines 20–58
template <typename input_t>
std::tuple<Tensor, Tensor> _rowwise_prune_helper(
const Tensor& weights, const Tensor& mask,
ScalarType compressed_indices_dtype) {
int num_non_masked_rows = 0;
auto mask_contig = mask.contiguous();
auto mask_data = mask_contig.data_ptr<bool>();
for (const auto i : c10::irange(mask.numel())) {
num_non_masked_rows += ((mask_data[i] == true) ? 1 : 0);
}
int num_cols = weights.size(1);
auto pruned_2d_tensor = at::empty({num_non_masked_rows, num_cols},
weights.options());
auto compressed_indices_mapping = at::empty({mask.numel()},
compressed_indices_dtype);
AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half,
at::ScalarType::BFloat16,
weights.scalar_type(),
"rowwise_prune_helper", [&]() {
auto* pruned_2d_tensor_data = pruned_2d_tensor.data_ptr<scalar_t>();
auto compressed_indices_mapping_data =
compressed_indices_mapping.data_ptr<input_t>();
auto weights_data = weights.data_ptr<scalar_t>();
int last_row_kept = 0;
for (const auto i : c10::irange(mask.numel())) {
if (mask_data[i]) {
memcpy(pruned_2d_tensor_data + last_row_kept * num_cols,
weights_data + i * num_cols,
num_cols * sizeof (scalar_t));
compressed_indices_mapping_data[i] = last_row_kept;
last_row_kept++;
} else {
compressed_indices_mapping_data[i] = -1;
}
}
});
return std::tuple<Tensor, Tensor>(pruned_2d_tensor,
compressed_indices_mapping);
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free