ApplyLogSumExp Class — pytorch Architecture
Architecture documentation for the ApplyLogSumExp class in epilogue_thread_apply_logsumexp.h from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/transformers/cuda/mem_eff_attention/epilogue/epilogue_thread_apply_logsumexp.h lines 106–167
class ApplyLogSumExp {
public:
using ElementOutput = ElementOutput_;
using ElementAccumulator = ElementAccumulator_;
using ElementCompute = ElementCompute_;
using ElementLSE = ElementLSE_;
static int constexpr kElementsPerAccess = ElementsPerAccess;
static int constexpr kCount = kElementsPerAccess;
static constexpr ScaleType::Kind kScale =
cutlass::epilogue::thread::ScaleType::NoBetaScaling;
using FragmentOutput = Array<ElementOutput, kCount>;
using FragmentAccumulator = Array<ElementAccumulator, kElementsPerAccess>;
using FragmentCompute = Array<ElementCompute, kElementsPerAccess>;
using FragmentLSE = Array<ElementLSE, kElementsPerAccess>;
using FragmentScaleBias = FragmentLSE; // Used by epilogue_smem_accumulator.h
public:
//
// Methods
//
CUTLASS_HOST_DEVICE
ApplyLogSumExp() {}
/// Returns true if source is needed
CUTLASS_HOST_DEVICE
bool is_source_needed() const {
return true;
}
/// Functionally required for serial reduction in the epilogue
CUTLASS_HOST_DEVICE
void set_k_partition(int k_partition, int k_partition_count) {}
CUTLASS_HOST_DEVICE
FragmentOutput operator()(
FragmentAccumulator const& AB,
FragmentLSE const& scale_unused,
// bias used as LSE
FragmentLSE const& bias) const {
FragmentCompute frag_AB = NumericArrayConverter<
ElementCompute,
ElementAccumulator,
kElementsPerAccess>()(AB);
FragmentCompute frag_lse_compute =
NumericArrayConverter<ElementCompute, ElementLSE, kElementsPerAccess>()(
bias);
FragmentCompute frag_compute;
minus<FragmentCompute> minus_lse;
detail::ArrayExponential<ElementCompute, kElementsPerAccess> apply_exp;
frag_compute = minus_lse(frag_AB, frag_lse_compute);
frag_compute = apply_exp(frag_compute);
return NumericArrayConverter<
ElementOutput,
ElementCompute,
kElementsPerAccess>()(frag_compute);
}
};
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free