Home / Class/ DeviceCachingAllocator Class — pytorch Architecture

DeviceCachingAllocator Class — pytorch Architecture

Architecture documentation for the DeviceCachingAllocator class in CUDACachingAllocator.cpp from the pytorch codebase.

Entity Profile

Source Code

c10/cuda/CUDACachingAllocator.cpp lines 1184–3846

class DeviceCachingAllocator {
 private:
  // lock around all operations
  mutable std::recursive_mutex mutex;

  // device statistics
  DeviceStats stats;

  c10::DeviceIndex device_id;

  // unallocated cached blocks larger than 1 MB
  BlockPool large_blocks;

  // unallocated cached blocks 1 MB or smaller
  BlockPool small_blocks;

  // allocated or in use by a stream. Holds all active allocations,
  // whether they came from graph_pools or one of the BlockPools above.
  ska::flat_hash_set<Block*> active_blocks;

  // captures_underway tracks if we are diverting some
  // allocations to a specific pool.
  // Most of the time it's empty, in which case malloc can avoid calling
  // cudaStreamGetCaptureInfo in the hot path.
  std::vector<std::pair<MempoolId_t, std::function<bool(cudaStream_t)>>>
      captures_underway;

  // tracks which pools we can use as a last resort before ooming
  ska::flat_hash_set<MempoolId_t, MempoolIdHash> use_on_oom_pools;

  // tracks which pools should not split a segment
  ska::flat_hash_set<MempoolId_t, MempoolIdHash> no_split_pools;

  // Map of blocks whose freeing is deferred until after CUDA graph capture.
  //   - Key: Block* to be freed.
  //   - Value: List of "empty nodes" inserted as free markers during capture.
  //     If the vector is empty, the block must always be deferred until capture
  //     ends.
  ska::flat_hash_map<Block*, std::vector<cudaGraphNode_t>> deferred_blocks;

  // Incremental reverse-traversal state cached per graph.
  // We never re-traverse nodes we've already seen
  struct GraphReuseContext {
    ska::flat_hash_map<cudaStream_t, ska::flat_hash_set<cudaGraphNode_t>>
        visited;
  };
  ska::flat_hash_map<MempoolId_t, CaptureId_t, MempoolIdHash>
      mempool_to_capture_id;
  ska::flat_hash_map<CaptureId_t, GraphReuseContext> graph_reuse_context;

  // outstanding cuda events
  ska::flat_hash_map<
      cuda::CUDAStream,
      std::deque<std::pair<EventPool::Event, Block*>>>
      cuda_events;

  // record used memory.
  size_t total_allocated_memory = 0;

  cudaDeviceProp device_prop;

  // maximum amount of memory that device is allowed to
  // allocate. This is set iff memory fraction is less than 1
  std::optional<size_t> allowed_memory_maximum{std::nullopt};

  // all live expandable segments
  std::vector<ExpandableSegment*> expandable_segments_;
  std::vector<c10::DeviceIndex> devices_with_peer_access_;

  bool record_history = false;
  std::unordered_set<TraceEntry::Action> skip_actions_list;

  std::atomic<CreateContextFn> context_recorder_;
  RecordContext record_context_ = RecordContext::NEVER;

  // Ring buffer for memory snapshot TraceEntry's
  RingBuffer<TraceEntry> alloc_buffer;

  // Members specific to CUDA graphs

  // Private pools for CUDA graphs
  ska::flat_hash_map<MempoolId_t, std::unique_ptr<PrivatePool>, MempoolIdHash>
      graph_pools;
  // Pools no longer referenced by any graph. Their BlockPools are eligible for
  // free_blocks. Can't be a vector or deque because we might erase entries in
  // any order. Could be an std::list, but we don't care much, access and
  // insert/erase are rare.
  ska::flat_hash_map<MempoolId_t, PrivatePool*, MempoolIdHash>
      graph_pools_freeable;

  // XXX - maybe we should generalize and have multiple events
  std::vector<OutOfMemoryObserver> oom_observers_;

  std::vector<AllocatorTraceTracker> trace_trackers_;

  // mapping from block to a stream_set, containing streams on which the block
  // was used while cudagraph capturing
  std::unordered_map<Block*, stream_set> block_to_cudagraph_stream_uses;

  // thread local compile context for each device
  static thread_local std::stack<std::string> compile_context;

  // thread local user metadata for annotating allocations
  static thread_local std::string user_metadata;

 public:
  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
  explicit DeviceCachingAllocator(c10::DeviceIndex id)
      : device_id(id),
        large_blocks(/*small=*/false),
        small_blocks(/*small=*/true) {
    C10_CUDA_CHECK(cudaGetDeviceProperties(&device_prop, id));

    setMemoryFraction(CUDAAllocatorConfig::per_process_memory_fraction());
    stats.max_split_size =
        static_cast<int64_t>(AcceleratorAllocatorConfig::max_split_size());
    context_recorder_.store(nullptr);
  }

  void recordHistory(
      bool enabled,
      CreateContextFn context_recorder,
      size_t alloc_buffer_max_entries,
      RecordContext when,
      bool clearHistory,
      const std::vector<std::string>& skip_actions) {
    std::unique_lock<std::recursive_mutex> lock(mutex);
    TORCH_CHECK(when == RecordContext::NEVER || context_recorder);
    record_history = enabled;

    // Convert string list to action enum set
    skip_actions_list.clear();
    for (const auto& action_str : skip_actions) {
      auto action = parseTraceEntryAction(action_str);
      skip_actions_list.insert(action);
    }

    context_recorder_.store(record_history ? context_recorder : nullptr);
    alloc_buffer.setMaxEntries(alloc_buffer_max_entries);
    record_context_ = enabled ? when : RecordContext::NEVER;
    if (!enabled || clearHistory) {
      alloc_buffer.clear();
    }
  }

  bool isHistoryEnabled() const {
    return record_history;
  }

  void pushCompileContext(std::string& md) {
    compile_context.push(md);
  }

  void popCompileContext() {
    if (!compile_context.empty()) {
      compile_context.pop();
    }
  }

  void setUserMetadata(const std::string& metadata) {
    user_metadata = metadata;
  }

  std::string getUserMetadata() {
    return user_metadata;
  }

  bool checkPoolLiveAllocations(
      MempoolId_t mempool_id,
      const std::unordered_set<void*>& expected_live_allocations) const {
    std::unique_lock<std::recursive_mutex> lock(mutex);

    PrivatePool* pool = nullptr;
    auto pool_it = graph_pools.find(mempool_id);
    TORCH_CHECK(pool_it != graph_pools.end(), "Could not find pool of id");
    pool = pool_it->second.get();

    TORCH_INTERNAL_ASSERT(pool != nullptr);

    size_t allocated_pool_blocks = 0;

    for (Block* b : active_blocks) {
      TORCH_INTERNAL_ASSERT(b != nullptr);
      TORCH_INTERNAL_ASSERT(b->pool != nullptr);
      if (b->allocated && b->pool->owner_PrivatePool == pool) {
        if (!expected_live_allocations.count(b->ptr)) {
          return false;
        }

        allocated_pool_blocks += 1;
      }
    }

    return allocated_pool_blocks == expected_live_allocations.size();
  }

  void attachOutOfMemoryObserver(OutOfMemoryObserver observer) {
    oom_observers_.emplace_back(std::move(observer));
  }

  void attachAllocatorTraceTracker(AllocatorTraceTracker tracker) {
    std::unique_lock<std::recursive_mutex> lock(mutex);
    trace_trackers_.emplace_back(std::move(tracker));
  }

  // Must be called outside of `mutex` or deadlocks are possible with Python
  std::shared_ptr<GatheredContext> maybeGatherContext(RecordContext level) {
    if (record_context_ < level) {
      return nullptr;
    }
    return context_recorder_.load()();
  }

  // All public methods (except the above) acquire the allocator mutex.
  // Thus, do not call a public method from another public method.

  Block* malloc(size_t orig_size, cudaStream_t stream) {
    // done outside the lock because we don't know what locks the recorder needs
    // to have...
    auto context = maybeGatherContext(RecordContext::STATE);

    std::unique_lock<std::recursive_mutex> lock(mutex);

    if (C10_LIKELY(captures_underway.empty())) {
      // Processes end-of-life events for outstanding allocations used on
      // multiple streams (checks if their GPU-side uses are complete and
      // recycles their memory if so)
      //
      // Q. Why skip process_events if a capture might be underway?
      // A. process_events involves cudaEventQueries, illegal during CUDA graph
      //    capture.
      //    Dumb simple solution: defer reclaiming these allocations until after
      //    capture. Cross-stream memory use is uncommon, so the deferral's
      //    effect on memory use during capture should be small.
      process_events(context);
    } else {
      if (CUDAAllocatorConfig::graph_capture_record_stream_reuse()) {
        // We check if there is some block that is safe to reuse on this stream
        free_safe_blocks_in_capture(context, stream);
      }
    }
    size_t size = round_size(orig_size);
    auto& pool = get_pool(size, stream);
    const size_t alloc_size = get_allocation_size(size);
    bool active_user_pool =
        pool.owner_PrivatePool && pool.owner_PrivatePool->allocator();
    // The expandable segments are only active on the default pool.
    bool is_expandable_segments_active =
        CUDAAllocatorConfig::expandable_segments() && !active_user_pool;
    AllocParams params(
        device_id,
        size,
        stream,
        &pool,
        alloc_size,
        is_expandable_segments_active);
    params.stat_types = get_stat_types_for_pool(pool);

    // First, try to get a block from the existing pool.
    bool block_found =
        // Search pool
        get_free_block(params)
        // Trigger callbacks and retry search
        || (trigger_free_memory_callbacks(params) && get_free_block(params));

    // Can't reuse an existing block; try to get a new one.
    if (!block_found) {
      // Do garbage collection if the flag is set.
      if (C10_UNLIKELY(
              allowed_memory_maximum.has_value() &&
              AcceleratorAllocatorConfig::garbage_collection_threshold() >
                  0.0)) {
        garbage_collect_cached_blocks(context);
      }

      // Attempt allocate
      // WARNING: alloc_block may release the allocator lock when calling
      // cudaMalloc. So far this function has not modified allocator state, but
      // keep in mind that any observed allocator state may change across calls
      // to alloc_block since it may release the lock.
      block_found = alloc_block(params, false, context, lock)
          // Try to use memory pools that have opted in as overflow before
          // expensive memory freeing operations.
          || try_mempool_fallback(
                        params, size, stream, device_id, alloc_size, stats)
          // Free enough available cached blocks to satisfy alloc and retry
          // alloc.
          || (release_available_cached_blocks(params, context) &&
              alloc_block(params, false, context, lock))
          // Free all non-split cached blocks and retry alloc.
          || (C10_LIKELY(captures_underway.empty()) &&
              release_cached_blocks(context, {0, 0}) &&
              alloc_block(params, true, context, lock));
    }

    if (!block_found) {
      // For any error code other than cudaErrorMemoryAllocation,
      // alloc_block should have thrown an exception already.
      TORCH_INTERNAL_ASSERT(params.err == cudaErrorMemoryAllocation);

      size_t device_free = 0;
      size_t device_total = 0;
      C10_CUDA_CHECK(cudaMemGetInfo(&device_free, &device_total));
      std::string allowed_info;

      if (allowed_memory_maximum.has_value()) {
        allowed_info =
            format_size(allowed_memory_maximum.value()) + " allowed; ";
      }

      std::string proc_info = reportProcessMemoryInfo(device_prop);

      record_trace(
          TraceEntry::OOM,
          device_free,
          params.size(),
          params.stream(),
          params.device(),
          params.pool->owner_MempoolId(),
          std::move(context));
      stats.num_ooms += 1;

      c10::reportOutOfMemoryToProfiler(
          static_cast<int64_t>(size),
          stats.allocated_bytes[static_cast<int64_t>(StatType::AGGREGATE)]
              .current,
          stats.reserved_bytes[static_cast<int64_t>(StatType::AGGREGATE)]
              .current,
          c10::Device(c10::DeviceType::CUDA, device_id));

      auto allocated_bytes =
          stats.allocated_bytes[static_cast<size_t>(StatType::AGGREGATE)]
              .current;
      auto reserved_bytes =
          stats.reserved_bytes[static_cast<size_t>(StatType::AGGREGATE)]
              .current;
      auto observers_local = oom_observers_;

      size_t allocated_in_private_pools = 0;
      auto get_size_block = [](const BlockPool& pool) {
        size_t res = 0;
        for (const auto& block : pool.blocks) {
          res += block->size;
        }
        return res;
      };
      for (const auto& p : graph_pools) {
        allocated_in_private_pools += get_size_block(p.second->large_blocks);
        allocated_in_private_pools += get_size_block(p.second->small_blocks);
      }

      std::string private_pool_msg;

      if (allocated_in_private_pools > 0) {
        private_pool_msg = "with " + format_size(allocated_in_private_pools) +
            " allocated in private pools (e.g., CUDA Graphs), ";
      }

      // Make sure we do not have the device lock before calling our
      // observers which might need hold the GIL
      // It is safe to release at this point because will no longer
      // be reading any allocator state.

      lock.unlock();

      for (const auto& obs : observers_local) {
        obs(device_id,
            alloc_size,
            allowed_memory_maximum.value_or(device_total),
            device_free);
      }

      // "total capacity": total global memory on GPU
      // "allowed": memory is allowed to use, which set by fraction.
      // "already allocated": memory allocated by the program using the
      //                      caching allocator
      // "free": free memory as reported by the CUDA API
      // "cached": memory held by the allocator but not used by the program
      //
      // The "allocated" amount  does not include memory allocated outside
      // of the caching allocator, such as memory allocated by other programs
      // or memory held by the driver.
      //
      // The sum of "allocated" + "free" + "cached" may be less than the
      // total capacity due to memory held by the driver and usage by other
      // programs.
      //
      // Note that at this point free_cached_blocks has already returned all
      // possible "cached" memory to the driver. The only remaining "cached"
      // memory is split from a larger block that is partially in-use.
      TORCH_CHECK_WITH(
          OutOfMemoryError,
          false,
          "CUDA out of memory. Tried to allocate ",
          format_size(alloc_size),
          ". GPU ",
          static_cast<int>(device_id),
          " has a total capacity of ",
          format_size(device_total),
          " of which ",
          format_size(device_free),
          " is free. ",
          proc_info,
          allowed_info,
          "Of the allocated memory ",
          format_size(allocated_bytes + allocated_in_private_pools),
          " is allocated by PyTorch, ",
          private_pool_msg,
          "and ",
          format_size(
              reserved_bytes - allocated_bytes - allocated_in_private_pools),
          " is reserved by PyTorch but unallocated.",
          " If reserved but unallocated memory is large try setting",
          " PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid"
          " fragmentation.  See documentation for Memory Management "
          " (https://docs.pytorch.org/docs/stable/notes/cuda.html#optimizing-memory-usage-with-pytorch-cuda-alloc-conf)");
    }

    bool split_remainder = should_split(
        params.block, params.size(), params.is_expandable_segments_active);
    return alloc_found_block(
        params, orig_size, std::move(context), split_remainder);
  }

