Count Class — pytorch Architecture
Architecture documentation for the Count class in epilogue_rescale_output.h from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/transformers/cuda/mem_eff_attention/epilogue/epilogue_rescale_output.h lines 62–190
template <
typename ElementOutput_, ///< Data type used to store tensors
typename ElementSource_, //< Data type for source (usually matches
//`ElementOutput`)
int Count, ///< Number of elements computed per operation.
///< Usually it is 128/sizeof_bits<ElementOutput_>,
///< but we use 64 or 32 sometimes when there are not enough data
///< to store
typename ElementAccumulator_, ///< Accumulator data type
typename ElementCompute_, ///< Data type used to compute linear combination
bool isFirst,
bool isLast,
typename FragmentAlphaBeta_,
FloatRoundStyle Round = FloatRoundStyle::round_to_nearest>
class MemoryEfficientAttentionNormalize {
public:
using ElementOutput = ElementOutput_;
using ElementSource = ElementSource_;
using ElementAccumulator = ElementAccumulator_;
using ElementCompute = ElementCompute_;
static int const kCount = Count;
using FragmentOutput = Array<ElementOutput, kCount>;
using FragmentSource = Array<ElementSource, kCount>;
using FragmentAccumulator = Array<ElementAccumulator, kCount>;
using ComputeFragment = Array<ElementCompute, kCount>;
using FragmentAlphaBeta = FragmentAlphaBeta_;
static FloatRoundStyle const kRound = Round;
private:
//
// Data members
//
FragmentAlphaBeta const& s_prime_;
FragmentAlphaBeta const& m_prime_;
public:
/// Constructs the function object, possibly loading from pointers in host
/// memory
CUTLASS_HOST_DEVICE
MemoryEfficientAttentionNormalize(
FragmentAlphaBeta const& s_prime,
FragmentAlphaBeta const& m_prime)
: s_prime_(s_prime), m_prime_(m_prime) {}
/// Returns true if source is needed
CUTLASS_HOST_DEVICE
bool is_source_needed() const {
return !isFirst;
}
/// Functionally required for serial reduction in the epilogue
CUTLASS_HOST_DEVICE
void set_k_partition(int k_partition, int k_partition_count) {}
/// Computes linear scaling: D = alpha * accumulator + beta * source
CUTLASS_HOST_DEVICE
FragmentOutput operator()(
int row,
FragmentAccumulator const& accumulator,
FragmentSource const& source) const {
assert(!isFirst);
// Convert source to internal compute numeric type
NumericArrayConverter<ElementCompute, ElementSource, kCount, Round>
source_converter;
NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round>
accumulator_converter;
// Convert to destination numeric type
NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round>
destination_converter;
ComputeFragment converted_source = source_converter(source);
ComputeFragment converted_accumulator = accumulator_converter(accumulator);
// Perform binary operations
ComputeFragment intermediate;
multiplies<ComputeFragment> mul_add_source;
multiply_add<ComputeFragment> mul_add_accumulator;
// Row sums for full masked out rows are 0, we set them to 1
// In order to avoid NaNs in the output and instead sem them to 0.
ElementCompute denom = s_prime_[row] == 0 ? 1 : s_prime_[row];
ElementCompute alpha = isLast ? (1 / denom) : 1;
ElementCompute beta = alpha * m_prime_[row];
intermediate = mul_add_source(beta, converted_source); // X = beta * C
intermediate = mul_add_accumulator(
alpha, converted_accumulator, intermediate); // D = alpha * Accum + X
return destination_converter(intermediate);
}
/// Computes linear scaling: D = alpha * accumulator
CUTLASS_HOST_DEVICE
FragmentOutput operator()(int row, FragmentAccumulator const& accumulator)
const {
assert(isFirst);
// Convert source to internal compute numeric type
NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round>
accumulator_converter;
// Convert to destination numeric type
NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round>
destination_converter;
ComputeFragment converted_accumulator = accumulator_converter(accumulator);
ComputeFragment intermediate;
multiplies<ComputeFragment> mul_accumulator;
// Row sums for full masked out rows are 0, we set them to 1
// In order to avoid NaNs in the output and instead sem them to 0.
ElementCompute denom = s_prime_[row] == 0 ? 1 : s_prime_[row];
ElementCompute alpha = isLast ? (1 / denom) : 1;
intermediate = mul_accumulator(
alpha, converted_accumulator); // X = alpha * C + uniform
return destination_converter(intermediate);
}
};
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free