Home / Class/ CachingHostAllocatorImpl Class — pytorch Architecture

CachingHostAllocatorImpl Class — pytorch Architecture

Architecture documentation for the CachingHostAllocatorImpl class in CachingHostAllocator.h from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/core/CachingHostAllocator.h lines 264–960

template <
    typename S,
    typename E,
    typename B = HostBlock<S>>
struct CachingHostAllocatorImpl {

  using BlockPool = HostBlockPool<S, E, B>;
  using PrivatePool = HostPrivatePool<S, E, B>;

  virtual ~CachingHostAllocatorImpl() {
    if (active_) {
      active_ = false;
      getBackgroundThreadPool()->waitWorkComplete();
    }
  }

  // return data_ptr and block pair.
  virtual std::pair<void*, void*> allocate(size_t size) {
    if (size == 0) {
      return {nullptr, nullptr};
    }

    auto&& [mempool_id, pool] = get_allocation_pool_for_current_stream();

    // If we are using background threads, we can process events in the
    // background.
    // In the case of a non-default pool, we never use the background
    // thread. Since allocations happen only during stream capture
    // time and there are no events to process anyway, speeding up
    // event processing with a helper thread is not helpful.
    if (!(pinned_use_background_threads() &&
          mempool_id.first == 0 && mempool_id.second == 0)) {
      process_events(pool);
    }

    // Round up the allocation to the nearest power of two to improve reuse.
    // These power of two sizes are also used to index into the free list.
    size_t roundSize = c10::llvm::PowerOf2Ceil(size);

    // First, try to allocate from the free list of the chosen pool
    auto* block = get_free_block(roundSize, pool);
    if (block) {
      block->was_allocated_during_stream_capture_ = current_stream_is_capturing_fast_path();
      return {block->ptr_, reinterpret_cast<void*>(block)};
    }

    // Check in the recently freed blocks with pending events to see if we
    // can reuse them. Call get_free_block again after processing events
    if (pinned_use_background_threads()) {
      // Launch the background thread and process events in a loop.
      static bool background_thread_flag [[maybe_unused]] = [this] {
        active_ = true;
        getBackgroundThreadPool()->run([&]() {
          while (active_) {
            // Background thread conservatively processes default pool
            // events only. Private pools never use a background
            // thread because cuda stream capture does not benefit
            // from asynchronous event processing.
            process_events(default_pool_);
            std::this_thread::sleep_for(std::chrono::microseconds(100));
          }
        });
        return true;
      }();
    }

    // Slow path: if we can't allocate from the cached free list, we need
    // to create a new block.
    void* ptr = nullptr;
    allocate_host_memory(roundSize, &ptr);

    // Then, create a new block.
    block = new B(roundSize, ptr);
    block->allocated_ = true;
    block->owning_pool_ = mempool_id;
    block->was_allocated_during_stream_capture_ = current_stream_is_capturing_fast_path();
    add_allocated_block(block, pool);
    return {block->ptr_, reinterpret_cast<void*>(block)};
  }

  virtual void free(void* ctx) {
    if (!ctx) {
      return;
    }

    // Note: we can assume that free is correctly paired with alloc, and thus we
    // do not need to look up the ctx in blocks_.
    auto* block = reinterpret_cast<B*>(ctx);

    std::optional<std::vector<E>> events;
    ska::flat_hash_set<S> streams;
    bool allocated_during_capture = false;
    {
      std::lock_guard<std::mutex> g(block->mutex_);
      block->allocated_ = false;
      allocated_during_capture = block->was_allocated_during_stream_capture_;
      if (block->streams_.empty()) {
        TORCH_INTERNAL_ASSERT(block->event_count_ == 0);
      } else {
        events = std::vector<E>();
        events->reserve(block->streams_.size());
        block->event_count_ += block->streams_.size();
        // Move out streams to avoid holding the mutex during event recording
        streams = std::move(block->streams_);
        block->streams_.clear();
      }
    }

    // cudaEventRecord() does not work for indicating when a block can
    // be reused while stream capture is happening, since no actual
    // GPU work happens during stream capture.
    if (!allocated_during_capture) {
      // Event recording must be done outside the mutex to avoid potential
      // deadlocks (e.g., when Python GIL is involved)
      for (auto stream : streams) {
        record_stream(events, stream);
      }
    }

    if (!events.has_value()) {
      auto& pool = pool_from_block(block);
      auto index = size_index(block->size_);
      std::lock_guard<std::mutex> g(pool.free_list_[index].mutex_);
      pool.free_list_[index].list_.push_back(block);
    } else if (allocated_during_capture) {
      // pass: No events are ever recorded during stream capture.

      // If the block was ever used, block->event_count_ will be above
      // 0 and thus can never be recycled by
      // process_events_for_specific_size. Thus, this block will never
      // be returned again. "Leaking" memory like this is intentional
      // to avoid subtle cuda graph problems described here:
      // https://github.com/pytorch/pytorch/pull/161583#issuecomment-3229885771

      // Otherwise, if the block was never used, block->event_count_
      // will be 0 and thus process_events_for_specific_size can
      // return this block.
    } else {
      // restore these events that record by used streams.
      auto& pool = pool_from_block(block);
      std::lock_guard<std::mutex> g(pool.events_mutex_);
      for (auto&& event : *events) {
        pool.events_.emplace_front(std::move(event), block);
      }
    }
  }

