Home / Class/ sgd_fused_step_impl Class — pytorch Architecture

sgd_fused_step_impl Class — pytorch Architecture

Architecture documentation for the sgd_fused_step_impl class in FusedSGDKernel.cpp from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/cpu/FusedSGDKernel.cpp lines 184–233

template <typename scalar_t>
void sgd_fused_step_impl(
    const at::Tensor& param,
    const at::Tensor& grad,
    const at::Tensor& momentum_buffer,
    const double weight_decay,
    const double momentum,
    const double lr,
    const double dampening,
    const bool nesterov,
    const bool maximize,
    const bool is_first_step,
    const float* grad_scale_ptr) {
  using opmath_t = at::opmath_type<scalar_t>;
  scalar_t* param_data = param.data_ptr<scalar_t>();
  scalar_t* grad_data = grad.data_ptr<scalar_t>();
  bool has_momentum_buffer = momentum != 0.0;
  scalar_t* momentum_buffer_data = has_momentum_buffer ? momentum_buffer.data_ptr<scalar_t>() : nullptr;

  constexpr size_t cache_line_size = 64;
  constexpr int64_t cache_line_aligned_task_unit = cache_line_size / sizeof(scalar_t);
  size_t num_units = divup(param.numel(), cache_line_aligned_task_unit);

  auto sgd_fn = [&](int64_t begin, int64_t end) {
        // local pointers
        begin *= cache_line_aligned_task_unit;
        end = std::min(end * cache_line_aligned_task_unit, param.numel());
        scalar_t* param_ptr = param_data + begin;
        scalar_t* grad_ptr = grad_data + begin;
        scalar_t* momentum_buffer_ptr = has_momentum_buffer ? momentum_buffer_data + begin : nullptr;

        const int64_t size = end - begin;
        sgd_math<scalar_t, opmath_t>(
          param_ptr,
          grad_ptr,
          momentum_buffer_ptr,
          weight_decay,
          momentum,
          lr,
          dampening,
          nesterov,
          maximize,
          is_first_step,
          grad_scale_ptr,
          size
        );
      };
  at::parallel_for(
      0, num_units, 0, sgd_fn);
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free