  bool try_mempool_fallback(
      AllocParams& params,
      size_t size,
      cudaStream_t stream,
      c10::DeviceIndex device_idx,
      size_t alloc_size,
      DeviceStats& device_stats) {
    bool block_found = false;
    // if already trying to use a mempool, then just oom
    bool active_pool = params.pool->owner_PrivatePool;
    if (!active_pool) {
      for (MempoolId_t mempool_id : use_on_oom_pools) {
        auto tid = std::this_thread::get_id();
        auto filter = [tid](cudaStream_t) {
          return std::this_thread::get_id() == tid;
        };
        beginAllocateToPool(mempool_id, filter);
        auto& mempool = get_pool(size, stream);
        AllocParams mempool_params(
            device_idx, size, stream, &mempool, alloc_size, false);
        mempool_params.stat_types = get_stat_types_for_pool(mempool);
        block_found = get_free_block(mempool_params);
        endAllocateToPool(mempool_id);
        releasePool(mempool_id);
        if (block_found) {
          params = mempool_params;
          break;
        }
      }
    }
    return block_found;
  }

  Block* alloc_found_block(
      const AllocParams& params,
      size_t orig_size,
      std::shared_ptr<GatheredContext> context,
      bool split_remainder) {
    auto size = params.size();
    auto device = params.device();
    auto pool = params.pool;
    auto stream = params.stream();

    TORCH_INTERNAL_ASSERT(
        params.err == cudaSuccess && params.block != nullptr &&
        params.block->ptr != nullptr);
    Block* block = params.block;
    Block* remaining = nullptr;

    const bool already_split = block->is_split();
    if (split_remainder) {
      remaining = block;

      block = new Block(device, stream, size, pool, block->ptr);
      block->expandable_segment_ = remaining->expandable_segment_;
      block->prev = remaining->prev;
      if (block->prev) {
        block->prev->next = block;
      }
      block->next = remaining;

      remaining->prev = block;
      remaining->ptr = static_cast<char*>(remaining->ptr) + size;
      remaining->size -= size;
      // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
      bool inserted = pool->insert_into_blocks(remaining).second;
      TORCH_INTERNAL_ASSERT_DEBUG_ONLY(inserted);

      if (already_split && !block->expandable_segment_) {
        // An already-split inactive block is being shrunk by size bytes.
        decrease_stat_array(
            stats.inactive_split_bytes, block->size, params.stat_types);
      } else if (!block->expandable_segment_) {
        // A new split inactive block is being created from a previously unsplit
        // block, size remaining->size bytes.
        for_each_selected_stat_type(params.stat_types, [&](size_t stat_type) {
          stats.inactive_split_bytes[stat_type].increase(remaining->size);
          stats.inactive_split[stat_type].increase(1);
        });
      }

    } else if (already_split && !block->expandable_segment_) {
      // An already-split block is becoming active
      for_each_selected_stat_type(params.stat_types, [&](size_t stat_type) {
        stats.inactive_split_bytes[stat_type].decrease(block->size);
        stats.inactive_split[stat_type].decrease(1);
      });
    }

    block->allocated = true;
    block->requested_size = orig_size;

    block->context_when_allocated = std::move(context);
    record_trace(
        TraceEntry::ALLOC,
        int64_t(block->ptr),
        orig_size,
        block->stream,
        block->device,
        block->pool->owner_MempoolId(),
        block->context_when_allocated);

    // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
    bool inserted = active_blocks.insert(block).second;
    TORCH_INTERNAL_ASSERT_DEBUG_ONLY(inserted);

    for_each_selected_stat_type(params.stat_types, [&](size_t stat_type) {
      stats.allocation[stat_type].increase(1);
      stats.allocated_bytes[stat_type].increase(block->size);
      stats.active[stat_type].increase(1);
      stats.active_bytes[stat_type].increase(block->size);
      stats.requested_bytes[stat_type].increase(block->requested_size);
    });
    if (block->size >= AcceleratorAllocatorConfig::max_split_size())
      stats.oversize_allocations.increase(1);

    auto allocated_bytes_gauge =
        STATIC_GAUGE(pytorch.CUDACachingAllocator.allocated_bytes);
    allocated_bytes_gauge.record(
        stats.allocated_bytes[static_cast<int64_t>(StatType::AGGREGATE)]
            .current);

    c10::reportMemoryUsageToProfiler(
        block->ptr,
        static_cast<int64_t>(block->size),
        stats.allocated_bytes[static_cast<size_t>(StatType::AGGREGATE)].current,
        stats.reserved_bytes[static_cast<size_t>(StatType::AGGREGATE)].current,
        c10::Device(c10::DeviceType::CUDA, device));

    return block;
  }

  struct CaptureInfo {
    cudaGraph_t graph{};
    CaptureId_t capture_id{0};
    const cudaGraphNode_t* terminals{nullptr};
    size_t num_terminals{0};
    cudaStreamCaptureStatus status{cudaStreamCaptureStatusNone};
  };

  CaptureInfo stream_get_capture_info(cudaStream_t stream) {
    CaptureInfo info{};
#if (defined(CUDA_VERSION) && CUDA_VERSION >= 13000)
    C10_CUDA_CHECK(cudaStreamGetCaptureInfo(
        stream,
        &info.status,
        &info.capture_id,
        &info.graph,
        &info.terminals,
        nullptr,
        &info.num_terminals));
#else
    C10_CUDA_CHECK(cudaStreamGetCaptureInfo_v2(
        stream,
        &info.status,
        &info.capture_id,
        &info.graph,
        &info.terminals,
        &info.num_terminals));
#endif
    TORCH_INTERNAL_ASSERT(
        info.status != cudaStreamCaptureStatusInvalidated,
        "Invalid stream capture status");

    return info;
  }

  // Record "free marker" of the CUDA graph for all streams that
  // have used the block, including the allocation stream. These nodes mark the
  // last use of the block in the capture graph. Returns a vector of the
  // inserted nodes, or an empty vector if any stream is not capturing.
  std::vector<cudaGraphNode_t> record_free_markers(Block* block) {
    // Is is possible to have the same marker recorded multiple times, so we use
    // a set to avoid duplicates
    ska::flat_hash_set<cudaGraphNode_t> markers;
    cudaGraph_t owning_graph = nullptr;

    auto try_record = [&](cudaStream_t s) -> bool {
      auto info = stream_get_capture_info(s);
      if (info.status == cudaStreamCaptureStatusNone) {
        return false; // not capturing on this stream -> must defer
      }

      if (owning_graph == nullptr) {
        owning_graph = info.graph;
      }
      TORCH_INTERNAL_ASSERT(
          info.graph == owning_graph,
          "All streams in the same capture should agree on the graph");

      // Use current terminals as the free markers for the stream
      for (size_t i = 0; i < info.num_terminals; ++i) {
        auto terminal = info.terminals[i];
        markers.insert(terminal);
      }
      owning_graph = info.graph; // all streams in the same capture should agree
      return true;
    };

    // If any stream is not currently capturing, return an empty node vector.
    // An empty vector indicates that the block should be deferred for freeing
    // until after capture.

    // Allocation stream
    if (!try_record(block->stream)) {
      return {};
    }
    // Any extra streams that used this block
    for (const auto& s : block->stream_uses) {
      if (!try_record(s.stream())) {
        return {};
      }
    }
    return std::vector<cudaGraphNode_t>(markers.begin(), markers.end());
  }

  // Returns the set of "reusable" free markers in the current
  // CUDA graph capture. A free marker is considered reusable if it is a
  // predecessor of every terminal node.
  // This ensures that all future captured work will occur after the free
  // marker, making it safe to reuse.
  void update_visited(
      const CaptureInfo& info,
      ska::flat_hash_set<cudaGraphNode_t>& visited) {
    // This is the versioned cudaGraphNodeGetDependencies helper function.
    auto node_get_dependencies =
        [](cudaGraphNode_t n, cudaGraphNode_t* deps, size_t* count) -> void {
#if (defined(CUDA_VERSION) && CUDA_VERSION >= 13000)
      if (deps == nullptr) {
        C10_CUDA_CHECK(cudaGraphNodeGetDependencies(n, deps, nullptr, count));
      } else {
        SmallVector<cudaGraphEdgeData> edgeData;
        edgeData.resize(*count);
        C10_CUDA_CHECK(
            cudaGraphNodeGetDependencies(n, deps, edgeData.data(), count));
      }
#else
      C10_CUDA_CHECK(cudaGraphNodeGetDependencies(n, deps, count));
#endif
    };

    // Helper to retrieve all parent nodes (dependencies) of a given node.
    auto get_parents =
        [&](cudaGraphNode_t node) -> std::vector<cudaGraphNode_t> {
      size_t count = 0;

      node_get_dependencies(node, nullptr, &count);
      std::vector<cudaGraphNode_t> out(count);
      if (count) {
        node_get_dependencies(node, out.data(), &count);
        out.resize(count);
      }
      return out;
    };

    // For each terminal node, perform a reverse DFS to count, for each free
    // marker, how many terminals it can reach (i.e., for how many terminals it
    // is a predecessor). A free marker is reusable if it is a predecessor of
    // all terminal nodes.
    std::deque<cudaGraphNode_t> dfs;
    for (size_t i = 0; i < info.num_terminals; ++i) {
      dfs.push_back(info.terminals[i]);
    }

    while (!dfs.empty()) {
      auto v = dfs.back();
      dfs.pop_back();

      if (visited.count(v)) {
        continue;
      }
      visited.insert(v);

      auto parents = get_parents(v);
      for (auto p : parents) {
        dfs.push_back(p);
      }
    }
  }