  virtual bool record_event(void* ptr, void* ctx, c10::Stream s) {
    B *block = nullptr;
    if (block_exists(ctx)) {
      block = reinterpret_cast<B*>(ctx);
    } else {
      block = get_block_from_ptr(ptr);
    }
    if (block == nullptr) {
      return false;
    }
    S stream = S(s);
    std::lock_guard<std::mutex> gb(block->mutex_);
    TORCH_INTERNAL_ASSERT(block->allocated_);
    block->streams_.insert(stream);
    return true;
  }

  void free_from_pool(BlockPool &pool) {
    for (size_t i = 0; i < pool.free_list_.size(); ++i) {
      std::lock(pool.free_list_[i].mutex_, pool.blocks_mutex_);
      std::lock_guard<std::mutex> gf(pool.free_list_[i].mutex_, std::adopt_lock);
      std::lock_guard<std::mutex> gb(pool.blocks_mutex_, std::adopt_lock);

      std::vector<B*> blocks_to_remove(pool.free_list_[i].list_.begin(),
                                       pool.free_list_[i].list_.end());
      pool.free_list_[i].list_.clear();

      for (auto* block : blocks_to_remove) {
        pool.blocks_.erase(block);
        pool.ptr_to_block_.erase(block->ptr_);
        auto index = size_index(block->size_);
        free_block(block);
        stats_.allocations.decrease(1);
        stats_.allocated_bytes.decrease(block->size_);
        stats_.allocation_bucket_stats[index].decrease(1);
        stats_.allocated_bytes_bucket_stats[index].decrease(block->size_);
        delete block;
      }
    }
  }

  // TODO: Make this take a pool id like in CUDACachingAllocator
  virtual void empty_cache() {
    process_events(default_pool_);
    free_from_pool(default_pool_);

    {
      std::unique_lock<std::shared_mutex> lg(instance_mutex_);
      for (auto it = graph_pools_freeable_.begin(); it != graph_pools_freeable_.end();) {
        process_events(it->second->blocks);
        free_from_pool(it->second->blocks);
        if (it->second->blocks.blocks_.empty()) {
          auto erase_count = graph_pools_.erase(it->first);
          TORCH_INTERNAL_ASSERT(erase_count == 1);
          it = graph_pools_freeable_.erase(it);
        } else {
          ++it;
        }
      }
    }
  }

  inline size_t size_index(size_t size) {
    return c10::llvm::Log2_64_Ceil(size);
  }

  virtual bool pinned_use_background_threads() {
    return c10::CachingAllocator::AcceleratorAllocatorConfig::
        pinned_use_background_threads();
  }

  virtual void copy_data(void* dest [[maybe_unused]], const void* src [[maybe_unused]], std::size_t count [[maybe_unused]]) const {
    TORCH_CHECK_NOT_IMPLEMENTED(false, "Not implemented for copy_data");
  }

