Home / Class/ CuFFTParamsLRUCache Class — pytorch Architecture

CuFFTParamsLRUCache Class — pytorch Architecture

Architecture documentation for the CuFFTParamsLRUCache class in CuFFTPlanCache.h from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/cuda/CuFFTPlanCache.h lines 379–481

class CuFFTParamsLRUCache {
public:
  using kv_t = typename std::pair<CuFFTParams, CuFFTConfig>;
  using map_t = typename std::unordered_map<std::reference_wrapper<CuFFTParams>,
                                            typename std::list<kv_t>::iterator,
                                            ParamsHash<CuFFTParams>,
                                            ParamsEqual<CuFFTParams>>;
  using map_kkv_iter_t = typename map_t::iterator;


  CuFFTParamsLRUCache() : CuFFTParamsLRUCache(CUFFT_DEFAULT_CACHE_SIZE) {}

  CuFFTParamsLRUCache(int64_t max_size) {
    _set_max_size(max_size);
  }

  CuFFTParamsLRUCache(CuFFTParamsLRUCache&& other) noexcept :
    _usage_list(std::move(other._usage_list)),
    _cache_map(std::move(other._cache_map)),
    _max_size(other._max_size) {}

  CuFFTParamsLRUCache& operator=(CuFFTParamsLRUCache&& other) noexcept {
    _usage_list = std::move(other._usage_list);
    _cache_map = std::move(other._cache_map);
    _max_size = other._max_size;
    return *this;
  }

  // If key is in this cache, return the cached config. Otherwise, emplace the
  // config in this cache and return it.
  // Return const reference because CuFFTConfig shouldn't be tampered with once
  // created.
  const CuFFTConfig &lookup(CuFFTParams params) {
    AT_ASSERT(_max_size > 0);

    map_kkv_iter_t map_it = _cache_map.find(params);
    // Hit, put to list front
    if (map_it != _cache_map.end()) {
      _usage_list.splice(_usage_list.begin(), _usage_list, map_it->second);
      return map_it->second->second;
    }

    // Miss
    // remove if needed
    if (_usage_list.size() >= _max_size) {
      auto last = _usage_list.end();
      last--;
      _cache_map.erase(last->first);
      _usage_list.pop_back();
    }

    // construct new plan at list front, then insert into _cache_map
    _usage_list.emplace_front(std::piecewise_construct,
                       std::forward_as_tuple(params),
                       std::forward_as_tuple(params));
    auto kv_it = _usage_list.begin();
    _cache_map.emplace(std::piecewise_construct,
                std::forward_as_tuple(kv_it->first),
                std::forward_as_tuple(kv_it));
    return kv_it->second;
  }

  void clear() {
    _cache_map.clear();
    _usage_list.clear();
  }

  void resize(int64_t new_size) {
    _set_max_size(new_size);
    auto cur_size = _usage_list.size();
    if (cur_size > _max_size) {
      auto delete_it = _usage_list.end();
      for (size_t i = 0; i < cur_size - _max_size; i++) {
        delete_it--;
        _cache_map.erase(delete_it->first);
      }
      _usage_list.erase(delete_it, _usage_list.end());
    }
  }

  size_t size() const { return _cache_map.size(); }

  size_t max_size() const noexcept { return _max_size; }

  std::mutex mutex;

private:
  // Only sets size and does value check. Does not resize the data structures.
  void _set_max_size(int64_t new_size) {
    // We check that 0 <= new_size <= CUFFT_MAX_PLAN_NUM here. Since
    // CUFFT_MAX_PLAN_NUM is of type size_t, we need to do non-negativity check
    // first.
    TORCH_CHECK(new_size >= 0,
             "cuFFT plan cache size must be non-negative, but got ", new_size);
    TORCH_CHECK(new_size <= CUFFT_MAX_PLAN_NUM,
             "cuFFT plan cache size can not be larger than ", CUFFT_MAX_PLAN_NUM, ", but got ", new_size);
    _max_size = static_cast<size_t>(new_size);
  }

  std::list<kv_t> _usage_list;
  map_t _cache_map;
  size_t _max_size;
};

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free