  // A block is considered reusable during CUDA graph capture if every free
  // marker associated with the block is a predecessor of every
  // terminal node.
  //
  // This ensures that any new operation added to the graph will be attached
  // after all terminal nodes, which themselves are after all free markers. As a
  // result, all future work is guaranteed to occur after the block's last use
  // on every stream, so the block's previous lifetime ends before any new
  // lifetime begins. This check relies solely on the DAG topology and does not
  // require event queries, making it safe to use during capture.
  void free_safe_blocks_in_capture(
      const std::shared_ptr<GatheredContext>& context,
      cudaStream_t stream) {
    auto info = stream_get_capture_info(stream);

    // If there are no reusable empty nodes (e.g., not currently capturing),
    // there is nothing to do.
    if (info.status == cudaStreamCaptureStatusNone || info.num_terminals == 0) {
      return;
    }
    if (graph_reuse_context.find(info.capture_id) ==
        graph_reuse_context.end()) {
      bool found = false;
      // Use the reverse iterator to search captures_underway in LIFO order.
      for (auto it = captures_underway.rbegin(); it != captures_underway.rend();
           ++it) {
        if (it->second(stream)) {
          auto graph_pool = graph_pools.find(it->first);
          TORCH_INTERNAL_ASSERT(
              graph_pool != graph_pools.end(),
              "Could not find graph pool for capture.");
          auto mempool_id = graph_pool->first;
          graph_reuse_context[info.capture_id] = GraphReuseContext{};
          mempool_to_capture_id[mempool_id] = info.capture_id;
          found = true;
          break;
        }
      }
      TORCH_INTERNAL_ASSERT(
          found, "Could not find memory pool id for capture.");
    }
    auto& graph_context = graph_reuse_context[info.capture_id];
    auto& visited = graph_context.visited[stream];
    update_visited(info, visited);

    std::vector<Block*> blocks_to_erase;
    for (auto& [block, markers] : deferred_blocks) {
      // Skip this block if it has no markers, as we defer its freeing until
      // after graph capture. Also skip if the block was not allocated on the
      // current stream; such blocks will be freed when
      // free_safe_blocks_in_capture is attempted on that stream.
      if (markers.empty() || block->stream != stream) {
        continue;
      }

      bool is_reusable = true;
      for (auto m : markers) {
        if (!visited.count(m)) {
          is_reusable = false;
          break;
        }
      }

      if (is_reusable) {
        // Clear stream uses since the graph ensures proper synchronization.
        // No need to insert events.
        block->stream_uses.clear();

        free_block(block, context);
        blocks_to_erase.push_back(block);
      }
    }

    // Remove blocks that were freed from the deferred_blocks map.
    for (auto* block : blocks_to_erase) {
      deferred_blocks.erase(block);
    }
  }

  void free(Block* block) {
    std::shared_ptr<GatheredContext> context =
        maybeGatherContext(RecordContext::ALL);
    std::lock_guard<std::recursive_mutex> lock(mutex);

    block->allocated = false;

    // following logic might modifying underlying Block, causing the size
    // changed. We store ahead for reporting
    auto orig_block_ptr = block->ptr;
    auto orig_block_size = block->size;

    StatTypes stat_types = get_stat_types_for_pool(*block->pool);
    for_each_selected_stat_type(stat_types, [&](size_t stat_type) {
      stats.allocation[stat_type].decrease(1);
      stats.allocated_bytes[stat_type].decrease(block->size);
    });
    auto allocated_bytes_gauge =
        STATIC_GAUGE(pytorch.CUDACachingAllocator.allocated_bytes);
    allocated_bytes_gauge.record(
        stats.allocated_bytes[static_cast<int64_t>(StatType::AGGREGATE)]
            .current);

    record_trace(
        TraceEntry::FREE_REQUESTED,
        int64_t(block->ptr),
        block->requested_size,
        block->stream,
        block->device,
        block->pool->owner_MempoolId(),
        context ? context : block->context_when_allocated);

    if (block->size >= AcceleratorAllocatorConfig::max_split_size())
      stats.oversize_allocations.decrease(1);

    // If the block has been used on more than one stream, handle accordingly.
    if (!block->stream_uses.empty()) {
      if (C10_UNLIKELY(!captures_underway.empty())) {
        if (CUDAAllocatorConfig::graph_capture_record_stream_reuse()) {
          // record_free_markers returns a vector of free markers,
          // or an empty vector if any associated stream is not currently
          // capturing. The empty vector means that we will defer the free until
          // capture is finished.
          deferred_blocks.emplace(block, record_free_markers(block));
        } else {
          // If graph_capture_record_stream_reuse is not enabled, always defer
          // the free until capture is finished.
          deferred_blocks.emplace(block, std::vector<cudaGraphNode_t>{});
        }
      } else {
        // If not in a capture, insert events for the block.
        insert_events(block);
      }
    } else {
      free_block(block, context);
    }

    c10::reportMemoryUsageToProfiler(
        orig_block_ptr,
        -static_cast<int64_t>(orig_block_size),
        stats.allocated_bytes[static_cast<size_t>(StatType::AGGREGATE)].current,
        stats.reserved_bytes[static_cast<size_t>(StatType::AGGREGATE)].current,
        c10::Device(c10::DeviceType::CUDA, block->device));
  }

  void* getBaseAllocation(Block* block, size_t* outSize) {
    std::lock_guard<std::recursive_mutex> lock(mutex);
    TORCH_CHECK(
        !block->expandable_segment_,
        "Tensors allocated with expandable_segments:True cannot be shared between processes. Consider using expandable_segments:False in data loading workers via torch.cuda.memory._set_allocator_settings('expandable_segments:False')");
    while (block->prev) {
      block = block->prev;
    }
    void* basePtr = block->ptr;
    if (outSize) {
      size_t size = 0;
      while (block) {
        size += block->size;
        block = block->next;
      }
      *outSize = size;
    }
    return basePtr;
  }

  ShareableHandle shareIpcHandle(Block* block) {
    std::lock_guard<std::recursive_mutex> lock(mutex);
    std::ostringstream ss;
    ss.put(SHAREABLE_HANDLE_VERSION);
    ptrdiff_t offset = 0;
    if (!block->expandable_segment_) {
      ss.put(SHAREABLE_CUDA_MALLOC);
      Block* base_block = block;
      while (base_block->prev) {
        base_block = base_block->prev;
      }
      offset = static_cast<const char*>(block->ptr) -
          static_cast<const char*>(base_block->ptr);
      cudaIpcMemHandle_t handle;
      C10_CUDA_CHECK(cudaIpcGetMemHandle(&handle, base_block->ptr));
      ss.write(reinterpret_cast<const char*>(&handle), CUDA_IPC_HANDLE_SIZE);
    } else {
      ss.put(SHAREABLE_CUDA_EXPANDABLE_SEGMENT);
      auto full_range = block->expandable_segment_->share(
          SegmentRange(block->ptr, block->size), ss);
      offset = static_cast<const char*>(block->ptr) - full_range.ptr;
    }
    return ShareableHandle{offset, ss.str()};
  }

  void recordStream(Block* block, cuda::CUDAStream stream) {
    std::lock_guard<std::recursive_mutex> lock(mutex);
    if (stream.stream() == block->stream) {
      // ignore uses on the allocation stream, since those don't require any
      // special synchronization
      return;
    }
    block->stream_uses.insert(stream);
    if (C10_UNLIKELY(!captures_underway.empty())) {
      block_to_cudagraph_stream_uses[block].insert(stream);
    }
  }

  /** get memory fraction limiting maximum allocated memory **/
  double getMemoryFraction() {
    if (!allowed_memory_maximum.has_value()) {
      return 1.0;
    }

    return static_cast<double>(allowed_memory_maximum.value()) /
        static_cast<double>(device_prop.totalGlobalMem);
  }

  /** set memory fraction to limit maximum allocated memory **/
  void setMemoryFraction(double fraction) {
    TORCH_CHECK(
        0 <= fraction && fraction <= 1,
        "invalid fraction:",
        fraction,
        ". Please set within [0, 1].");
    allowed_memory_maximum = std::nullopt;
    if (fraction < 1.0) {
      allowed_memory_maximum = static_cast<size_t>(
          fraction * static_cast<double>(device_prop.totalGlobalMem));
    }
  }

  /** get expandable segment size for all the streams on device **/
  std::vector<StreamSegmentSize> getExpandableSegmentSizes() {
    std::lock_guard<std::recursive_mutex> lock(mutex);
    std::vector<StreamSegmentSize> sizes;
    for (auto& segment : expandable_segments_) {
      if (!segment->getStream()) {
        continue;
      }
      sizes.emplace_back(
          segment->getStream(),
          segment->getSegmentSize() == kSmallBuffer,
          segment->getMappedSize());
    }
    return sizes;
  }

  /** returns cached blocks to the system allocator **/
  void emptyCache(MempoolId_t mempool_id) {
    auto context = maybeGatherContext(RecordContext::ALL);
    std::lock_guard<std::recursive_mutex> lock(mutex);
    release_cached_blocks(context, mempool_id);
  }

  /** Retrieves size of largest unused block held by the memory cache **/
  void cacheInfo(size_t* largest) {
    std::lock_guard<std::recursive_mutex> lock(mutex);
    if (*largest ==
        0) { // make an initial guess if a zero *largest is passed in
      size_t tmp_bytes = 0;
      C10_CUDA_CHECK(cudaMemGetInfo(
          largest, // Use free memory as an optimistic initial guess of *largest
          &tmp_bytes));
    }
    cache_info_aux(large_blocks, largest);
    cache_info_aux(small_blocks, largest);
    for (const auto& gp : graph_pools) {
      cache_info_aux(gp.second->large_blocks, largest);
      cache_info_aux(gp.second->small_blocks, largest);
    }
  }

  /** Returns a copy of the memory allocator stats **/
  DeviceStats getStats() const {
    std::lock_guard<std::recursive_mutex> lock(mutex);
    return stats;
  }

  /** Resets the historical accumulation stats for the device **/
  void resetAccumulatedStats() {
    std::lock_guard<std::recursive_mutex> lock(mutex);

    for (const auto statType :
         c10::irange(static_cast<size_t>(StatType::NUM_TYPES))) {
      stats.allocation[statType].reset_accumulated();
      stats.segment[statType].reset_accumulated();
      stats.active[statType].reset_accumulated();
      stats.inactive_split[statType].reset_accumulated();
      stats.allocated_bytes[statType].reset_accumulated();
      stats.reserved_bytes[statType].reset_accumulated();
      stats.active_bytes[statType].reset_accumulated();
      stats.inactive_split_bytes[statType].reset_accumulated();
      stats.requested_bytes[statType].reset_accumulated();
    }

    stats.num_alloc_retries = 0;
    stats.num_ooms = 0;
    stats.num_sync_all_streams = 0;
    stats.num_device_alloc = 0;
    stats.num_device_free = 0;
    stats.oversize_allocations.reset_accumulated();
    stats.oversize_segments.reset_accumulated();
  }