  HostStats getStats() {
    HostStats stats;

    // To keep getStats lightweight we do *not* flush any available blocks
    // into the free_list. This may skew the stats a bit.

    auto add_bucket_stats = [](Stat& accumulator, const Stat& other) {
      accumulator.allocated += other.allocated;
      accumulator.current += other.current;
      accumulator.freed += other.freed;
      // Since peaks are measured per bucket independently, we add them up
      // to estimate the total peak. This is not strictly correct, but it is
      // the best approximation we can get after the fact.
      accumulator.peak += other.peak;
    };

    // Accurate reading of memory stats requires concurrently holding both the
    // free list mutexes and the blocks mutex. Previously, this was only done in
    // empty_cache function.
    for (size_t i = 0; i < default_pool_.free_list_.size(); ++i) {
      std::lock(
          default_pool_.free_list_[i].mutex_, default_pool_.blocks_mutex_);
      std::lock_guard<std::mutex> gf(
          default_pool_.free_list_[i].mutex_, std::adopt_lock);
      std::lock_guard<std::mutex> gb(
          default_pool_.blocks_mutex_, std::adopt_lock);

      // We collect the slow-path stats only once, since they are not collected
      // per bucket (we pick index 0 arbitrarily). These are also all the host
      // allocations, not taking into account caching and free lists.
      if (i == 0) {
        stats.allocations = stats_.allocations;
        stats.allocated_bytes = stats_.allocated_bytes;
        stats.num_host_alloc = stats.allocations.allocated;
        stats.num_host_free = stats.allocations.freed;
      }

      // Bucket stats need to be merged with the slow-path stats. We do this in
      // a best effort manner, since we can't really replay the cached events per bucket.
      add_bucket_stats(stats.active_requests, stats_.active_bucket_stats[i]);
      add_bucket_stats(stats.active_bytes, stats_.active_bytes_bucket_stats[i]);
      stats.bucket_allocation[i] = stats_.allocation_bucket_stats[i].allocated;
    }

    // Get the timing stats
    {
      std::lock_guard<std::mutex> g(stats_.timing_mutex_);

      stats.host_alloc_time = stats_.host_alloc_time;
      stats.host_free_time = stats_.host_free_time;
    }

    return stats;
  }

  void resetAccumulatedStats() {
    // Resetting accumulated memory stats requires concurrently holding both the
    // free list mutexes and the blocks mutex. Previously, this was only done in
    // empty_cache function.
    for (size_t i = 0; i < default_pool_.free_list_.size(); ++i) {
      std::lock(
          default_pool_.free_list_[i].mutex_, default_pool_.blocks_mutex_);
      std::lock_guard<std::mutex> gf(
          default_pool_.free_list_[i].mutex_, std::adopt_lock);
      std::lock_guard<std::mutex> gb(
          default_pool_.blocks_mutex_, std::adopt_lock);

      if (i == 0) {
        stats_.allocations.reset_accumulated();
        stats_.allocated_bytes.reset_accumulated();
      }
      stats_.active_bucket_stats[i].reset_accumulated();
      stats_.active_bytes_bucket_stats[i].reset_accumulated();
      stats_.allocation_bucket_stats[i].reset_accumulated();
      stats_.allocated_bytes_bucket_stats[i].reset_accumulated();
    }

    // Also reset timing stats
    {
      std::lock_guard<std::mutex> g(stats_.timing_mutex_);
      stats_.host_alloc_time.reset_accumulated();
      stats_.host_free_time.reset_accumulated();
    }
  }

  void resetPeakStats() {
    // Resetting peak memory stats requires concurrently holding both the
    // free list mutexes and the blocks mutex. Previously, this was only done in
    // empty_cache function.
    for (size_t i = 0; i < default_pool_.free_list_.size(); ++i) {
      std::lock(
          default_pool_.free_list_[i].mutex_, default_pool_.blocks_mutex_);
      std::lock_guard<std::mutex> gf(
          default_pool_.free_list_[i].mutex_, std::adopt_lock);
      std::lock_guard<std::mutex> gb(
          default_pool_.blocks_mutex_, std::adopt_lock);

      if (i == 0) {
        stats_.allocations.reset_peak();
        stats_.allocated_bytes.reset_peak();
      }
      stats_.active_bucket_stats[i].reset_peak();
      stats_.active_bytes_bucket_stats[i].reset_peak();
      stats_.allocation_bucket_stats[i].reset_peak();
      stats_.allocated_bytes_bucket_stats[i].reset_peak();
    }

    // Also reset timing stats
    {
      std::lock_guard<std::mutex> g(stats_.timing_mutex_);
      stats_.host_alloc_time.reset_peak();
      stats_.host_free_time.reset_peak();
    }
  }

