Home / Class/ AlgoIterator Class — pytorch Architecture

AlgoIterator Class — pytorch Architecture

Architecture documentation for the AlgoIterator class in Conv_v7.cpp from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/cudnn/Conv_v7.cpp lines 526–583

template <typename perf_t>
class AlgoIterator {
  using search = algorithm_search<perf_t>;
  const ConvolutionArgs& args;
  bool benchmark;

 public:
  AlgoIterator(const ConvolutionArgs& args, bool benchmark)
      : args(args), benchmark(benchmark) {}

  static std::vector<perf_t> onlyDefaultAlgorithm(const ConvolutionArgs& args) {
    std::vector<perf_t> perfResults(1);
    perfResults[0].algo = search::DEFAULT_ALGO;
    if (args.params.dataType == CUDNN_DATA_HALF) {
      perfResults[0].mathType = CUDNN_TENSOR_OP_MATH;
    } else {
      perfResults[0].mathType = CUDNN_DEFAULT_MATH;
      if (args.params.dataType == CUDNN_DATA_FLOAT && !args.params.allow_tf32) {
        perfResults[0].mathType = CUDNN_FMA_MATH;
      }
    }
    search::getWorkspaceSize(
        args, perfResults[0].algo, &(perfResults[0].memory));
    return perfResults;
  }

  void try_all(std::function<void(const perf_t& perf)> f) {
    bool only_use_default = args.params.deterministic && !benchmark;

    auto& cache = search::cache();
    perf_t algoPerf;
    if (!only_use_default && cache.find(args.params, &algoPerf)) {
      try {
        f(algoPerf);
        return;
      } catch (c10::OutOfMemoryError&) {
        std::ignore = cudaGetLastError(); // clear CUDA error
      }
    }

    auto perfResults = only_use_default
        ? onlyDefaultAlgorithm(args)
        : search::findAlgorithms(args, benchmark);
    for (auto& algoPerf : perfResults) {
      try {
        f(algoPerf);
        cache.insert(args.params, algoPerf);
        return;
      } catch (c10::OutOfMemoryError&) {
        std::ignore = cudaGetLastError(); // clear CUDA error
      } catch (c10::CuDNNError&) {
        std::ignore = cudaGetLastError(); // clear CUDA error
      }
    }
    TORCH_CHECK(
        false, "Unable to find a valid cuDNN algorithm to run convolution");
  }
};

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free