  /** Resets the historical peak stats for the device **/
  void resetPeakStats() {
    std::lock_guard<std::recursive_mutex> lock(mutex);

    for (const auto statType :
         c10::irange(static_cast<size_t>(StatType::NUM_TYPES))) {
      stats.allocation[statType].reset_peak();
      stats.segment[statType].reset_peak();
      stats.active[statType].reset_peak();
      stats.inactive_split[statType].reset_peak();
      stats.allocated_bytes[statType].reset_peak();
      stats.reserved_bytes[statType].reset_peak();
      stats.active_bytes[statType].reset_peak();
      stats.inactive_split_bytes[statType].reset_peak();
      stats.requested_bytes[statType].reset_peak();
    }
    stats.oversize_allocations.reset_peak();
    stats.oversize_segments.reset_peak();
  }

  /* Checkpoint the state of a private pool necessary to return it to its
   * current state */
  std::unique_ptr<PrivatePoolState> getCheckpointState(MempoolId_t id) {
    auto context = maybeGatherContext(RecordContext::ALL);
    std::lock_guard<std::recursive_mutex> lock(mutex);
    insert_events_deferred_until_no_capture(context);

    auto pool = graph_pools.find(id);
    if (pool != graph_pools.end()) {
      auto private_pool_head_blocks =
          get_private_pool_head_blocks(pool->second.get());
      return std::make_unique<PrivatePoolState>(id, private_pool_head_blocks);
    } else if (graph_pools_freeable.count(id)) {
      TORCH_CHECK(false, "Not expected to checkpoint freeable graph");
    } else {
      TORCH_CHECK(false, "Could not find pool of id");
    }
  }

  void freeBlocksAllocatedToPool(PrivatePool* private_pool, RestoreResult& rr) {
    auto pool_blocks = get_private_pool_head_blocks(private_pool);

    std::vector<Block*> head_blocks;
    for (Block* block : pool_blocks) {
      if (block->prev == nullptr) {
        head_blocks.push_back(block);
      }
    }

    for (Block* block : head_blocks) {
      Block* curr = block;

      while (curr) {
        // When we free a block, its pointer should never change
        // only its adjacent blocks, so free, then look at pointer
        if (curr->allocated) {
          TORCH_CHECK(
              curr->event_count == 0,
              "Events should have synchronized when setting checkpointed block");
          rr.allocations_freed.push_back(curr->ptr);
          free(curr);
          TORCH_CHECK(!curr->allocated)
        }
        curr = curr->next;
      }
    }

    for (Block* b : get_private_pool_head_blocks(private_pool)) {
      Block* curr = b;
      while (curr) {
        TORCH_CHECK(!curr->allocated);
        curr = curr->next;
      }
    }
  }

  // checkpoint the state of an allocation that may have been
  // split into multiple blocks
  void setSegmentStateToCheckpoint(
      Block* block,
      SegmentState& segment,
      const std::shared_ptr<GatheredContext>& context,
      RestoreResult& rr) {
    Block* curr_block = block;
    Block* last_block = block;

    TORCH_INTERNAL_ASSERT(block->pool);
    BlockPool& pool = *block->pool;
    const auto segment_len = segment.blocks.size();

    auto is_unmapped_tail = [](Block* b) {
      return b->next == nullptr && !b->mapped && !b->allocated;
    };

    // allocate all blocks in the segment
    for (size_t i = 0; i < segment_len; ++i) {
      // Note [Last block when restoring checkpoint state]
      // The last block in expandable segment is one of the following cases:
      // - case 1: a unmapped tail containing the remaining amount of available
      //      unmapped virtual address space.
      // - case 2: segment has grown larger since we record the checkpoint.
      //      After we allocate all blocks in checkpoint, there should still
      //      be 1 block for extra mapped memory followed by another block
      //      for the unmapped tail.
      if (i == segment_len - 1 && curr_block->expandable_segment_) {
        bool valid_structure =
            /* case 1*/ is_unmapped_tail(curr_block) ||
            /* case 2*/
            (curr_block->mapped && !curr_block->allocated && curr_block->next &&
             is_unmapped_tail(curr_block->next));
        TORCH_CHECK(
            valid_structure,
            "Invalid expandable segment structure during checkpoint restore");
        continue;
      }

      auto& block_state = segment.blocks.at(i);
      AllocParams params(
          block_state.device,
          block_state.size,
          block_state.stream,
          &pool,
          block_state.size,
          curr_block->expandable_segment_ != nullptr);
      pool.blocks.erase(curr_block);
      params.block = curr_block;
      params.stat_types = get_stat_types_for_pool(pool);

      // splitting a block depends on `max_split_size`, which may have changed
      // between when checkpoint was taken and now, so we make sure to recreate
      // the behavior from the checkpoint. Keep splitting as long as there is
      // space left in the block because the block is already the size of how it
      // appears in the segment, so any leftover space belongs to the next
      // block.
      bool split = curr_block->size > block_state.size;

      // curr_block will become next pointer if it is split, so reassign with
      // the returned value
      curr_block = alloc_found_block(params, block_state.size, context, split);

      TORCH_CHECK(curr_block->ptr == block_state.ptr);
      TORCH_CHECK(curr_block->size == block_state.size);

      last_block = curr_block;
      curr_block = curr_block->next;

      TORCH_CHECK((curr_block != nullptr) == ((i + 1) < (segment_len)));
    }

    while (last_block->prev) {
      last_block = last_block->prev;
    }

    // free blocks that are not allocated in the checkpoint
    curr_block = last_block;

    for (size_t i = 0; i < segment_len; ++i, curr_block = curr_block->next) {
      if (i == segment_len - 1 && curr_block->expandable_segment_) {
        // See Note [Last block when restoring checkpoint state]
        bool valid_structure =
            /* case 1*/ is_unmapped_tail(curr_block) ||
            /* case 2*/
            (curr_block->mapped && !curr_block->allocated && curr_block->next &&
             is_unmapped_tail(curr_block->next));
        TORCH_CHECK(
            valid_structure,
            "Invalid expandable segment structure during checkpoint restore");
        continue;
      }

      auto& block_state = segment.blocks.at(i);
      TORCH_INTERNAL_ASSERT(curr_block != nullptr);

      if (block_state.allocated) {
        rr.allocations_created.push_back(curr_block);
        continue;
      }

      free(curr_block);

      TORCH_CHECK(curr_block->ptr == block_state.ptr);
      TORCH_CHECK(curr_block->allocated == block_state.allocated);

      // See Note [Last block when restoring checkpoint state]
      // The last mapped block (i == segment_len - 2) may merge with extra
      // mapped memory from segment growth, making it larger than checkpoint
      bool is_last_mapped =
          (i == segment_len - 2) && curr_block->expandable_segment_;
      TORCH_CHECK(
          curr_block->size == block_state.size ||
          (is_last_mapped && curr_block->size >= block_state.size));
    }
  }

  /**
   * Note [Checkpointing PrivatePoolState]
   *
   * Refer above to Note [Interaction with CUDA graph capture]. Allocations made
   * during graph capture are made from a separate private pool. During graph
   * capture allocations behave as usual. During graph replay the allocator
   * state does not change even as new tensors are created. The private pool
   * will not free its blocks to the main caching allocator until cuda graph use
   * is finished to prevent an allocation from eager clobbering the memory from
   * a live but unaccounted for tensor that was created during replay.
   *
   * `make_graphed_callables`, a series of separate callables chained in
   * successive cuda graphs, can share a memory pool because after a cuda graph
   * recording the allocations in the shared private pool exactly reflect the
   * tensors that are allocated.
   *
   * We would like to extend callable chaining to support a graphed callable
   * tree. In this scenario, we have a tree of callable chains which will be
   * captured with cuda graphs. In the diagram below, we have a tree with four
   * callables, A, B, C, and D. Suppose we have captured, and subsequently
   * replayed, A, B, and C. Then on a new invocation, we replay A and B, but
   * would now like to record D. At this point the private pool will not reflect
   * any of the live tensors created during graph replay. Allocations made
   * during a new recording with the pool could overwrite those live tensors.
   *
   * In order to record a new graph capture after replaying prior callables in
   * the tree, we need the allocator to reflect the state of the live tensors.
   * We checkpoint the state of the private pool after each recording, and then
   * reapply it when we are starting a new recording chain. Additionally, we
   * must free the allocations for any tensors that died between the end of our
   * previous graph replaying and our new recording. All of the allocated
   * segments that existed in the checkpointed state must still exist in the
   * pool. There may also exist new allocated blocks.
   * (TODO : link note [live tensors between iterations] when it exists). For
   * every block that is currently allocated but no allocated in the snapshot,
   * we will return a pointer to their block.
   *.
   *
   *
   *  ---------------> A ---------------> B ---------------> C
   *                                      |
   *                                      |
   *                                      |
   *                                      |
   *                                      ╰ ---------------> D
   */
  RestoreResult setCheckpointPoolState(PrivatePoolState& pps) {
    // To reset the caching allocator state we will
    // - Free all the blocks currently allocated to the pool (see [live tensors
    // between iterations])
    // - Allocate all the blocks in a checkpointed segment, whether they are
    // live or not
    // - Free the blocks in a checkpointed segment which are not live
    // This could be optimized, but it nicely reuses exiting apis, and this
    // is not on the hot path.

    // following `done outside the lock because we don't know what locks the
    // recorder needs to have...`

    std::shared_ptr<GatheredContext> context =
        maybeGatherContext(RecordContext::STATE);

    std::lock_guard<std::recursive_mutex> lock(mutex);

    RestoreResult rr;

    TORCH_CHECK(
        !graph_pools_freeable.count(pps.owner_id),
        "Not expected to checkpoint freeable graph");

    auto pool = graph_pools.find(pps.owner_id);
    TORCH_CHECK(pool != graph_pools.end(), "Could not find private pool id");

    PrivatePool* private_pool = pool->second.get();

    freeBlocksAllocatedToPool(private_pool, rr);

    std::unordered_map<void*, Block*> ptrs_to_blocks;
    // at this point, all of the blocks should be free, so they will all be in
    // the block set
    for (Block* block : private_pool->small_blocks.blocks) {
      ptrs_to_blocks[block->ptr] = block;
    }
    for (Block* block : private_pool->large_blocks.blocks) {
      ptrs_to_blocks[block->ptr] = block;
    }

    for (auto& segment : pps.segments) {
      auto ptr = segment.blocks.at(0).ptr;
      TORCH_CHECK(ptrs_to_blocks.count(ptr), " could not find ", ptr)
      auto block = ptrs_to_blocks[ptr];

      setSegmentStateToCheckpoint(block, segment, context, rr);
    }
    return rr;
  }

