Home / Class/ _EmbeddingBagKernelCacheImpl Class — pytorch Architecture

_EmbeddingBagKernelCacheImpl Class — pytorch Architecture

Architecture documentation for the _EmbeddingBagKernelCacheImpl class in EmbeddingBag.h from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/EmbeddingBag.h lines 88–110

template<typename... StorageMixins>
struct _EmbeddingBagKernelCacheImpl : private StorageMixins... {

    _EmbeddingBagKernelCacheImpl() = default;
    // use each of the mixins to store corresponding kernel and block size
    explicit _EmbeddingBagKernelCacheImpl(std::optional<int64_t> maybe_block_size)
      : StorageMixins(maybe_block_size)...
    {}

    // this method is thread safe (call sites may call from different threads)
    template<bool has_weight, typename TIndex, typename TData>
    typename _CallbackAndBlockSize<has_weight, TIndex, TData>::TCallback
    getCallback(int64_t block_size) const {
        // if the cache doesn't store the kernel for the incoming block size
        // (so it is different from the one stored in corresponding mixin)
        // regenerate the kernel (not writing it into the cache so we avoid locks)
        if (block_size != _CallbackAndBlockSize<has_weight, TIndex, TData>::blockSize) {
            return _CallbackAndBlockSize<has_weight, TIndex, TData>::generateCallback(block_size);
        }
        // else retrieve the cached kernel from the corresponding mixin
        return _CallbackAndBlockSize<has_weight, TIndex, TData>::callback;
    }
};

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free