 private:
  virtual void add_allocated_block(B* block, BlockPool& pool) {
    std::lock_guard<std::mutex> g(pool.blocks_mutex_);
    pool.blocks_.insert(block);
    stats_.allocations.increase(1);
    stats_.allocated_bytes.increase(block->size_);
    pool.ptr_to_block_.insert({block->ptr_, block});

    // Unfortunately, we have to, on the slow path, quickly
    // lock the bucket to record the allocation. This should
    // be a rare event once the cache is warmed up.
    auto size = block->size_;
    auto index = size_index(size);
    {
      std::lock_guard<std::mutex> g(pool.free_list_[index].mutex_);
      stats_.allocation_bucket_stats[index].increase(1);
      stats_.allocated_bytes_bucket_stats[index].increase(size);
      stats_.active_bucket_stats[index].increase(1);
      stats_.active_bytes_bucket_stats[index].increase(size);
    }
  }

  virtual B* get_free_block(size_t size, BlockPool& pool) {
    auto index = size_index(size);
    std::lock_guard<std::mutex> g(pool.free_list_[index].mutex_);
    if (!pool.free_list_[index].list_.empty()) {
      B* block = pool.free_list_[index].list_.back();
      pool.free_list_[index].list_.pop_back();
      block->allocated_ = true;
      stats_.active_bucket_stats[index].increase(1);
      stats_.active_bytes_bucket_stats[index].increase(size);
      return block;
    }
    return nullptr;
  }

  virtual void process_events(BlockPool& pool) {
    // process all events in the given pool until the last unready event.
    process_events_for_specific_size(-1, pool);
  }

  // If size is -1, process all events from backwards until the last unready
  // event. Otherwise, process events for a specific size and on first ready block
  // is found, add it to the free list and return.
  virtual void process_events_for_specific_size(int64_t size, BlockPool& pool) {
    size_t event_count = 0;
    size_t max_events = 0;
    {
      std::lock_guard<std::mutex> g(pool.events_mutex_);
      max_events = pool.events_.size();
    }

    while (true) {
      // Avoid calling cudaEventDestroy while holding a mutex, so move
      // intermediate events out of the lock into this object.
      // process the last event
      std::optional<std::pair<E, B*>> processed;
      {
        std::lock_guard<std::mutex> g(pool.events_mutex_);
        if (!pool.events_.empty()) {
          processed = std::move(pool.events_.back());
          pool.events_.pop_back();
        }
      }

      if (!processed) {
        return;
      }

      if (size != -1) {
        if (event_count++ > max_events) {
          {
            std::lock_guard<std::mutex> g(pool.events_mutex_);
            pool.events_.push_front(std::move(*processed));
          }
          return;
        }
        if (size != (int64_t)processed->second->size_) {
          // if we are processing a specific size, and the size of the block
          // doesn't match, we can't use it.
          {
            std::lock_guard<std::mutex> g(pool.events_mutex_);
            pool.events_.push_front(std::move(*processed));
          }
          continue;
        }
      }

      // otherwise, query the event
      {
        // now, see if we can handle this element
        auto& event = processed->first;
        if (!query_event(event)) {
          // push the event onto the back if it's not ready.
          {
            std::lock_guard<std::mutex> g(pool.events_mutex_);
            if (size == -1) {
              pool.events_.push_back(std::move(*processed));
              return;
            } else {
              pool.events_.push_front(std::move(*processed));
              continue;
            }
          }
        }
      }

      // Process the events.
      TORCH_INTERNAL_ASSERT(processed);
      auto* block = processed->second;
      bool available = false;
      {
        std::lock_guard<std::mutex> g(block->mutex_);
        TORCH_INTERNAL_ASSERT(!block->allocated_);
        block->event_count_--;
        if (block->event_count_ == 0) {
          available = true;
        }
      }

      if (available) {
        auto index = size_index(block->size_);
        std::lock_guard<std::mutex> g(pool.free_list_[index].mutex_);
        pool.free_list_[index].list_.push_back(block);
        stats_.active_bucket_stats[index].decrease(1);
        stats_.active_bytes_bucket_stats[index].decrease(size);
        if (size != -1) {
          return;
        }
      }
    }
  }

  TaskThreadPool* getBackgroundThreadPool() {
    static TaskThreadPool* pool = new TaskThreadPool(1);
    return pool;
  }