  /** Dump a complete snapshot of the memory held by the allocator. Potentially
   * VERY expensive. **/
  std::vector<SegmentInfo> snapshot(MempoolId_t mempool_id) {
    std::lock_guard<std::recursive_mutex> lock(mutex);

    std::vector<Block*> all_blocks;

    if (mempool_id.first != 0 || mempool_id.second != 0) {
      // If there is an active mempool, we find the corresponding PrivatePool
      // in graph_pools and only return the blocks from it.
      auto pool = graph_pools.find(mempool_id);
      if (pool != graph_pools.end()) {
        all_blocks = get_private_pool_head_blocks(pool->second.get());
      }
    } else {
      // When snapshot is called with non-default mempool_id, we return
      // all the blocks in the CUDACachingAllocator (as returned by
      // get_all_blocks).
      all_blocks = get_all_blocks();
    }

    size_t total_active = 0;
    std::vector<SegmentInfo> result;

    for (const Block* const head_block : all_blocks) {
      // For expandable segments, we report one segment for each contiguous
      // mapped range of memory
      if (head_block->prev && head_block->prev->mapped) {
        continue;
      }
      result.emplace_back();
      SegmentInfo& segment_info = result.back();
      segment_info.device = head_block->device;
      segment_info.address = reinterpret_cast<size_t>(head_block->ptr);
      segment_info.stream = reinterpret_cast<void*>(head_block->stream);
      segment_info.is_large = (!head_block->pool->is_small);
      segment_info.is_expandable = head_block->expandable_segment_;
      segment_info.context_when_allocated =
          head_block->context_when_segment_allocated;
      MempoolId_t id = head_block->pool->owner_MempoolId();
      if ((mempool_id.first == 0 && mempool_id.second == 0) ||
          id == mempool_id) {
        segment_info.owner_private_pool_id = id;
      }
      segment_info.registration_counter = head_block->registration_counter;

      const Block* block = head_block;
      while (block != nullptr && block->mapped) {
        segment_info.blocks.emplace_back();
        BlockInfo& block_info = segment_info.blocks.back();

        block_info.size = block->size;
        block_info.requested_size = block->requested_size;
        block_info.allocated = block->allocated;
        block_info.active = block->allocated || (block->event_count > 0) ||
            !block->stream_uses.empty();

        segment_info.total_size += block_info.size;
        if (block_info.allocated) {
          segment_info.allocated_size += block_info.size;
        }
        if (block_info.active) {
          segment_info.active_size += block_info.size;
          segment_info.requested_size += block_info.requested_size;
        }
        block_info.context_when_allocated = block->context_when_allocated;
        block = block->next;
      }
      total_active += segment_info.active_size;
    }

    std::sort(
        result.begin(),
        result.end(),
        [](const SegmentInfo& a, const SegmentInfo& b) {
          return a.address < b.address;
        });

    record_trace(
        TraceEntry::SNAPSHOT, 0, total_active, nullptr, 0, mempool_id, nullptr);
    return result;
  }

  std::vector<TraceEntry> trace(
      const std::function<time_t(approx_time_t)>& tsc_to_us) const {
    std::lock_guard<std::recursive_mutex> lock(mutex);
    std::vector<TraceEntry> result;
    alloc_buffer.getEntries(result);

    // Convert all the timestamps from tsc to epoch time in microseconds.
    for (auto& te : result) {
      te.time_.t_ = tsc_to_us(te.time_.approx_t_);
    }
    return result;
  }

  // This function takes the size and number of divisions argument and rounds
  // up the size argument for the nearest power-of-2 division.
  // For example, if we need to round-up 1200 and number of divisions is 4,
  // the size 1200 lies between 1024 and 2048 and if we do 4 divisions between
  // them, the values are 1024, 1280, 1536, and 1792. So the function will
  // return 1280 as the nearest ceiling of power-2 division.
  static size_t roundup_power2_next_division(size_t size, size_t divisions) {
    if (llvm::isPowerOf2_64(size)) {
      return size;
    }

    TORCH_CHECK(divisions >= 2, "Only 2 or more divisions are supported");

    // divide the space between these 2's power into equal divisions
    // If division is zero, return the power-of-2 ceiling.
    size_t power2_floor = llvm::PowerOf2Floor(size);
    size_t power2_division =
        power2_floor >> (63 - llvm::countLeadingZeros(divisions));
    if (C10_UNLIKELY(power2_division == 0)) {
      return (power2_floor << 1);
    }
    size_t round_size_floor = size & (~(power2_division - 1));
    return (round_size_floor == size) ? size
                                      : round_size_floor + power2_division;
  }

  static size_t round_size(size_t size) {
    if (size < kMinBlockSize) {
      return kMinBlockSize;
    } else {
      auto divisions =
          AcceleratorAllocatorConfig::roundup_power2_divisions(size);
      if (divisions > 1 && size > (kMinBlockSize * divisions)) {
        return roundup_power2_next_division(size, divisions);
      } else {
        return kMinBlockSize * ((size + kMinBlockSize - 1) / kMinBlockSize);
      }
    }
  }

  void createOrIncrefPool(
      MempoolId_t mempool_id,
      std::shared_ptr<CUDAAllocator> allocator) {
    // Create a PrivatePool object if it does not exist yet
    // and increment its use_count
    std::lock_guard<std::recursive_mutex> lock(mutex);
    create_or_incref_pool(mempool_id, std::move(allocator));
  }

  void setUseOnOOM(MempoolId_t mempool_id, bool use_on_oom) {
    std::lock_guard<std::recursive_mutex> lock(mutex);
    if (use_on_oom) {
      use_on_oom_pools.insert(mempool_id);
    } else {
      use_on_oom_pools.erase(mempool_id);
    }
  }

  void setNoSplit(MempoolId_t mempool_id) {
    // Choose if this pool should not split a segment
    std::lock_guard<std::recursive_mutex> lock(mutex);
    no_split_pools.insert(mempool_id);
  }

  // See Note [Interaction with CUDA graph capture]

  // Called by CUDAGraph::capture_begin
  void beginAllocateToPool(
      MempoolId_t mempool_id,
      std::function<bool(cudaStream_t)> filter) {
    std::lock_guard<std::recursive_mutex> lock(mutex);
    create_or_incref_pool(mempool_id);
    for (auto it = captures_underway.begin(); it != captures_underway.end();
         ++it) {
      TORCH_CHECK(
          it->first != mempool_id,
          "beginAllocateToPool: already recording to mempool_id");
    }
    captures_underway.emplace_back(mempool_id, std::move(filter));
  }

  // Called by CUDAGraph::capture_end
  void endAllocateToPool(MempoolId_t mempool_id) {
    std::lock_guard<std::recursive_mutex> lock(mutex);

    if (CUDAAllocatorConfig::graph_capture_record_stream_reuse() &&
        !graph_reuse_context.empty()) {
      auto capture_id = mempool_to_capture_id[mempool_id];
      auto graph_context = graph_reuse_context[capture_id];
      for (auto& [stream, _] : graph_context.visited) {
        TORCH_INTERNAL_ASSERT(
            stream_get_capture_info(stream).status ==
                cudaStreamCaptureStatusNone,
            "This stream should not be capturing when the capture is ended");
      }
      graph_reuse_context.erase(capture_id);
      mempool_to_capture_id.erase(mempool_id);
    }

    for (auto it = captures_underway.begin(); it != captures_underway.end();
         ++it) {
      if (it->first == mempool_id) {
        captures_underway.erase(it);
        return;
      }
    }
    TORCH_CHECK(
        false, "endAllocatePool: not currently recording to mempool_id");
  }

  // Called by CUDAGraph::reset and MemPool::~MemPool()
  void releasePool(MempoolId_t mempool_id) {
    std::lock_guard<std::recursive_mutex> lock(mutex);
    // The instantiated cudaGraphExec_t has been destroyed. We can't blindly
    // delete and cudaFree the mempool its capture used, because
    //  1. other graph(s) might share the same pool
    //  2. the user might still hold references to output tensors allocated
    //  during capture.
    // To handle 1 and 2, we track the number of graphs using this particular
    // mempool. When the count reaches 0, we tell free_cached_blocks it may now
    // cudaFree blocks from this graph's pool when it discovers they're unused
    // (unsplit).
    auto pp = get_private_pool(mempool_id);
    auto uc = --(pp->use_count);
    TORCH_INTERNAL_ASSERT(uc >= 0);
    if (uc == 0) {
      // Allows free_cached_blocks to begin cudaFreeing this pool's memory,
      // and makes sure this pool wasn't somehow made freeable already.
      // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
      bool inserted = graph_pools_freeable.insert({mempool_id, pp}).second;
      TORCH_INTERNAL_ASSERT(inserted);
    }
  }

  int getPoolUseCount(MempoolId_t mempool_id) const {
    std::lock_guard<std::recursive_mutex> lock(mutex);
    auto pp = get_private_pool(mempool_id);
    return pp->use_count;
  }

  void addPeerAccess(c10::DeviceIndex dev_to_access) {
    std::lock_guard<std::recursive_mutex> lock(mutex);
    if (std::find(
            devices_with_peer_access_.begin(),
            devices_with_peer_access_.end(),
            dev_to_access) != devices_with_peer_access_.end()) {
      return;
    }
    devices_with_peer_access_.push_back(dev_to_access);
    for (auto& es : expandable_segments_) {
      es->addPeer(dev_to_access);
    }
  }
  std::vector<c10::DeviceIndex> peers() const {
    std::lock_guard<std::recursive_mutex> lock(mutex);
    return devices_with_peer_access_;
  }

  bool hasAllocatedExpandableSegments() const {
    return !expandable_segments_.empty();
  }

 private:
  // All private methods do not acquire the allocator mutex.

  std::vector<Block*> get_all_blocks() const {
    std::vector<Block*> blocks;
    blocks.insert(
        blocks.end(), small_blocks.blocks.begin(), small_blocks.blocks.end());
    blocks.insert(
        blocks.end(), large_blocks.blocks.begin(), large_blocks.blocks.end());
    for (const auto& gp : graph_pools) {
      blocks.insert(
          blocks.end(),
          gp.second->small_blocks.blocks.begin(),
          gp.second->small_blocks.blocks.end());
      blocks.insert(
          blocks.end(),
          gp.second->large_blocks.blocks.begin(),
          gp.second->large_blocks.blocks.end());
    }
    blocks.insert(blocks.end(), active_blocks.begin(), active_blocks.end());
    return blocks;
  }

  std::vector<Block*> get_private_pool_head_blocks(PrivatePool* pool) const {
    std::vector<Block*> blocks;
    for (Block* b : active_blocks) {
      if ((b->pool == &pool->small_blocks || b->pool == &pool->large_blocks) &&
          b->prev == nullptr) {
        blocks.push_back(b);
      }
    }

    for (Block* b : pool->small_blocks.blocks) {
      if (b->prev == nullptr) {
        blocks.push_back(b);
      }
    }
    for (Block* b : pool->large_blocks.blocks) {
      if (b->prev == nullptr) {
        blocks.push_back(b);
      }
    }

    return blocks;
  }

  void create_or_incref_pool(
      MempoolId_t mempool_id,
      std::shared_ptr<CUDAAllocator> allocator = nullptr) {
    auto it = graph_pools.find(mempool_id);
    if (it == graph_pools.end()) {
      // mempool_id does not reference an existing pool.
      // Make a new pool for CUDAGraph capture or torch.cuda.use_mem_pool
      // usage. use_count is initially 1, which means the pool is
      // being used since somebody called createOrIncrefPool.
      graph_pools.emplace(
          mempool_id,
          std::make_unique<PrivatePool>(mempool_id, std::move(allocator)));
    } else {
      // mempool_id references an existing pool, which the current CUDAGraph
      // capture or torch.cuda.use_mem_pool will
      // share. Check this pool is live (at least one other capture already
      // references it). Increment it to establish the usage.
      TORCH_INTERNAL_ASSERT(it->second->use_count > 0);
      TORCH_INTERNAL_ASSERT(!allocator);
      it->second->use_count++;
    }
  }

  PrivatePool* get_private_pool(MempoolId_t mempool_id) const {
    auto it = graph_pools.find(mempool_id);
    TORCH_INTERNAL_ASSERT(it != graph_pools.end());
    return it->second.get();
  }

