Home / Class/ Func Class — pytorch Architecture

Func Class — pytorch Architecture

Architecture documentation for the Func class in BatchRulesModules.cpp from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/functorch/BatchRulesModules.cpp lines 116–141

template<typename F, F Func, typename... ExtraArgs>
std::tuple<Tensor, std::optional<int64_t>>
static grid_sample_batch_rule(const Tensor& input, std::optional<int64_t> input_bdim, const Tensor& grid, std::optional<int64_t> grid_bdim, ExtraArgs... extra_args) {
  std::tuple<Tensor, std::optional<int64_t>> result;
  if (input_bdim && !grid_bdim) {
    auto new_input = reshape_dim_into(*input_bdim, 1, input);
    auto out = Func(new_input, grid, std::forward<ExtraArgs>(extra_args)...);
    out = reshape_dim_outof(1, input.sizes()[*input_bdim], out);
    result = std::make_tuple(std::move(out), 1);
  } else if (!input_bdim && grid_bdim) {
    // grid of N(BH)W2 -> NC(BH)W or grid of N(BD)HBW3 -> NC(BD)HW
    auto new_grid = reshape_dim_into(*grid_bdim, 1, grid);
    auto out = Func(input, new_grid, std::forward<ExtraArgs>(extra_args)...);
    out = reshape_dim_outof(2, grid.sizes()[*grid_bdim], out);
    result = std::make_tuple(std::move(out), 2);
  } else if (input_bdim && grid_bdim) {
    auto new_input = reshape_dim_into(*input_bdim, 0, input);
    auto new_grid = reshape_dim_into(*grid_bdim, 0, grid);
    auto out = Func(new_input, new_grid, std::forward<ExtraArgs>(extra_args)...);
    out = reshape_dim_outof(0, input.sizes()[*grid_bdim], out);
    result = std::make_tuple(std::move(out), 0);
  } else {
    result = std::make_tuple(Func(input, grid, std::forward<ExtraArgs>(extra_args)...), std::nullopt);
  }
  return result;
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free