 public:
  void begin_allocate_to_pool(
      c10::MempoolId_t pool_id,
      std::function<bool(c10::Stream)> filter) {
    std::unique_lock<std::shared_mutex> lg(instance_mutex_);
    create_or_incref_pool_under_lock(pool_id);
    for (auto it2 = captures_underway_.begin(); it2 != captures_underway_.end();
         ++it2) {
      TORCH_CHECK(
          it2->first != pool_id,
          "beginAllocateToPool: already recording to mempool_id");
    }
    captures_underway_.emplace_back(pool_id, std::move(filter));
    captures_underway_empty_.store(false, std::memory_order_relaxed);
  }

  void end_allocate_to_pool(c10::MempoolId_t pool_id) {
    std::unique_lock<std::shared_mutex> lg(instance_mutex_);
    for (auto it = captures_underway_.begin(); it != captures_underway_.end();
         ++it) {
      if (it->first == pool_id) {
        captures_underway_.erase(it);
        captures_underway_empty_.store(captures_underway_.empty(), std::memory_order_relaxed);
        return;
      }
    }
    TORCH_CHECK(false, "endAllocatePool: not currently recording to mempool_id");
  }

  void create_or_incref_pool_under_lock(c10::MempoolId_t pool_id) {
    auto it = graph_pools_.find(pool_id);
    if (it == graph_pools_.end()) {
      graph_pools_.emplace(pool_id, std::make_unique<PrivatePool>(pool_id));
    } else {
      TORCH_INTERNAL_ASSERT(it->second->use_count > 0);
      it->second->use_count++;
    }
  }

  void release_pool(c10::MempoolId_t pool_id) {
    std::unique_lock<std::shared_mutex> lg(instance_mutex_);
    auto* pp = graph_pools_.at(pool_id).get();
    TORCH_INTERNAL_ASSERT(pp != nullptr);
    auto uc = --(pp->use_count);
    TORCH_INTERNAL_ASSERT(uc >= 0);
    if (uc == 0) {
      bool inserted = graph_pools_freeable_.insert({pool_id, pp}).second;
      TORCH_INTERNAL_ASSERT(inserted);
    }
  }

 private:
  std::tuple<c10::MempoolId_t, BlockPool&> get_allocation_pool_for_current_stream() {
    if (C10_LIKELY(captures_underway_empty_.load(std::memory_order_relaxed))) {
      return {c10::MempoolId_t{0, 0}, default_pool_};
    }

    std::shared_lock<std::shared_mutex> lg(instance_mutex_);
    S stream = get_current_stream();
    for (auto& entry : captures_underway_) {
      if (entry.second(stream)) {
        auto it = graph_pools_.find(entry.first);
        TORCH_INTERNAL_ASSERT(it != graph_pools_.end());
        auto id = entry.first;
        if (C10_UNLIKELY(!stream_is_capturing(stream))) {
          TORCH_WARN_ONCE("CachingHostAllocator is allocating to private pool (",
                          id.first, ",", id.second,
                          ") but the current stream is not capturing. Private "
                          "pools have been tested only during graph capture. "
                          "See https://github.com/pytorch/pytorch/pull/167507#discussion_r2561698011");
        }
        return {id, it->second->blocks};
      }
    }
    return {c10::MempoolId_t{0, 0}, default_pool_};
  }

  // Helper: return the pool containing a block, based on its owning_pool_.
  BlockPool& pool_from_block(B* block) {
    auto id = block->owning_pool_;
    if (id == c10::MempoolId_t{0, 0}) {
      return default_pool_;
    }
    std::shared_lock<std::shared_mutex> lg(instance_mutex_);
    auto it = graph_pools_.find(id);
    TORCH_INTERNAL_ASSERT(it != graph_pools_.end());
    return it->second->blocks;
  }

  // We want to keep the non stream capture case as fast as possible,
  // since memory allocation is often in the critical path in non
  // stream capture code. This function will skip the overhead of
  // locking on instance_mutex_, two virtual function calls, and a
  // call into the CUDA API in order to prevent overheads whenever
  // possible.
  bool current_stream_is_capturing_fast_path() const {
    // Stream capture can allocate only to private pools. If there
    // are no private pools for which capture is currently underway,
    // then by modus tollens the current stream is not capturing.
    if (C10_LIKELY(captures_underway_empty_.load(std::memory_order_relaxed))) {
      return false;
    }
    return stream_is_capturing(get_current_stream());
  }