  // returns the smallest possible address in any segment
  // where there is enough free address space to fit size
  // may be composed of free and unmapped segments
  Block* find_expandable_block(
      c10::DeviceIndex device,
      cudaStream_t stream,
      BlockPool* pool,
      size_t size) {
    Block key(device, stream, 0);

    auto allocatable = [](Block* b) {
      return b && !b->allocated && b->event_count == 0 &&
          b->stream_uses.empty();
    };
    auto has_available_address_space = [&](Block* b) {
      size_t bytes = 0;
      while (bytes < size && allocatable(b)) {
        bytes += b->size;
        b = b->next;
      }
      return bytes >= size;
    };
    for (auto it = pool->unmapped.lower_bound(&key);
         it != pool->unmapped.end() && (*it)->stream == stream;
         ++it) {
      Block* c = *it;
      // we found the lowest address of an unmapped segment
      // but there might be a free segment we can also use
      // right before it
      if (allocatable(c->prev)) {
        c = c->prev;
      }
      if (has_available_address_space(c)) {
        return c;
      }
    }
    auto segment_size = pool->is_small
        ? kSmallBuffer
        : AcceleratorAllocatorConfig::large_segment_size();
    expandable_segments_.emplace_back(new ExpandableSegment(
        device, stream, segment_size, devices_with_peer_access_));

    ExpandableSegment* es = expandable_segments_.back();
    Block* candidate = new Block(device, stream, es->size(), pool, es->ptr());
    candidate->mapped = false;
    candidate->expandable_segment_ = es;
    pool->unmapped.insert(candidate);
    return candidate;
  }

  bool map_block(
      Block* to_map,
      size_t size,
      const std::shared_ptr<GatheredContext>& ctx) {
    TORCH_INTERNAL_ASSERT(!to_map->mapped && size <= to_map->size);
    TORCH_INTERNAL_ASSERT(
        !to_map->context_when_allocated); // unmapped blocks should not keep
                                          // history
    auto mapped_range =
        to_map->expandable_segment_->map(SegmentRange{to_map->ptr, size});
    // failed to map the memory
    if (mapped_range.size == 0) {
      return false;
    }
    TORCH_INTERNAL_ASSERT(
        mapped_range.ptr == to_map->ptr && mapped_range.size >= size);

    BlockPool& pool = *to_map->pool;
    pool.unmapped.erase(to_map);
    to_map->mapped = true;

    if (mapped_range.size < to_map->size) {
      // to_map -> remaining -> to_map->next(?)
      Block* remaining = new Block(
          to_map->device,
          to_map->stream,
          to_map->size - mapped_range.size,
          &pool,
          static_cast<char*>(to_map->ptr) + mapped_range.size);
      remaining->mapped = false;
      remaining->expandable_segment_ = to_map->expandable_segment_;
      remaining->splice(to_map, to_map->next);
      pool.unmapped.insert(remaining);
      to_map->size = mapped_range.size;
    }

    try_merge_blocks(to_map, to_map->prev, pool);
    try_merge_blocks(to_map, to_map->next, pool);

    pool.insert_into_blocks(to_map);

    // update statistics
    total_allocated_memory += mapped_range.size;
    StatTypes stat_types = get_stat_types_for_pool(*to_map->pool);
    for_each_selected_stat_type(stat_types, [&](size_t stat_type) {
      stats.reserved_bytes[stat_type].increase(mapped_range.size);
    });
    auto reserved_bytes_gauge =
        STATIC_GAUGE(pytorch.CUDACachingAllocator.reserved_bytes);
    reserved_bytes_gauge.record(
        stats.reserved_bytes[static_cast<int64_t>(StatType::AGGREGATE)]
            .current);

    stats.num_device_alloc++;
    record_trace(
        TraceEntry::SEGMENT_MAP,
        int64_t(mapped_range.ptr),
        mapped_range.size,
        to_map->stream,
        to_map->device,
        to_map->pool->owner_MempoolId(),
        ctx);
    if (!to_map->prev && !to_map->context_when_segment_allocated) {
      to_map->context_when_segment_allocated = ctx;
    }

    return true;
  }

  Block* try_allocate_expandable_block(
      c10::DeviceIndex device,
      cudaStream_t stream,
      BlockPool* pool,
      size_t size,
      const std::shared_ptr<GatheredContext>& ctx) {
    Block* candidate = find_expandable_block(device, stream, pool, size);
    // Candidate is now a list free/unmapped blocks with at least size room:
    // unmapped -> null
    // unmapped -> free -> *
    // free -> unmapped -> *

    if (!candidate->mapped &&
        !map_block(candidate, std::min(candidate->size, size), ctx)) {
      return nullptr;
    }
    TORCH_INTERNAL_ASSERT(candidate->mapped);

    while (candidate->size < size) {
      // invariant: free -> unmapped -> *
      // map_block will map some of unmapped and merge with free
      auto remaining = size - candidate->size;
      auto new_candidate = candidate->next;
      if (!map_block(
              new_candidate, std::min(remaining, candidate->next->size), ctx)) {
        return nullptr;
      }
      candidate = new_candidate;
    }
    pool->blocks.erase(candidate);
    return candidate;
  }

  /** moves a block into a pool of cached free blocks */
  void free_block(
      Block* block,
      const std::shared_ptr<GatheredContext>& context) {
    TORCH_INTERNAL_ASSERT(
        !block->allocated && block->event_count == 0 &&
        block->stream_uses.empty());

    record_trace(
        TraceEntry::FREE_COMPLETED,
        int64_t(block->ptr),
        block->requested_size,
        block->stream,
        block->device,
        block->pool->owner_MempoolId(),
        context ? context : block->context_when_allocated);

    block->context_when_allocated = nullptr;
    size_t original_block_size = block->size;
    size_t requested_size = block->requested_size;

    auto& pool = *block->pool;
    int64_t net_change_inactive_split_blocks = 0;
    int64_t net_change_inactive_split_size = 0;

    const std::array<Block*, 2> merge_candidates = {block->prev, block->next};
    for (Block* merge_candidate : merge_candidates) {
      const auto subsumed_size = try_merge_blocks(block, merge_candidate, pool);
      if (subsumed_size > 0) {
        net_change_inactive_split_blocks -= 1;
        net_change_inactive_split_size -= static_cast<int64_t>(subsumed_size);
      }
    }

    active_blocks.erase(block);
    // Makes sure the Block* isn't already present in the pool we're freeing it
    // back into.
    // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
    bool inserted = pool.insert_into_blocks(block).second;
    TORCH_INTERNAL_ASSERT(inserted);

    if (block->is_split()) {
      net_change_inactive_split_blocks += 1;
      net_change_inactive_split_size += static_cast<int64_t>(block->size);
    }

    StatTypes stat_types = get_stat_types_for_pool(pool);

    for_each_selected_stat_type(stat_types, [&](size_t stat_type) {
      // inactive_split tries to capture the idea that blocks
      // cannot be freed when requested, but fully free pages
      // of expandable blocks can always be freed.
      // The logic to track this as statistic is pretty involved,
      // so we simply just exclude expandable segments from
      // inactive_split
      if (!block->expandable_segment_) {
        if (net_change_inactive_split_blocks > 0) {
          stats.inactive_split[stat_type].increase(
              static_cast<size_t>(net_change_inactive_split_blocks));
        } else if (net_change_inactive_split_blocks < 0) {
          stats.inactive_split[stat_type].decrease(
              static_cast<size_t>(-net_change_inactive_split_blocks));
        }
        if (net_change_inactive_split_size > 0) {
          stats.inactive_split_bytes[stat_type].increase(
              static_cast<size_t>(net_change_inactive_split_size));
        } else if (net_change_inactive_split_size < 0) {
          stats.inactive_split_bytes[stat_type].decrease(
              static_cast<size_t>(-net_change_inactive_split_size));
        }
      }
      stats.active[stat_type].decrease(1);
      stats.active_bytes[stat_type].decrease(original_block_size);
      stats.requested_bytes[stat_type].decrease(requested_size);
    });
  }

  /** combine previously split blocks. returns the size of the subsumed block,
   * or 0 on failure. */
  size_t try_merge_blocks(Block* dst, Block* src, BlockPool& pool) {
    if (!src || src->allocated || src->event_count > 0 ||
        !src->stream_uses.empty() || dst->mapped != src->mapped) {
      return 0;
    }

    AT_ASSERT(dst->is_split() && src->is_split());

    if (dst->prev == src) { // [src dst]
      dst->ptr = src->ptr;
      dst->prev = src->prev;
      if (dst->prev) {
        dst->prev->next = dst;
      }
      dst->context_when_segment_allocated =
          std::move(src->context_when_segment_allocated);
    } else { // [dest src]
      dst->next = src->next;
      if (dst->next) {
        dst->next->prev = dst;
      }
    }
    const size_t subsumed_size = src->size;
    dst->size += subsumed_size;
    // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
    auto erased =
        src->mapped ? pool.blocks.erase(src) : pool.unmapped.erase(src);
    TORCH_INTERNAL_ASSERT_DEBUG_ONLY(erased == 1);
    delete src;

    return subsumed_size;
  }

  BlockPool& get_pool(size_t size, cudaStream_t stream) {
    // captures_underway is a conservative guess that the current stream may be
    // capturing. It's only non-empty if some thread has begun and not yet ended
    // a capture, so it's usually 0, and we can short-circuit
    // cudaStreamCaptureStatus (which does a TLS lookup).
    if (C10_UNLIKELY(!captures_underway.empty())) {
      // Use the reverse iterator to search captures_underway in LIFO order.
      for (auto it = captures_underway.rbegin(); it != captures_underway.rend();
           ++it) {
        if (it->second(stream)) {
          auto it1 = graph_pools.find(it->first);
          TORCH_INTERNAL_ASSERT(it1 != graph_pools.end());
          if (size <= kSmallSize) {
            return it1->second->small_blocks;
          } else {
            return it1->second->large_blocks;
          }
        }
      }
    }
    if (size <= kSmallSize) {
      return small_blocks;
    } else {
      return large_blocks;
    }
  }

  StatTypes get_stat_types_for_pool(const BlockPool& pool) {
    StatTypes stat_types = {false};
    stat_types[static_cast<size_t>(StatType::AGGREGATE)] = true;
    stat_types[static_cast<size_t>(
        pool.is_small ? StatType::SMALL_POOL : StatType::LARGE_POOL)] = true;
    return stat_types;
  }

  bool should_split(
      const Block* block,
      size_t size,
      bool is_expandable_segments_active) {
    // If the pool is marked as not splitting a segment, do not split
    if (no_split_pools.find(block->pool->owner_MempoolId()) !=
        no_split_pools.end()) {
      return false;
    }
    // Otherwise, check if the remaining size is greater than the minimum block
    // size
    size_t remaining = block->size - size;
    if (block->pool->is_small || is_expandable_segments_active) {
      return remaining >= kMinBlockSize;
    } else {
      return (size < AcceleratorAllocatorConfig::max_split_size()) &&
          (remaining > kSmallSize);
    }
  }

  static size_t get_allocation_size(size_t size) {
    if (size <= kSmallSize) {
      return kSmallBuffer;
    } else if (size < kMinLargeAlloc) {
      return AcceleratorAllocatorConfig::large_segment_size();
    } else {
      return kRoundLarge * ((size + kRoundLarge - 1) / kRoundLarge);
    }
  }

