Func Class — pytorch Architecture
Architecture documentation for the Func class in BatchRulesModules.cpp from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/functorch/BatchRulesModules.cpp lines 116–141
template<typename F, F Func, typename... ExtraArgs>
std::tuple<Tensor, std::optional<int64_t>>
static grid_sample_batch_rule(const Tensor& input, std::optional<int64_t> input_bdim, const Tensor& grid, std::optional<int64_t> grid_bdim, ExtraArgs... extra_args) {
std::tuple<Tensor, std::optional<int64_t>> result;
if (input_bdim && !grid_bdim) {
auto new_input = reshape_dim_into(*input_bdim, 1, input);
auto out = Func(new_input, grid, std::forward<ExtraArgs>(extra_args)...);
out = reshape_dim_outof(1, input.sizes()[*input_bdim], out);
result = std::make_tuple(std::move(out), 1);
} else if (!input_bdim && grid_bdim) {
// grid of N(BH)W2 -> NC(BH)W or grid of N(BD)HBW3 -> NC(BD)HW
auto new_grid = reshape_dim_into(*grid_bdim, 1, grid);
auto out = Func(input, new_grid, std::forward<ExtraArgs>(extra_args)...);
out = reshape_dim_outof(2, grid.sizes()[*grid_bdim], out);
result = std::make_tuple(std::move(out), 2);
} else if (input_bdim && grid_bdim) {
auto new_input = reshape_dim_into(*input_bdim, 0, input);
auto new_grid = reshape_dim_into(*grid_bdim, 0, grid);
auto out = Func(new_input, new_grid, std::forward<ExtraArgs>(extra_args)...);
out = reshape_dim_outof(0, input.sizes()[*grid_bdim], out);
result = std::make_tuple(std::move(out), 0);
} else {
result = std::make_tuple(Func(input, grid, std::forward<ExtraArgs>(extra_args)...), std::nullopt);
}
return result;
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free