Home / Class/ kSpatialDim Class — pytorch Architecture

kSpatialDim Class — pytorch Architecture

Architecture documentation for the kSpatialDim class in conv_serialization.h from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/quantized/cpu/conv_serialization.h lines 88–203

template <uint32_t kSpatialDim>
ConvParamsSerializationTypeV3 parse_conv_serialized_state(const c10::IValue& v) {

  // determine the version based on IValue contents
  int version = -1;
  if (v.isTuple()) {
    const auto& elements = v.toTupleRef().elements();
    if (!elements.empty()) {
      auto firstElement = elements[0];
      if (firstElement.isTensor()) {
        version = 1;
      } else if (firstElement.isString()) {
        const std::string& version_str = firstElement.toStringRef();
        // note: not parsing the string to automatically handle bad
        // inputs
        if (version_str == "2") {
          version = 2;
        }
      } else if (firstElement.isInt()) {
        auto raw_version = firstElement.toInt();
        if (raw_version == 3) {
          version = 3;
        }
      }
    }
  }
  TORCH_INTERNAL_ASSERT(version != -1, "Unable to parse serialization version");

  if (version == 1) {
    // version 1 - convert to version 3 manually

    const auto& elements = v.toTupleRef().elements();

    at::Tensor weight = elements[0].toTensor();
    std::optional<at::Tensor> bias = elements[1].toOptional<at::Tensor>();
    torch::List<at::Tensor> stride_x_kSpatialDim = elements[2].toTensorList();
    torch::List<at::Tensor> padding_x_kSpatialDim = elements[3].toTensorList();
    torch::List<at::Tensor> dilation_x_kSpatialDim = elements[4].toTensorList();
    at::Tensor groups = elements[5].toTensor();

    std::vector<int64_t> config_vals;
    config_vals.reserve(
        stride_x_kSpatialDim.size() + padding_x_kSpatialDim.size() +
        dilation_x_kSpatialDim.size() + kSpatialDim + 3);
    config_vals.push_back(kSpatialDim);
    for (const auto i : c10::irange(stride_x_kSpatialDim.size())) {
      auto const & stride = stride_x_kSpatialDim.get(i);
      config_vals.push_back(stride[0].item<int16_t>());
    }
    for (const auto i : c10::irange(padding_x_kSpatialDim.size())) {
      auto const &padding = padding_x_kSpatialDim.get(i);
      config_vals.push_back(padding[0].item<int16_t>());
    }
    for (const auto i : c10::irange(dilation_x_kSpatialDim.size())) {
      auto const &dilation = dilation_x_kSpatialDim.get(i);
      config_vals.push_back(dilation[0].item<int16_t>());
    }
    // output_padding does not exist in v1, so we fill in a default value
    for ([[maybe_unused]] const auto i : c10::irange(kSpatialDim)) {
      config_vals.push_back(0);
    }
    config_vals.push_back(groups[0].item<int16_t>());
    // transpose does not exist in v1, so we fill in a default value
    config_vals.push_back(0);

    std::vector<std::optional<at::Tensor>> tensors;
    tensors.emplace_back();
    tensors.emplace_back(weight);
    tensors.emplace_back(bias);

    int64_t version = 3;
    return std::tie(version, config_vals, tensors);
  } else if (version == 2) {
    // version 2
    const auto& elements = v.toTupleRef().elements();
    std::vector<at::Tensor> non_optional = elements[1].toTensorList().vec();
    std::vector<std::optional<at::Tensor>> optional;

    if (elements[2].isTensorList()) {
      for (const auto& elem : elements[2].toTensorList()) {
        optional.emplace_back(static_cast<at::Tensor>(elem));
      }
    } else {
      for (const auto& elem : elements[2].toList()) {
        optional.emplace_back(static_cast<c10::IValue>(elem).toOptional<at::Tensor>());
      }
    }
    // create default optional value for bias
    if (optional.empty()) {
      optional.emplace_back();
    }

    auto config_a = non_optional[0].accessor<int16_t, 1>();
    std::vector<int64_t> config_vals;
    config_vals.reserve(config_a.size(0));
    for (const auto i : c10::irange(config_a.size(0))) {
      config_vals.emplace_back(config_a[i]);
    }

    auto weight = non_optional[1];
    auto bias = optional[0];

    std::vector<std::optional<at::Tensor>> tensors;
    tensors.emplace_back();
    tensors.emplace_back(weight);
    tensors.emplace_back(bias);

    int64_t version = 3;
    return std::tie(version, config_vals, tensors);
  } else if (version == 3) {
    return v.to<ConvParamsSerializationTypeV3>();
  } else {
    TORCH_INTERNAL_ASSERT(false, "Unexpected serialized qconv version: ",
        version);
  }
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free