  bool get_free_block(AllocParams& p) {
    BlockPool& pool = *p.pool;

    if (C10_UNLIKELY(
            allowed_memory_maximum.has_value() &&
            AcceleratorAllocatorConfig::garbage_collection_threshold() > 0.0)) {
      // Track block reuse interval only when garbage collection is enabled.
      ++pool.get_free_blocks_call_count;
    }
    auto it = pool.blocks.lower_bound(&p.search_key);
    if (it == pool.blocks.end() || (*it)->stream != p.stream())
      return false;

    if ((*it)->expandable_segment_) {
      if (p.is_expandable_segments_active) {
        // if we are allocated to the part of the block that is expandable
        // for the purposes of "best fit" we consider its size to be the size it
        // can expand to, not the size it currently is. This means that we
        // sometimes have to search for blocks with bigger 'size' before
        // choosing this segment.
        auto expandable_size = [](Block* b) {
          return b->size + (b->next && !b->next->mapped ? b->next->size : 0);
        };
        auto next = it;
        next++;
        while ((*it)->expandable_segment_ && next != pool.blocks.end() &&
               (*next)->stream == p.stream() &&
               expandable_size(*next) < expandable_size(*it)) {
          it = next++;
        }
      } else {
        // Rarely expandable segments has been turned off after we have
        // already allocated some blocks as expandable. For instance,
        // since we cannot share expandable memory via IPC, someone might
        // temporarily disable it. In this case we need to honor this request
        // by only finding non-expandable blocks
        do {
          it++;
        } while (it != pool.blocks.end() && (*it)->expandable_segment_ &&
                 (*it)->stream == p.stream());
        if (it == pool.blocks.end() || (*it)->stream != p.stream()) {
          return false;
        }
      }
    }

    // Do not return an oversized block for a large request
    if ((p.size() < AcceleratorAllocatorConfig::max_split_size()) &&
        ((*it)->size >= AcceleratorAllocatorConfig::max_split_size()))
      return false;
    // Allow oversized block size to be rounded up but within a limit
    if ((p.size() >= AcceleratorAllocatorConfig::max_split_size()) &&
        ((*it)->size >=
         p.size() + AcceleratorAllocatorConfig::max_non_split_rounding_size()))
      return false;
    p.block = *it;
    pool.blocks.erase(it);
    return true;
  }

  bool trigger_free_memory_callbacks(AllocParams& p) {
    bool freed_memory = false;
    for (const auto& name : FreeCudaMemoryCallbacksRegistry()->Keys()) {
      freed_memory |=
          FreeCudaMemoryCallbacksRegistry()->Create(name)->Execute();
    }
    return freed_memory;
  }

  void garbage_collect_cached_blocks(
      const std::shared_ptr<GatheredContext>& context) {
    // Free unused cached blocks to reclaim GPU memory.
    // Unlike release_cached_blocks(), this does not enforce synchronization and
    // therefore should be of less overheads.

    size_t gc_threshold = static_cast<size_t>(
        AcceleratorAllocatorConfig::garbage_collection_threshold() *
        static_cast<double>(allowed_memory_maximum.value()));
    // No need to trigger GC yet
    if (total_allocated_memory <= gc_threshold) {
      return;
    }
    const auto target_size = total_allocated_memory - gc_threshold;
    size_t gc_reclaimed = 0;

    // Calculate the total age of the free-able blocks. We'll use it later to
    // get "avg age" threshold.
    size_t total_age = 0.0;
    int freeable_block_count = 0;
    for (auto& b : large_blocks.blocks) {
      if (!b->is_split()) {
        total_age += b->gc_count();
        ++freeable_block_count;
      }
    }
    // No free-able blocks?
    if (freeable_block_count == 0) {
      return;
    }

    // Repeat GC until we reach reclaim > target size.
    bool block_freed = true;
    while (gc_reclaimed < target_size && block_freed == true &&
           freeable_block_count > 0) {
      // Free blocks exceeding this age threshold first.
      double age_threshold =
          static_cast<double>(total_age) / freeable_block_count;
      // Stop iteration if we can no longer free a block.
      block_freed = false;

      // Free blocks of > avg age. Don't stop upon reaching the target_size,
      // we don't want this GC to be triggered frequently.
      auto it = large_blocks.blocks.begin();
      while (it != large_blocks.blocks.end()) {
        Block* block = *it;
        ++it;
        if (!block->is_split() && !block->expandable_segment_ &&
            static_cast<double>(block->gc_count()) >= age_threshold) {
          block_freed = true;
          gc_reclaimed += block->size;
          total_age -= block->gc_count(); // Decrement the age
          freeable_block_count--; // One less block that can be freed
          release_block(block, context);
        }
      }
    }
  }

  // This function assumes that global lock has been taken while calling into
  // this function. We do cudaMalloc sync call in this function which
  // can be expensive while holding the lock. Hence, we pass-in the lock to the
  // function to temporarily release the lock before cudaMalloc call and acquire
  // it back again after the call so that other threads dont get blocked.
  bool alloc_block(
      AllocParams& p,
      bool isRetry,
      const std::shared_ptr<GatheredContext>& ctx,
      std::unique_lock<std::recursive_mutex>& lock) {
    // Defensively checks for preexisting CUDA error state.
    C10_CUDA_CHECK(cudaGetLastError());

    size_t size = p.alloc_size;
    void* ptr = nullptr;

    if (isRetry) {
      stats.num_alloc_retries += 1;
    }
#ifdef FBCODE_CAFFE2
    bool in_fbcode = true;
#else
    bool in_fbcode = false;
#endif

    if (allowed_memory_maximum.has_value() &&
        total_allocated_memory + size > allowed_memory_maximum.value()) {
      p.err = cudaErrorMemoryAllocation;
      return false;
      // Temporarily disable checkpointing & cudagraphs internally
    } else if (
        p.is_expandable_segments_active &&
        !(in_fbcode && p.pool->owner_PrivatePool)) {
      p.block = try_allocate_expandable_block(
          p.device(), p.stream(), p.pool, p.size(), ctx);
      if (p.block) {
        p.err = cudaSuccess;
        if (p.pool->owner_PrivatePool) {
          // The block is for a CUDA graph's PrivatePool.
          p.pool->owner_PrivatePool->cudaMalloc_count++;
        }
      } else {
        p.err = cudaErrorMemoryAllocation;
      }
      return bool(p.block);
    } else {
      if (CUDAAllocatorConfig::release_lock_on_cudamalloc()) {
        // At scope exit, acquire the lock again. This provides safety against
        // any potential exceptions in the cudaMallocMaybeCapturing function.
        auto sg = c10::make_scope_exit([&]() { lock.lock(); });
        lock.unlock();
        p.err = cudaMallocMaybeCapturing(&ptr, size, p);
      } else {
        p.err = cudaMallocMaybeCapturing(&ptr, size, p);
      }
      if (CUDAAllocatorConfig::release_lock_on_cudamalloc()) {
        TORCH_CHECK(
            lock.owns_lock(), "Failed to acquire lock after cudaMalloc");
      }

      if (p.err != cudaSuccess) {
        if (p.err == cudaErrorMemoryAllocation) {
          // If this is the first attempt (!isRetry), we can forgive and clear
          // CUDA's internal error state.
          //
          // If this is the second attempt (isRetry), malloc's TORCH_CHECK_WITH
          // will take over to throw a helpful exception. The user can choose
          // to catch the exception, free some stuff in their script, and
          // attempt the allocation again. In this case, we can also forgive and
          // clear CUDA's internal error state.
          (void)cudaGetLastError();
        } else {
          // If the error's unrelated to memory allocation, we should throw
          // immediately.
          C10_CUDA_CHECK(p.err);
        }
        return false;
      }
    }

    if (p.pool->owner_PrivatePool) {
      // The block is for a CUDA graph's PrivatePool.
      p.pool->owner_PrivatePool->cudaMalloc_count++;
    }

    total_allocated_memory += size;
    p.block = new Block(
        p.device(), p.stream(), size, p.pool, static_cast<char*>(ptr));
    for_each_selected_stat_type(p.stat_types, [&](size_t stat_type) {
      stats.segment[stat_type].increase(1);
      stats.reserved_bytes[stat_type].increase(size);
    });
    if (size >= AcceleratorAllocatorConfig::max_split_size())
      stats.oversize_segments.increase(1);
    auto reserved_bytes_gauge =
        STATIC_GAUGE(pytorch.CUDACachingAllocator.reserved_bytes);
    reserved_bytes_gauge.record(
        stats.reserved_bytes[static_cast<int64_t>(StatType::AGGREGATE)]
            .current);

    // p.block came from new, not cudaMalloc. It should not be nullptr here.
    TORCH_INTERNAL_ASSERT(p.block != nullptr && p.block->ptr != nullptr);
    stats.num_device_alloc++;
    record_trace(
        TraceEntry::SEGMENT_ALLOC,
        int64_t(p.block->ptr),
        p.block->size,
        p.stream(),
        p.device(),
        p.pool->owner_MempoolId(),
        ctx);
    p.block->context_when_segment_allocated = ctx;
    return true;
  }

  /** Free one or more oversize blocks to the system allocator.  But only enough
   * **/
  /** to satisfy the target size **/
  bool release_available_cached_blocks(
      const AllocParams& p,
      const std::shared_ptr<GatheredContext>& context) {
    if (AcceleratorAllocatorConfig::max_split_size() ==
        std::numeric_limits<size_t>::max())
      return false;
    BlockPool& pool = *p.pool;

    // because of std::unique_ptr, block cannot be trivially copied
    // Use constructor for search key.
    Block key(p.search_key.device, p.search_key.stream, p.search_key.size);
    key.size = (key.size < AcceleratorAllocatorConfig::max_split_size())
        ? AcceleratorAllocatorConfig::max_split_size()
        : key.size;
    auto it = pool.blocks.lower_bound(&key);
    if (it == pool.blocks.end() || (*it)->stream != p.stream() ||
        (*it)->expandable_segment_) {
      // No single block is large enough; free multiple oversize blocks,
      // starting with the largest
      if (it == pool.blocks.begin())
        return false;
      size_t totalReleased = 0;
      --it; // Back up one item.  Now on the largest block for the correct
            // stream
      while ((totalReleased < key.size) &&
             ((*it)->size >= AcceleratorAllocatorConfig::max_split_size()) &&
             ((*it)->stream == p.stream())) {
        auto cur = it;
        bool is_first = cur == pool.blocks.begin();
        if (!is_first) {
          --it;
        }
        if (!(*cur)->expandable_segment_) {
          totalReleased += (*cur)->size;
          release_block(*cur, context);
        }
        if (is_first) {
          break;
        }
      }
      if (totalReleased < key.size)
        return false;
    } else {
      release_block(*it, context);
    }
    return true;
  }

  /**
   * If mempool_id is {0,0} (the default pool) and there are no
   * currently capturing memory pools, free the default pool's blocks
   * and also free the blocks of the freeable private pools.
   *
   * If mempool_id corresponds to a private pool that is freeable,
   * call synchronize_and_free_events() on that private pool. Free the
   * blocks of all freeable private pools, including this one.
   */
  bool release_cached_blocks(
      const std::shared_ptr<GatheredContext>& context,
      MempoolId_t mempool_id) {
    if (mempool_id.first == 0 && mempool_id.second == 0 &&
        captures_underway.empty()) {
      // If there is no active mempool, we work on releasing *all* blocks.

      // First ensure that all blocks that can't currently be allocated due to
      // outstanding events are returned to the pool.
      synchronize_and_free_events(context);

      // Free all non-split cached blocks to system allocator
      release_blocks(large_blocks, context);
      release_blocks(small_blocks, context);
    }

    for (auto it = graph_pools_freeable.begin();
         it != graph_pools_freeable.end();) {
      if (mempool_id.first != 0 || mempool_id.second != 0) {
        if (it->first == mempool_id) {
          // If there is an active mempool, we sync only the events
          // associated with the pool
          synchronize_and_free_events(context, it->second);
        } else {
          // otherwise we move on
          ++it;
          continue;
        }
      }
      // See notifyCaptureDestroy for the strategy here.
      TORCH_INTERNAL_ASSERT(it->second->use_count == 0);
      release_blocks(it->second->small_blocks, context);
      release_blocks(it->second->large_blocks, context);
      if (it->second->cudaMalloc_count == 0) {
        auto erase_count = graph_pools.erase(it->first);
        TORCH_INTERNAL_ASSERT(erase_count == 1);
        it = graph_pools_freeable.erase(it);
      } else {
        ++it;
      }
    }

    return true;
  }

