Home / Class/ _select_conv_backend Class — pytorch Architecture

_select_conv_backend Class — pytorch Architecture

Architecture documentation for the _select_conv_backend class in Convolution.cpp from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/Convolution.cpp lines 1189–1296

template <typename T>
static ConvBackend _select_conv_backend(
    const Tensor& input,
    const Tensor& weight,
    const std::optional<Tensor>& bias,
    const at::OptionalArrayRef<T> bias_sizes_opt,
    const bool need_backward,
    const ConvParams<T>& params) {

  // don't send empty inputs through backends
  if (at::symint::size<T>(input, 0) == 0 || at::symint::size<T>(input, 1) == 0) {
    return input.is_mkldnn() ? ConvBackend::MkldnnEmpty : ConvBackend::Empty;
  } else if (at::symint::numel<T>(input) == 0) {
    TORCH_CHECK(false, "Only zero batch or zero channel inputs are supported, but got input shape: ", at::symint::sizes<T>(input));
  }

  if (params.is_depthwise(input, weight)) {
    if (params.use_cudnn_depthwise(input, weight)) {
      return ConvBackend::Cudnn;
    } else if (params.use_miopen(input, weight, bias_sizes_opt.has_value())) {
      return ConvBackend::MiopenDepthwise;
    } else {
      if (input.ndimension() == 4) {
        return ConvBackend::CudaDepthwise2d;
      } else if (input.ndimension() == 5) {
        return ConvBackend::CudaDepthwise3d;
      } else {
        // unsupported
      }
    }
  } else if (params.use_cudnn(input, weight)) {
    if (params.transposed) {
      return ConvBackend::CudnnTranspose;
    } else {
      return ConvBackend::Cudnn;
    }
  } else if (params.use_miopen(input, weight, bias_sizes_opt.has_value())) {
    if (params.transposed) {
      return ConvBackend::MiopenTranspose;
    } else {
      return ConvBackend::Miopen;
    }
  } else if (params.use_mkldnn(input, weight)) {
    if (params.transposed) {
      return ConvBackend::MkldnnTranspose;
    } else {
      return ConvBackend::Mkldnn;
    }
  } else if (!need_backward && params.use_xnnpack(input, weight, bias_sizes_opt)) {
    // Using prepacked conv is preferred, but XNNPACK is still the fastest
    // option for NHWC.
    return ConvBackend::Xnnpack2d;
  // 3x3 depthwith convolutions implementation is inference only
  } else if (!need_backward && params.use_cpu_depthwise3x3_winograd(input, weight, bias)) {
    return ConvBackend::Winograd3x3Depthwise;
  } else if (
      !params.transposed && (input.ndimension() == 5) &&
      (input.device().is_cpu()) &&
      !params.is_dilated()) {
    // fast path for grouped conv3d
    return ConvBackend::Slow3d;
  } else if (input.device().is_cpu() || input.is_cuda()) {
    // backends without support for groups
    if (params.transposed) {
      if (input.ndimension() == 4) {
        return ConvBackend::SlowTranspose2d;
      } else if (input.ndimension() == 5) {
        return ConvBackend::SlowTranspose3d;
      } else {
        // unsupported
      }
    } else {  /* Not transposed */
      if (input.ndimension() == 4) {
        if (params.is_dilated()) {
          return ConvBackend::SlowDilated2d;
        } else {  /* dim == 4, non-dilated */
          if (params.use_nnpack(input, weight)) {
            return ConvBackend::NnpackSpatial;
          } else {
            /* CPU implementation has specialized MM kernels
               for non-dilated case here */
            return ConvBackend::Slow2d;
          }
        }
      } else if (input.ndimension() == 5 && (input.is_cuda() || params.is_dilated())) {
        return ConvBackend::SlowDilated3d;
      } else if (input.ndimension() == 5) { /* dim == 5, CPU, non-dilated */
        /* CPU implementation has specialized MM kernels
           for non-dilated case here */
        return ConvBackend::Slow3d;
      } else {
        // unsupported
      }
    }
  } else if (params.use_mps(input, weight)) {
    if (params.transposed) {
      return ConvBackend::MpsTranspose;
    } else {
      return ConvBackend::Mps;
    }
  } else {
    // Only reach here when input is backend with out-of-source implementation.
    return ConvBackend::Overrideable;
  }

  // Error out if no suitable backend was found.
  TORCH_CHECK(false, "unsupported ConvNd parameters");
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free