Home / Class/ dim_arg_pos Class — pytorch Architecture

dim_arg_pos Class — pytorch Architecture

Architecture documentation for the dim_arg_pos class in BatchRulesReduceOps.cpp from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/functorch/BatchRulesReduceOps.cpp lines 100–235

template<
  int dim_arg_pos,
  int keepdim_case,
  // optional cannot be used in a template, otherwise we would use it here.
  int maybe_keepdim_arg_pos
>
static void boxed_reduction_batch_rule(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
  const auto& schema = op.schema();
  const auto num_returns = schema.returns().size();
  const auto num_arguments = schema.arguments().size();

  c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
  auto maybe_layer = maybeCurrentDynamicLayer();
  vmap_check_escaped(maybe_layer, "boxed_reduction_batch_rule");
  int64_t cur_level = maybe_layer->layerId();

  auto orig_arguments = torch::jit::last(*stack, num_arguments);
  if (std::none_of(orig_arguments.begin(), orig_arguments.end(), ivalueParticipatesInCurrentLevel)) {
    c10::impl::ExcludeDispatchKeyGuard guard_2(DispatchKey::FuncTorchBatched);
    op.callBoxed(stack);
    return;
  }

  auto arguments = torch::jit::pop(*stack, num_arguments);

  TORCH_INTERNAL_ASSERT(arguments[0].isTensor());
  auto [self, self_bdim] = unwrapTensorAtLevel(arguments[0].toTensor(), cur_level);

  self = moveBatchDimToFront(self, self_bdim);

  auto logical_dim = rankWithoutBatchDim(self, self_bdim);
  std::vector<int64_t> dims;
  ReductionCase reduction_case{};
  if (arguments[dim_arg_pos].isIntList()) {
    reduction_case = ReductionCase::DimArray;
    dims = arguments[dim_arg_pos].toIntList().vec();
    if (dims.empty()) {
      auto all_dims = range(0, std::max(static_cast<int64_t>(1), logical_dim));
      dims = std::vector<int64_t>(all_dims.begin(), all_dims.end());
    }
  } else if (arguments[dim_arg_pos].isInt()) {
    reduction_case = ReductionCase::Dim;
    dims = {arguments[dim_arg_pos].toInt()};
  } else if (arguments[dim_arg_pos].isNone())  {
    auto param_type = schema.arguments()[dim_arg_pos].type()->expect<OptionalType>()->getElementType();
    if (param_type->kind() == IntType::Kind) {
      reduction_case = ReductionCase::Dim;
      if (self.dim() > 1) {
        self = self.flatten(1);
      }
      dims = {0};
    } else if (param_type->kind() == ListType::Kind) {
      reduction_case = ReductionCase::DimArray;
      if (logical_dim == 0) {
        dims = {0};
      } else {
        auto all_dims = range(0, self.dim() - 1);
        dims = std::vector<int64_t>(all_dims.begin(), all_dims.end());
      }
    } else {
      TORCH_INTERNAL_ASSERT(false, "Unexpected dtype found at dims");
    }
  } else{
    TORCH_INTERNAL_ASSERT(false, "Unexpected dtype found at dims");
  }

  VmapDimVector new_dims;
  new_dims.reserve(dims.size());
  for (auto dim: dims) {
    new_dims.push_back(getPhysicalDim(self, self_bdim.has_value(), dim));
  }
  bool is_scalar_case = logical_dim == 0 && dims.size() == 1 && is_allowed_dim_on_scalar_tensor(dims[0]);
  std::optional<bool> maybe_keepdim;
  if (is_scalar_case) {
    // NOTE: [boxed_reduction_batch_rule scalar tensor handling]
    // Reduction operations in PyTorch have an edge case where they allow
    // dim=0 and dim=-1 if the tensor has shape [].
    //
    // This can come up if we do something like
    // vmap(lambda x: x.sum(0))(torch.tensor([10.])),
    //
    // In order to handle this edge case, we unsqueeze a dimension on the Tensor,
    // run the operation (with dim=1 instead), and then process the output tensor.
    // There are two cases:
    // - keepdim = True
    //     unsqueeze   op      squeeze
    //   [B] -> [B, 1] -> [B, 1] -> [B]
    // - keepdim = False
    //     unsqueeze   op     no need to squeeze
    //   [B] -> [B, 1] -> [B]
    // if keepdim is True, then we need to squeeze the dimension of size 1.

    // Determine the value of keepdim
    switch (keepdim_case) {
      case KEEPDIM_CASE_FALSE:
        maybe_keepdim = false;
        break;
      case KEEPDIM_CASE_TRUE:
        maybe_keepdim = true;
        break;
      case KEEPDIM_CASE_VARIABLE:
        TORCH_INTERNAL_ASSERT(maybe_keepdim_arg_pos >= 0);
        maybe_keepdim = arguments[maybe_keepdim_arg_pos].toBool();
        break;
    }
    self = self.unsqueeze(-1);
    new_dims = {1};
  }
  arguments[0] = std::move(self);
  if (reduction_case == ReductionCase::DimArray) {
    arguments[dim_arg_pos] = std::vector<int64_t>(new_dims.begin(), new_dims.end());
  } else if (reduction_case == ReductionCase::Dim) {
    arguments[dim_arg_pos] = new_dims[0];
  }
  for (const auto arg_idx : c10::irange(0, num_arguments)) {
    torch::jit::push(stack, arguments[arg_idx]);
  }
  op.callBoxed(stack);

  auto returns = torch::jit::pop(*stack, num_returns);
  for (auto& ret : returns) {
    if (ret.isTensor()) {
      auto res = ret.toTensor();
      // see NOTE: [boxed_reduction_batch_rule scalar tensor handling]
      if (is_scalar_case && maybe_keepdim.value()) {
        // squeeze(-1) is a no-op if the shape of the dim is not 1.
        // To make it safer, we internal assert here.
        TORCH_INTERNAL_ASSERT(res.size(-1) == 1);
        res = res.squeeze(-1);
      }
      torch::jit::push(stack, makeBatched(std::move(res), 0, cur_level));
    } else {
      TORCH_INTERNAL_ASSERT(false, "This boxed batching rule does not currently support ops that return non-tensor values");
    }
  }
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free