Home / Class/ fractional_max_pool2d_out_single_batch_frame Class — pytorch Architecture

fractional_max_pool2d_out_single_batch_frame Class — pytorch Architecture

Architecture documentation for the fractional_max_pool2d_out_single_batch_frame class in FractionalMaxPool2d.cpp from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/FractionalMaxPool2d.cpp lines 132–188

template <typename scalar_t>
void fractional_max_pool2d_out_single_batch_frame(
  const scalar_t* input,
  scalar_t* output,
  int64_t* indices,
  const scalar_t* randomSamples,
  int numPlanes,
  int inputW, int inputH,
  int outputW, int outputH,
  int poolSizeW, int poolSizeH) {
  at::parallel_for(0, numPlanes, 0, [&](int64_t start, int64_t end) {
    for (const auto plane : c10::irange(start, end)) {
      /* each plane contains 2 random samples, one for W and one for H */
      const scalar_t* randomSamplesForPlane = randomSamples + plane * 2;

      /* Generate interval sequence */
      auto sequenceW = generate_intervals<scalar_t>(
          randomSamplesForPlane[0], inputW, outputW, poolSizeW);
      auto sequenceH = generate_intervals<scalar_t>(
          randomSamplesForPlane[1], inputH, outputH, poolSizeH);

      /* loop over output */
      const scalar_t* inputForPlane = input + plane * inputW * inputH;
      scalar_t* outputForPlane = output + plane * outputW * outputH;
      int64_t* indicesForPlane = indices + plane * outputW * outputH;

      for (int h = 0; h < outputH; ++h) {
        int inputHStart = sequenceH[h];

        for (int w = 0; w < outputW; ++w) {
          int inputWStart = sequenceW[w];

          int h2 = inputHStart, w2 = inputWStart;
          scalar_t maxVal = -std::numeric_limits<scalar_t>::infinity();
          int64_t maxIndex = h2 * inputW + w2;

          for (h2 = inputHStart; h2 < inputHStart + poolSizeH; ++h2) {
            for (w2 = inputWStart; w2 < inputWStart + poolSizeW; ++w2) {
              AT_ASSERT(h2 >= 0 && h2 < inputH);
              AT_ASSERT(w2 >= 0 && w2 < inputW);

              int planeIndex = h2 * inputW + w2;
              scalar_t val = inputForPlane[planeIndex];
              if (val > maxVal || std::isnan(val)) {
                maxVal = val;
                maxIndex = planeIndex;
              }
            }
          }

          outputForPlane[h * outputW + w] = maxVal;
          indicesForPlane[h * outputW + w] = maxIndex;
        }
      }
    }
  });
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free