Home / Class/ Func Class — pytorch Architecture

Func Class — pytorch Architecture

Architecture documentation for the Func class in BatchRulesNorm.cpp from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/functorch/BatchRulesNorm.cpp lines 42–123

template<typename F, F Func>
static
std::tuple<Tensor, std::optional<int64_t>,Tensor, std::optional<int64_t>,Tensor, std::optional<int64_t>>
batch_norm_batch_rule(
    const Tensor& input, std::optional<int64_t> input_bdim,
    const std::optional<Tensor>& weight_opt, std::optional<int64_t> weight_bdim,
    const std::optional<Tensor>& bias_opt, std::optional<int64_t> bias_bdim,
    const std::optional<Tensor>& running_mean_opt, std::optional<int64_t> running_mean_bdim,
    const std::optional<Tensor>& running_var_opt, std::optional<int64_t> running_var_bdim,
    bool training, double momentum, double eps) {
  c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
  const Tensor& weight = *weight_maybe_owned;
  c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
  const Tensor& bias = *bias_maybe_owned;
  c10::MaybeOwned<Tensor> running_mean_maybe_owned = at::borrow_from_optional_tensor(running_mean_opt);
  const auto& running_mean = *running_mean_maybe_owned;
  c10::MaybeOwned<Tensor> running_var_maybe_owned = at::borrow_from_optional_tensor(running_var_opt);
  const auto& running_var = *running_var_maybe_owned;
  TORCH_CHECK(!training || (!input_bdim || ((!running_mean.defined() || running_mean_bdim) && (!running_var.defined() || running_var_bdim))),
      "Batch norm got a batched tensor as input while the running_mean or running_var, which will be updated in place, ",
      "were not batched.\nIf you are using a module and do not need eval mode, please set `track_running_stats` to be False.",
      "If you are using a prebuilt module and do not need eval mode, please see the functorch website for resources on ",
      "how to patch your module to work with vmap");
  std::optional<int64_t> bdim_size;
  Tensor result0;
  Tensor mean;
  Tensor rstd;
  if (!input_bdim && !running_mean_bdim && !running_var_bdim) {
    const auto dummy_weight = at::ones(input.size(1), input.options());  // cudnn and miopen require a weight
    const auto dummy_bias = at::zeros(input.size(1), input.options());   // without this, get "strides() called on undefined Tensor" on cuda
    auto result = Func(input, dummy_weight, dummy_bias, running_mean_opt, running_var_opt, training, momentum, eps);
    result0 = std::get<0>(result).transpose(0, 1);          // [C, B, *]
    mean = std::move(std::get<1>(result));
    rstd = std::move(std::get<2>(result));
  } else {
    bdim_size = get_bdim_size3(input, input_bdim, running_mean, running_mean_bdim, running_var, running_var_bdim);
    auto input_ = moveBatchDimToFront(input, input_bdim);
    input_ = ensure_has_bdim(input_, input_bdim.has_value(), bdim_size.value());
    input_ = reshape_dim_into(0, /*channels dim*/1, input_);

    std::optional<Tensor> running_mean_;
    std::optional<Tensor> running_var_;
    if (running_mean.defined()) {
      running_mean_ = moveBatchDimToFront(running_mean, running_mean_bdim);
      running_mean_ = ensure_has_bdim(*running_mean_, running_mean_bdim.has_value(), bdim_size.value());
      running_mean_ = reshape_dim_into(0, 0, *running_mean_).contiguous();
    }
    if (running_var.defined()) {
      running_var_ = moveBatchDimToFront(running_var, running_var_bdim);
      running_var_ = ensure_has_bdim(*running_var_, running_var_bdim.has_value(), bdim_size.value());
      running_var_ = reshape_dim_into(0, 0, *running_var_).contiguous();
    }

    const auto dummy_weight = at::ones(input_.size(1), input_.options());  // cudnn and miopen require a weight
    const auto dummy_bias = at::zeros(input_.size(1), input_.options());   // without this, get "strides() called on undefined Tensor" on cuda
    auto result = Func(input_, dummy_weight, dummy_bias, running_mean_, running_var_, training, momentum, eps);
    result0 = std::get<0>(result).transpose(0, 1);                // [(B0, C), B, *]
    mean = std::move(std::get<1>(result));
    rstd = std::move(std::get<2>(result));
    result0 = reshape_dim_outof(0, bdim_size.value(), result0);   // [B0, C, B, *]
    mean = reshape_dim_outof(0, bdim_size.value(), mean);         // [B0, C]
    rstd = reshape_dim_outof(0, bdim_size.value(), rstd);         // [B0, C]
  }

  const auto stats_bdim = compute_stat_bdim(bdim_size, mean);
  if (weight.defined()) {
    const auto input_logical_rank = rankWithoutBatchDim(input, input_bdim);
    auto weight_ = moveBatchDimToFront(weight, weight_bdim);
    weight_ = padRight(weight_, weight_bdim, input_logical_rank);
    result0 = result0 * weight_;
  }
  if (bias.defined()) {
    const auto result_logical_rank = rankWithoutBatchDim(
        result0,
        bdim_size.has_value() || weight_bdim.has_value() ? std::optional<int64_t>(0) : std::optional<int64_t>(std::nullopt));
    auto bias_ = moveBatchDimToFront(bias, bias_bdim);
    bias_ = padRight(bias_, bias_bdim, result_logical_rank);
    result0 = result0 + bias_;
  }
  result0 = result0.transpose(1, 2);  // [B0, B, C, *], because some arg must have been batched, the output must be batched
  return std::make_tuple(std::move(result0), 0, std::move(mean), stats_bdim, std::move(rstd), stats_bdim);
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free