Func Class — pytorch Architecture
Architecture documentation for the Func class in BatchRulesNorm.cpp from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/functorch/BatchRulesNorm.cpp lines 42–123
template<typename F, F Func>
static
std::tuple<Tensor, std::optional<int64_t>,Tensor, std::optional<int64_t>,Tensor, std::optional<int64_t>>
batch_norm_batch_rule(
const Tensor& input, std::optional<int64_t> input_bdim,
const std::optional<Tensor>& weight_opt, std::optional<int64_t> weight_bdim,
const std::optional<Tensor>& bias_opt, std::optional<int64_t> bias_bdim,
const std::optional<Tensor>& running_mean_opt, std::optional<int64_t> running_mean_bdim,
const std::optional<Tensor>& running_var_opt, std::optional<int64_t> running_var_bdim,
bool training, double momentum, double eps) {
c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
const Tensor& weight = *weight_maybe_owned;
c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
const Tensor& bias = *bias_maybe_owned;
c10::MaybeOwned<Tensor> running_mean_maybe_owned = at::borrow_from_optional_tensor(running_mean_opt);
const auto& running_mean = *running_mean_maybe_owned;
c10::MaybeOwned<Tensor> running_var_maybe_owned = at::borrow_from_optional_tensor(running_var_opt);
const auto& running_var = *running_var_maybe_owned;
TORCH_CHECK(!training || (!input_bdim || ((!running_mean.defined() || running_mean_bdim) && (!running_var.defined() || running_var_bdim))),
"Batch norm got a batched tensor as input while the running_mean or running_var, which will be updated in place, ",
"were not batched.\nIf you are using a module and do not need eval mode, please set `track_running_stats` to be False.",
"If you are using a prebuilt module and do not need eval mode, please see the functorch website for resources on ",
"how to patch your module to work with vmap");
std::optional<int64_t> bdim_size;
Tensor result0;
Tensor mean;
Tensor rstd;
if (!input_bdim && !running_mean_bdim && !running_var_bdim) {
const auto dummy_weight = at::ones(input.size(1), input.options()); // cudnn and miopen require a weight
const auto dummy_bias = at::zeros(input.size(1), input.options()); // without this, get "strides() called on undefined Tensor" on cuda
auto result = Func(input, dummy_weight, dummy_bias, running_mean_opt, running_var_opt, training, momentum, eps);
result0 = std::get<0>(result).transpose(0, 1); // [C, B, *]
mean = std::move(std::get<1>(result));
rstd = std::move(std::get<2>(result));
} else {
bdim_size = get_bdim_size3(input, input_bdim, running_mean, running_mean_bdim, running_var, running_var_bdim);
auto input_ = moveBatchDimToFront(input, input_bdim);
input_ = ensure_has_bdim(input_, input_bdim.has_value(), bdim_size.value());
input_ = reshape_dim_into(0, /*channels dim*/1, input_);
std::optional<Tensor> running_mean_;
std::optional<Tensor> running_var_;
if (running_mean.defined()) {
running_mean_ = moveBatchDimToFront(running_mean, running_mean_bdim);
running_mean_ = ensure_has_bdim(*running_mean_, running_mean_bdim.has_value(), bdim_size.value());
running_mean_ = reshape_dim_into(0, 0, *running_mean_).contiguous();
}
if (running_var.defined()) {
running_var_ = moveBatchDimToFront(running_var, running_var_bdim);
running_var_ = ensure_has_bdim(*running_var_, running_var_bdim.has_value(), bdim_size.value());
running_var_ = reshape_dim_into(0, 0, *running_var_).contiguous();
}
const auto dummy_weight = at::ones(input_.size(1), input_.options()); // cudnn and miopen require a weight
const auto dummy_bias = at::zeros(input_.size(1), input_.options()); // without this, get "strides() called on undefined Tensor" on cuda
auto result = Func(input_, dummy_weight, dummy_bias, running_mean_, running_var_, training, momentum, eps);
result0 = std::get<0>(result).transpose(0, 1); // [(B0, C), B, *]
mean = std::move(std::get<1>(result));
rstd = std::move(std::get<2>(result));
result0 = reshape_dim_outof(0, bdim_size.value(), result0); // [B0, C, B, *]
mean = reshape_dim_outof(0, bdim_size.value(), mean); // [B0, C]
rstd = reshape_dim_outof(0, bdim_size.value(), rstd); // [B0, C]
}
const auto stats_bdim = compute_stat_bdim(bdim_size, mean);
if (weight.defined()) {
const auto input_logical_rank = rankWithoutBatchDim(input, input_bdim);
auto weight_ = moveBatchDimToFront(weight, weight_bdim);
weight_ = padRight(weight_, weight_bdim, input_logical_rank);
result0 = result0 * weight_;
}
if (bias.defined()) {
const auto result_logical_rank = rankWithoutBatchDim(
result0,
bdim_size.has_value() || weight_bdim.has_value() ? std::optional<int64_t>(0) : std::optional<int64_t>(std::nullopt));
auto bias_ = moveBatchDimToFront(bias, bias_bdim);
bias_ = padRight(bias_, bias_bdim, result_logical_rank);
result0 = result0 + bias_;
}
result0 = result0.transpose(1, 2); // [B0, B, C, *], because some arg must have been batched, the output must be batched
return std::make_tuple(std::move(result0), 0, std::move(mean), stats_bdim, std::move(rstd), stats_bdim);
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free