Home / Class/ feature_rank Class — pytorch Architecture

feature_rank Class — pytorch Architecture

Architecture documentation for the feature_rank class in BatchRulesHelper.h from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/functorch/BatchRulesHelper.h lines 291–361

template <int64_t feature_rank, int64_t contig_tensor_index=-1>
inline void boxed_all_tensors_have_optional_bdim(
    const c10::OperatorHandle& op, torch::jit::Stack* stack) {
  const auto& schema = op.schema();
  const auto num_returns = schema.returns().size();
  const auto num_arguments = schema.arguments().size();

  c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
  auto maybe_layer = maybeCurrentDynamicLayer();
  vmap_check_escaped(maybe_layer, "boxed_all_tensors_have_optional_bdim");
  int64_t cur_level = maybe_layer->layerId();

  const auto arguments = torch::jit::last(stack, num_arguments);
  if (std::none_of(arguments.begin(), arguments.end(), ivalueParticipatesInCurrentLevel)) {
    op.callBoxed(stack);
    return;
  }

  int64_t args_begin = static_cast<int64_t>(stack->size() - num_arguments);
  SmallVector<UnpackedBatchedTensor, 5> tensor_inputs;
  SmallVector<int64_t, 5> tensor_pos;
  int64_t batch_size = 0;

  find_and_unpack_tensors(
      stack, static_cast<int64_t>(num_arguments), cur_level,
      &tensor_inputs, &tensor_pos, &batch_size);

  std::optional<bool> is_no_batch_dim_case;

  for (const auto tensor_idx : c10::irange(0, tensor_inputs.size())) {
    const auto& value = std::get<0>(tensor_inputs[tensor_idx]);
    auto bdim = std::get<1>(tensor_inputs[tensor_idx]);
    const auto logical_rank = rankWithoutBatchDim(value, bdim);

    if (!is_no_batch_dim_case.has_value()) {
      is_no_batch_dim_case = (logical_rank == feature_rank);
    }
    auto value_ = ensure_has_bdim(value, bdim.has_value(), batch_size);
    if (!bdim.has_value()) {
      bdim = 0;
    }
    if (*is_no_batch_dim_case) {
      TORCH_INTERNAL_ASSERT(logical_rank == feature_rank);
      value_ = moveBatchDimToFront(value_, bdim);
      if (tensor_idx == contig_tensor_index) {
        value_ = value_.contiguous();
      }
      (*stack)[args_begin + tensor_pos[tensor_idx]] = std::move(value_);
      continue;
    }
    TORCH_INTERNAL_ASSERT(logical_rank == feature_rank + 1);
    value_ = reshape_dim_into(*bdim, 0, value_);
    if (tensor_idx == contig_tensor_index) {
      value_ = value_.contiguous();
    }
    (*stack)[args_begin + tensor_pos[tensor_idx]] = std::move(value_);
  }

  op.callBoxed(stack);

  for (const auto idx : c10::irange(args_begin, args_begin + num_returns)) {
    const auto& ret = (*stack)[idx];
    TORCH_INTERNAL_ASSERT(ret.isTensor(),
        "This boxed batching rule does not currently support ops that return non-tensor values");
    if (*is_no_batch_dim_case) {
      (*stack)[idx] = makeBatched(ret.toTensor(), 0, cur_level);
    } else {
      (*stack)[idx] = makeBatched(reshape_dim_outof(0, batch_size, ret.toTensor()), 0, cur_level);
    }
  }
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free