Home / Class/ miopenConvolutionBwdWeightsAlgoGEMM Class — pytorch Architecture

miopenConvolutionBwdWeightsAlgoGEMM Class — pytorch Architecture

Architecture documentation for the miopenConvolutionBwdWeightsAlgoGEMM class in Conv_miopen.cpp from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/miopen/Conv_miopen.cpp lines 492–563

template<>
struct algorithm_search<miopenConvBwdWeightsAlgorithm_t> {
  using perf_t = miopenConvAlgoPerf_t;
  using algo_t = miopenConvBwdWeightsAlgorithm_t;

  static constexpr auto DEFAULT_ALGO = miopenConvolutionBwdWeightsAlgoGEMM;
  static BenchmarkCache<algo_t>& cache() { return bwd_filter_algos; }
  static BenchmarkCache<size_t>& wsscache() { return bwd_filter_wssizes; }

  static perf_t findAlgorithm(const ConvolutionArgs& args, bool benchmark) {
    int perf_count;
    perf_t perf_results;
    size_t max_ws_size = getWorkspaceSize(args, DEFAULT_ALGO);
    Workspace ws(max_ws_size);
    MIOPEN_CHECK(miopenFindConvolutionBackwardWeightsAlgorithm(
        args.handle,
        args.odesc.desc(), args.output.const_data_ptr(),
        args.idesc.desc(), args.input.const_data_ptr(),
        args.cdesc.desc(),
        args.wdesc.desc(), args.weight.data_ptr(),
        1,      // just return the fastest
        &perf_count,
        &perf_results,
        ws.data,
        ws.size,
        benchmark));
    return perf_results;
  }

  static miopenConvSolution_t getSolution(const ConvolutionArgs& args, bool force_default) {
    size_t max_solution_count;
    size_t solution_count;
    miopenConvSolution_t solutions[AT_MIOPEN_MAX_SOLUTIONS];
    MIOPEN_CHECK(miopenConvolutionBackwardWeightsGetSolutionCount(
        args.handle,
        args.odesc.desc(),
        args.idesc.desc(),
        args.cdesc.desc(),
        args.wdesc.desc(),
        &max_solution_count));
    if (max_solution_count > AT_MIOPEN_MAX_SOLUTIONS) {
        TORCH_CHECK(false, "miopenConvBwdWeightsAlgorithm_t getSolution max_solution_count > AT_MIOPEN_MAX_SOLUTIONS");
    }
    MIOPEN_CHECK(miopenConvolutionBackwardWeightsGetSolution(
        args.handle,
        args.odesc.desc(),
        args.idesc.desc(),
        args.cdesc.desc(),
        args.wdesc.desc(),
        max_solution_count,
        &solution_count,
        solutions));

    if (force_default) {
        // find default alg
        for (size_t i=0; i<solution_count; ++i) {
            if (solutions[i].algorithm == (miopenConvAlgorithm_t)DEFAULT_ALGO) {
                return solutions[i];
            }
        }
        // default algo was not found, select first algo without workspace requirement
        for (size_t i=0; i<solution_count; ++i) {
            if (solutions[i].workspace_size == 0) {
                return solutions[i];
            }
        }
        // now what? fall through and hope for the best
    }

    return solutions[0];
  }
};

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free