Home / Class/ CuFFTConfig Class — pytorch Architecture

CuFFTConfig Class — pytorch Architecture

Architecture documentation for the CuFFTConfig class in CuFFTPlanCache.h from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/cuda/CuFFTPlanCache.h lines 223–355

class CuFFTConfig {
public:

  // Only move semantics is enough for this class. Although we already use
  // unique_ptr for the plan, still remove copy constructor and assignment op so
  // we don't accidentally copy and take perf hit.
  CuFFTConfig(const CuFFTConfig&) = delete;
  CuFFTConfig& operator=(CuFFTConfig const&) = delete;

  explicit CuFFTConfig(const CuFFTParams& params):
      CuFFTConfig(
          IntArrayRef(params.input_strides_, params.signal_ndim_ + 1),
          IntArrayRef(params.output_strides_, params.signal_ndim_ + 1),
          IntArrayRef(params.sizes_, params.signal_ndim_ + 1),
          params.fft_type_,
          params.value_type_) {}

  // For complex types, strides are in units of 2 * element_size(dtype)
  // sizes are for the full signal, including batch size and always two-sided
  CuFFTConfig(IntArrayRef in_strides, IntArrayRef out_strides,
      IntArrayRef sizes, CuFFTTransformType fft_type, ScalarType dtype):
        fft_type_(fft_type), value_type_(dtype) {

    // signal sizes (excluding batch dim)
    CuFFTDimVector signal_sizes(sizes.begin() + 1, sizes.end());

    // input batch size
    const int64_t batch = sizes[0];
    const int64_t signal_ndim = sizes.size() - 1;

    // Since cuFFT has limited non-unit stride support and various constraints, we
    // use a flag to keep track throughout this function to see if we need to
    // input = input.clone();

#if defined(USE_ROCM)
    // clone input to avoid issues with hipfft clobering the input and failing tests
    clone_input = true;
#else
    clone_input = false;
#endif

    // For half, base strides on the real part of real-to-complex and
    // complex-to-real transforms are not supported. Since our output is always
    // contiguous, only need to check real-to-complex case.
    if (dtype == ScalarType::Half) {
      // cuFFT on half requires compute capability of at least SM_53
      auto dev_prop = at::cuda::getCurrentDeviceProperties();
      TORCH_CHECK(dev_prop->major >= 5 && !(dev_prop->major == 5 && dev_prop->minor < 3),
               "cuFFT doesn't support signals of half type with compute "
               "capability less than SM_53, but the device containing input half "
               "tensor only has SM_", dev_prop->major, dev_prop->minor);
      for (const auto i : c10::irange(signal_ndim)) {
        TORCH_CHECK(is_pow_of_two(sizes[i + 1]),
            "cuFFT only supports dimensions whose sizes are powers of two when"
            " computing in half precision, but got a signal size of",
            sizes.slice(1));
      }
      clone_input |= in_strides.back() != 1;
    }

    CuFFTDataLayout in_layout;
    if (clone_input) {
      in_layout = cufft_simple_embed(sizes, fft_type == CuFFTTransformType::C2R);
    } else {
      in_layout = as_cufft_embed(in_strides, sizes, fft_type == CuFFTTransformType::C2R);
    }
    auto out_layout = as_cufft_embed(out_strides, sizes, fft_type == CuFFTTransformType::R2C);
    TORCH_INTERNAL_ASSERT(!out_layout.must_clone, "Out strides cannot be represented as CuFFT embedding");
    clone_input |= in_layout.must_clone;

    // Check if we can take advantage of simple data layout.
    //
    // See NOTE [ cuFFT Embedded Strides ] in native/cuda/SpectralOps.cu.

    const bool simple_layout = in_layout.simple && out_layout.simple;
    cudaDataType itype, otype, exec_type;
    const auto complex_input = cufft_complex_input(fft_type);
    const auto complex_output = cufft_complex_output(fft_type);
    if (dtype == ScalarType::Float) {
      itype = complex_input ? CUDA_C_32F : CUDA_R_32F;
      otype = complex_output ? CUDA_C_32F : CUDA_R_32F;
      exec_type = CUDA_C_32F;
    } else if (dtype == ScalarType::Double) {
      itype = complex_input ? CUDA_C_64F : CUDA_R_64F;
      otype = complex_output ? CUDA_C_64F : CUDA_R_64F;
      exec_type = CUDA_C_64F;
    } else if (dtype == ScalarType::Half) {
      itype = complex_input ? CUDA_C_16F : CUDA_R_16F;
      otype = complex_output ? CUDA_C_16F : CUDA_R_16F;
      exec_type = CUDA_C_16F;
    } else {
      TORCH_CHECK(false, "cuFFT doesn't support tensor of type: ", dtype);
    }

    // disable auto allocation of workspace to use THC allocator
    CUFFT_CHECK(cufftSetAutoAllocation(plan(), /* autoAllocate */ 0));

    size_t ws_size_t;

    // make plan
    if (simple_layout) {
      // If with unit-stride, we tell cuFFT by setting inembed == onembed == NULL.
      // In such case, cuFFT ignores istride, ostride, idist, and odist
      // by assuming istride = ostride = 1.
      //
      // See NOTE [ cuFFT Embedded Strides ] in native/cuda/SpectralOps.cu.
      CUFFT_CHECK(cufftXtMakePlanMany(plan(), signal_ndim, signal_sizes.data(),
        /* inembed */ nullptr, /* base_istride */ 1, /* idist */ 1, itype,
        /* onembed */ nullptr, /* base_ostride */ 1, /* odist */ 1, otype,
        batch, &ws_size_t, exec_type));
    } else {
      CUFFT_CHECK(cufftXtMakePlanMany(plan(), signal_ndim, signal_sizes.data(),
            in_layout.embed.data(), in_layout.stride, in_layout.dist, itype,
            out_layout.embed.data(), out_layout.stride, out_layout.dist, otype,
            batch, &ws_size_t, exec_type));
    }
    ws_size = static_cast<int64_t>(ws_size_t);
  }

  const cufftHandle &plan() const { return plan_ptr.get(); }

  CuFFTTransformType transform_type() const { return fft_type_; }
  ScalarType data_type() const { return value_type_; }
  bool should_clone_input() const { return clone_input; }
  int64_t workspace_size() const { return ws_size; }

private:
  CuFFTHandle plan_ptr;
  bool clone_input;
  int64_t ws_size;
  CuFFTTransformType fft_type_;
  ScalarType value_type_;
};

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free