Home / Class/ max_pool_with_indices_batch_rule_helper Class — pytorch Architecture

max_pool_with_indices_batch_rule_helper Class — pytorch Architecture

Architecture documentation for the max_pool_with_indices_batch_rule_helper class in BatchRulesPooling.cpp from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/functorch/BatchRulesPooling.cpp lines 11–37

template <typename Func>
static std::tuple<Tensor, std::optional<int64_t>,Tensor, std::optional<int64_t>>
max_pool_with_indices_batch_rule_helper(
  const Tensor& self, std::optional<int64_t> self_bdim,
  IntArrayRef kernel_size, IntArrayRef stride,
  IntArrayRef padding, IntArrayRef dilation, bool ceil_mode, int64_t n, Func pooling_fn) {

  auto logical_rank = rankWithoutBatchDim(self, self_bdim);
  TORCH_INTERNAL_ASSERT(logical_rank == n + 1 || logical_rank == n + 2);
  // Tensor[B, logical_rank...] -> just call max_poolnd
  if (logical_rank == n + 1) {
    auto self_ = moveBatchDimToFront(self, self_bdim);
    auto result = pooling_fn(
        self_, kernel_size, stride, padding, dilation, ceil_mode);
    return std::make_tuple(std::move(std::get<0>(result)), 0, std::move(std::get<1>(result)), 0);
  }
  // Tensor[B, N, logical_rank...] -> Tensor[B * N, logical_rank...]
  // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
  auto bdim_size = self.size(self_bdim.value());
  // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
  auto self_ = reshape_dim_into(self_bdim.value(), 0, self);
  auto result = pooling_fn(
      self_, kernel_size, stride, padding, dilation, ceil_mode);
  return std::make_tuple(
      reshape_dim_outof(0, bdim_size, std::get<0>(result)), 0,
      reshape_dim_outof(0, bdim_size, std::get<1>(result)), 0);
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free