  B* get_block_from_ptr(void *ptr) {
    std::shared_lock<std::shared_mutex> lk(instance_mutex_);
    {
      std::lock_guard<std::mutex> lk(default_pool_.blocks_mutex_);
      if (default_pool_.ptr_to_block_.count(ptr)) {
        return default_pool_.ptr_to_block_.at(ptr);
      }
    }
    if (C10_LIKELY(graph_pools_.empty())) {
      return nullptr;
    } else {
      for (auto &&[_, private_pool]: graph_pools_) {
        BlockPool& pool = private_pool->blocks;
        std::lock_guard<std::mutex> lk(pool.blocks_mutex_);
        if (pool.ptr_to_block_.count(ptr)) {
          return pool.ptr_to_block_.at(ptr);
        }
      }
      return nullptr;
    }
  }

  bool block_exists(void *block_) {
    B *block = reinterpret_cast<B*>(block_);
    std::shared_lock<std::shared_mutex> lk(instance_mutex_);
    {
      std::lock_guard<std::mutex> lk(default_pool_.blocks_mutex_);
      if (default_pool_.blocks_.count(block)) {
        return true;
      }
    }
    if (C10_LIKELY(graph_pools_.empty())) {
      return false;
    } else {
      for (auto &&[_, private_pool]: graph_pools_) {
        BlockPool& pool = private_pool->blocks;
        std::lock_guard<std::mutex> lk(pool.blocks_mutex_);
        if (pool.blocks_.count(block)) {
          return true;
        }
      }
      return false;
    }
  }

  /* These following functions are runtime-related. */

  // Allocate page-locked memory on the host.
  virtual void allocate_host_memory(size_t size, void** ptr) {
    TORCH_CHECK_NOT_IMPLEMENTED(
        false, "Not implemented for allocate_host_memory");
  }

  // Free block and release the pointer contained in block.
  virtual void free_block(B* block) {
    TORCH_CHECK_NOT_IMPLEMENTED(false, "Not implemented for free_block");
  }

  // Record an event on stream and store event into events.
  virtual void record_stream(std::optional<std::vector<E>>& events, S stream) {
    TORCH_CHECK_NOT_IMPLEMENTED(false, "Not implemented for record_stream");
  }

  // Query event if it is completed.
  virtual bool query_event(E& event) {
    TORCH_CHECK_NOT_IMPLEMENTED(false, "Not implemented for query_event");
  }

  virtual S get_current_stream() const {
    TORCH_CHECK_NOT_IMPLEMENTED(false, "Not implemented for get_current_stream");
  }

  virtual bool stream_is_capturing(S s) const {
    TORCH_CHECK_NOT_IMPLEMENTED(false, "Not implemented for stream_is_capturing");
  }

  // instance variables

  // instance_mutex_ protects graphs_pools_, graph_pools_freeable_,
  // and captures_underway_, as well as the use_count field of
  // PrivatePools in graph_pools_ and graph_pools_freeable_.  We use a
  // shared mutex because we want to allow for multiple private pools
  // to be allocated to concurrently.  Does not protect default_pool_,
  // which has its own mutex.
  alignas(hardware_destructive_interference_size) mutable std::shared_mutex instance_mutex_;

  // We manually maintain the invariant that captures_underway_empty_
  // == captures_underway_.empty() outside of zones guarded by
  // instance_mutex_. This trick allows for us to check whether any
  // captures are currently underway without ever taking a lock on
  // instance_mutex_ (which guards captures_underway_), which is much
  // more expensive than a relaxed memory load on this atomic. Read
  // more here:
  // https://github.com/pytorch/pytorch/pull/167507#discussion_r2586418965
  // It is important to use only "relaxed" loads and stores.
  std::atomic<bool> captures_underway_empty_{true};

  // Private pools for captures
  ska::flat_hash_map<c10::MempoolId_t, std::unique_ptr<PrivatePool>, c10::MempoolIdHash>
      graph_pools_;

  ska::flat_hash_map<c10::MempoolId_t, PrivatePool*, c10::MempoolIdHash>
      graph_pools_freeable_;

  // Track active capture contexts requesting allocations to specific pools
  std::vector<
      std::pair<c10::MempoolId_t, std::function<bool(c10::Stream)>>> captures_underway_;

  // corresponds to c10::MempoolId_t{0,0}
  BlockPool default_pool_;

  // Indicates whether the event-processing thread pool is active.
  // Set to false in the destructor to signal background threads to stop.
  std::atomic<bool> active_{false};
protected:
  alignas(hardware_destructive_interference_size) HostStatsStaged stats_;
};

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free