Home / Class/ slow_conv3d_backward_update_grad_input_frame Class — pytorch Architecture

slow_conv3d_backward_update_grad_input_frame Class — pytorch Architecture

Architecture documentation for the slow_conv3d_backward_update_grad_input_frame class in ConvolutionMM3d.cpp from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/ConvolutionMM3d.cpp lines 337–394

template <typename scalar_t>
void slow_conv3d_backward_update_grad_input_frame(
    TensorAccessor<scalar_t, 4> grad_input,
    TensorAccessor<const scalar_t, 4> grad_output,
    TensorAccessor<const scalar_t, 2> weight,
    TensorAccessor<scalar_t, 2> fgrad_input,
    int64_t kernel_depth,
    int64_t kernel_height,
    int64_t kernel_width,
    int64_t stride_depth,
    int64_t stride_height,
    int64_t stride_width,
    int64_t pad_depth,
    int64_t pad_height,
    int64_t pad_width,
    int64_t groups) {
  // Compute fgrad_input = weight.T * grad_output.reshape({grad_output.shape(0), -1})
  // Note gemm expects fortran order, so all 3 matrices are transposed.
  // Swapping argument order cancels this, since C == AB <=> T(C) == T(B)T(A)
  const int64_t m = grad_output.size(1) * grad_output.size(2) * grad_output.size(3);
  const int64_t n = weight.size(1);
  const int64_t k = weight.size(0) / groups;

  const int64_t lda = m;
  const int64_t ldb = n;
  const int64_t ldc = m;

  at::native::cpublas::gemm_batched_with_stride(
      TransposeType::NoTranspose,
      TransposeType::Transpose,
      groups, m, n, k,
      static_cast<scalar_t>(1),
      grad_output.data(), lda, grad_output.stride(0) * k,
      weight.data(), ldb, weight.stride(0) * k,
      static_cast<scalar_t>(0),
      fgrad_input.data(), ldc, fgrad_input.stride(0) * n);

  Unfold3dAccCPU(
      c10::CppTypeToScalarType<scalar_t>::value,
      fgrad_input.data(),
      grad_input.size(0),
      grad_input.size(1),
      grad_input.size(2),
      grad_input.size(3),
      grad_output.size(1),
      grad_output.size(2),
      grad_output.size(3),
      kernel_depth,
      kernel_height,
      kernel_width,
      stride_depth,
      stride_height,
      stride_width,
      pad_depth,
      pad_height,
      pad_width,
      grad_input.data());
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free