  void release_expandable_segment(Block* block) {
    TORCH_INTERNAL_ASSERT(
        block->size == block->expandable_segment_->size(),
        "block disagrees with segment");
    TORCH_INTERNAL_ASSERT(!block->mapped);
    auto it = std::find(
        expandable_segments_.begin(),
        expandable_segments_.end(),
        block->expandable_segment_);
    TORCH_INTERNAL_ASSERT(it != expandable_segments_.end());
    expandable_segments_.erase(it);
    block->pool->unmapped.erase(block);
    delete block->expandable_segment_;
    delete block;
  }

  void release_block(
      Block* block,
      const std::shared_ptr<GatheredContext>& context) {
    TORCH_INTERNAL_ASSERT(!block->expandable_segment_);
    stats.num_device_free++;
    record_trace(
        TraceEntry::SEGMENT_FREE,
        int64_t(block->ptr),
        block->size,
        block->stream,
        block->device,
        block->pool->owner_MempoolId(),
        context ? context : block->context_when_segment_allocated);

    auto* pool = block->pool;
    if (pool->owner_PrivatePool && pool->owner_PrivatePool->allocator()) {
      // If there is an active mempool with a given allocator,
      // we use the given allocator's delete function.
      pool->owner_PrivatePool->allocator()->raw_delete(block->ptr);
    } else {
      C10_CUDA_CHECK(cudaFree((void*)block->ptr));
    }
    total_allocated_memory -= block->size;

    if (pool->owner_PrivatePool) {
      // The cudaFreed block belonged to a CUDA graph's PrivatePool.
      TORCH_INTERNAL_ASSERT(pool->owner_PrivatePool->cudaMalloc_count > 0);
      pool->owner_PrivatePool->cudaMalloc_count--;
    }

    StatTypes stat_types = get_stat_types_for_pool(*pool);
    for_each_selected_stat_type(stat_types, [&](size_t stat_type) {
      stats.segment[stat_type].decrease(1);
      stats.reserved_bytes[stat_type].decrease(block->size);
    });
    auto reserved_bytes_gauge =
        STATIC_GAUGE(pytorch.CUDACachingAllocator.reserved_bytes);
    reserved_bytes_gauge.record(
        stats.reserved_bytes[static_cast<int64_t>(StatType::AGGREGATE)]
            .current);

    if (block->size >= AcceleratorAllocatorConfig::max_split_size())
      stats.oversize_segments.decrease(1);
    pool->blocks.erase(block);
    delete block;
  }

  void unmap_block(
      Block* block,
      const std::shared_ptr<GatheredContext>& context) {
    auto unmapped = block->expandable_segment_->unmap(
        SegmentRange{block->ptr, block->size});
    if (unmapped.size == 0) {
      return;
    }
    block->pool->blocks.erase(block);

    ptrdiff_t before_size = unmapped.ptr - static_cast<char*>(block->ptr);
    if (before_size > 0) {
      // prev? -> before_free -> block
      Block* before_free = new Block(
          block->device, block->stream, before_size, block->pool, block->ptr);
      before_free->expandable_segment_ = block->expandable_segment_;
      before_free->splice(block->prev, block);
      block->pool->insert_into_blocks(before_free);
    }

    auto after_size = block->size - (before_size + unmapped.size);
    if (after_size > 0) {
      // block -> after_free -> next?
      Block* after_free = new Block(
          block->device,
          block->stream,
          after_size,
          block->pool,
          unmapped.ptr + unmapped.size);
      after_free->expandable_segment_ = block->expandable_segment_;
      after_free->splice(block, block->next);
      block->pool->insert_into_blocks(after_free);
    }

    block->ptr = unmapped.ptr;
    block->size = unmapped.size;
    block->mapped = false;

    try_merge_blocks(block, block->prev, *block->pool);
    try_merge_blocks(block, block->next, *block->pool);
    block->pool->unmapped.insert(block);

    // update statistics
    total_allocated_memory -= unmapped.size;
    StatTypes stat_types = get_stat_types_for_pool(*block->pool);
    for_each_selected_stat_type(stat_types, [&](size_t stat_type) {
      stats.reserved_bytes[stat_type].decrease(unmapped.size);
    });
    auto reserved_bytes_gauge =
        STATIC_GAUGE(pytorch.CUDACachingAllocator.reserved_bytes);
    reserved_bytes_gauge.record(
        stats.reserved_bytes[static_cast<int64_t>(StatType::AGGREGATE)]
            .current);

    if (block->pool->owner_PrivatePool) {
      // The cudaFreed block belonged to a CUDA graph's PrivatePool.
      TORCH_INTERNAL_ASSERT(
          block->pool->owner_PrivatePool->cudaMalloc_count > 0);
      block->pool->owner_PrivatePool->cudaMalloc_count--;
    }

    stats.num_device_free++;
    record_trace(
        TraceEntry::SEGMENT_UNMAP,
        int64_t(unmapped.ptr),
        unmapped.size,
        block->stream,
        block->device,
        block->pool->owner_MempoolId(),
        context ? context : block->context_when_segment_allocated);
  }
  void release_blocks(
      BlockPool& pool,
      const std::shared_ptr<GatheredContext>& context) {
    std::vector<Block*> to_unmap;
    // Frees all non-split blocks
    auto it = pool.blocks.begin();
    while (it != pool.blocks.end()) {
      Block* block = *it;
      ++it;
      if (block->expandable_segment_) {
        // unmapping will mutate the free pool
        // so just gather what needs to be freed
        // to avoid invalidating the iterator
        to_unmap.push_back(block);
      } else if (!block->prev && !block->next) {
        release_block(block, context);
      }
    }
    for (Block* block : to_unmap) {
      unmap_block(block, context);
      if (!block->prev && !block->next) {
        release_expandable_segment(block);
      }
    }
  }

  EventPool::Event create_event_internal(c10::DeviceIndex idx) {
    // Leak the event pool to avoid shutdown issues.
    static auto* event_pool = new EventPool();
    return event_pool->get(idx);
  }

  void synchronize_and_free_events(
      const std::shared_ptr<GatheredContext>& context,
      PrivatePool* pool = nullptr) {
    // Synchronize on outstanding events and then free associated blocks.
    stats.num_sync_all_streams++;

    // This function syncs, so capture should not be underway. Might as well
    // make sure capture-deferred end of life events get processed too.
    TORCH_INTERNAL_ASSERT(captures_underway.empty());
    insert_events_deferred_until_no_capture(context);

    for (auto it = cuda_events.begin(); it != cuda_events.end();) {
      for (auto e = it->second.begin(); e != it->second.end();) {
        Block* block = e->second;

        // If a pool was passed, only synchronize the events
        // that are associated with the pool, otherwise move on
        if (pool && block->pool->owner_PrivatePool != pool) {
          ++e;
          continue;
        }

        EventPool::Event event = std::move(e->first);

        C10_CUDA_CHECK(cudaEventSynchronize(*event));

        block->event_count--;
        if (block->event_count == 0) {
          free_block(block, context);
        }
        // We are done with the event, so erase it from the deque
        e = it->second.erase(e);
      }

      // If the events deque is empty, only then erase the
      // cuda event from the events map
      if (it->second.empty()) {
        it = cuda_events.erase(it);
      } else {
        it++;
      }
    }
  }

  void remove_cudagraph_stream_uses(Block* block) {
    // remove stream uses added during cudagraph capture
    // (i.e., block->stream_uses - block->cudagraph_stream_uses)
    if (C10_UNLIKELY(
            block_to_cudagraph_stream_uses.find(block) !=
            block_to_cudagraph_stream_uses.end())) {
      stream_set streams(std::move(block->stream_uses));
      AT_ASSERT(block->stream_uses.empty());
      for (auto& stream : streams) {
        if (block_to_cudagraph_stream_uses[block].find(stream) ==
            block_to_cudagraph_stream_uses[block].end()) {
          block->stream_uses.insert(stream);
        }
      }
      block_to_cudagraph_stream_uses.erase(block);
    }
  }

  void insert_events(Block* block) {
    c10::DeviceIndex prev_device = 0;
    C10_CUDA_CHECK(c10::cuda::GetDevice(&prev_device));

    stream_set streams(std::move(block->stream_uses));
    AT_ASSERT(block->stream_uses.empty());
    for (auto& stream : streams) {
      C10_CUDA_CHECK(c10::cuda::SetDevice(stream.device_index()));

      EventPool::Event event = create_event_internal(stream.device_index());
      C10_CUDA_CHECK(cudaEventRecord(*event, stream.stream()));

      block->event_count++;
      cuda_events[stream].emplace_back(std::move(event), block);
    }

    C10_CUDA_CHECK(c10::cuda::MaybeSetDevice(prev_device));
  }

  void insert_events_deferred_until_no_capture(
      const std::shared_ptr<GatheredContext>& context) {
    if (C10_UNLIKELY(!deferred_blocks.empty())) {
      for (auto& [block, inserted_empty_nodes] : deferred_blocks) {
        TORCH_INTERNAL_ASSERT(!block->stream_uses.empty());
        // only streams recorded before cudagraph will be used to insert events
        // since we know all streams recorded during cudagraph must have
        // completed (refer to Section 3.2.8.7.3.1 Cross-stream Dependencies and
        // Events in CUDA Programming Guide).
        remove_cudagraph_stream_uses(block);
        insert_events(block);
        if (block->event_count == 0) {
          free_block(block, context);
        }
      }
      deferred_blocks.clear();
    }
  }

  void process_events(const std::shared_ptr<GatheredContext>& context) {
    insert_events_deferred_until_no_capture(context);

    // Process outstanding cudaEvents. Events that are completed are
    // removed from the queue, and the 'event_count' for the
    // corresponding allocation is decremented. We maintain a separate
    // list of events per stream to avoid head-of-line delays if one
    // or more streams has long-running operations.

    // Iterate over different streams.
    for (auto it = cuda_events.begin(); it != cuda_events.end();) {
      // Iterate over this stream's (event, block) pairs.
      while (!it->second.empty()) {
        auto& e = it->second.front();
        EventPool::Event event = std::move(e.first);
        Block* block = e.second;

        cudaError_t err = C10_CUDA_ERROR_HANDLED(cudaEventQuery(*event));
        if (err == cudaErrorNotReady) {
          // ignore and clear the error if not ready
          (void)cudaGetLastError();
          // Return the ownership of the Event (unique ptr)
          e.first = std::move(event);
          break;
        } else if (err != cudaSuccess) {
          C10_CUDA_CHECK(err);
        }

        block->event_count--;
        if (block->event_count == 0) {
          free_block(block, context);
        }
        it->second.pop_front();
      }

      if (it->second.empty()) {
        it = cuda_events.erase(it);
      } else {
        it++;
      }
    }
  }

  // Iterates over sizes of all memory blocks for given device in given pool
  void cache_info_aux(const BlockPool& pool, size_t* largest) {
    for (const auto& block : pool.blocks) {
      const auto blocksize = block->size;
      if (blocksize > *largest) {
        *largest = blocksize;
      }
    }
  }

  void record_trace(
      TraceEntry::Action action,
      size_t addr,
      size_t size,
      cudaStream_t stream,
      c10::DeviceIndex device,
      MempoolId_t mempool_id,
      std::shared_ptr<GatheredContext> context) {
    if (!record_history && trace_trackers_.empty())
      return;
    std::string compile_string = "N/A";
    if (!compile_context.empty()) {
      compile_string = compile_context.top();
    }
    TraceEntry te(
        action,
        device,
        addr,
        size,
        reinterpret_cast<void*>(stream),
        mempool_id,
        getApproximateTime(),
        record_context_ >= RecordContext::ALLOC ? std::move(context) : nullptr,
        compile_string,
        user_metadata);

    // Callbacks should not include any Pytorch call
    for (const auto& cb : trace_trackers_) {
      cb(te);
    }

    if (record_history) {
      // Skip if action is in the skip_actions set
      bool should_skip = skip_actions_list.count(action) > 0;
      if (!should_skip) {
        alloc_buffer.insertEntries(te);
      }
    }
  }
};

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free