kSpatialDim Class — pytorch Architecture
Architecture documentation for the kSpatialDim class in conv_serialization.h from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/quantized/cpu/conv_serialization.h lines 88–203
template <uint32_t kSpatialDim>
ConvParamsSerializationTypeV3 parse_conv_serialized_state(const c10::IValue& v) {
// determine the version based on IValue contents
int version = -1;
if (v.isTuple()) {
const auto& elements = v.toTupleRef().elements();
if (!elements.empty()) {
auto firstElement = elements[0];
if (firstElement.isTensor()) {
version = 1;
} else if (firstElement.isString()) {
const std::string& version_str = firstElement.toStringRef();
// note: not parsing the string to automatically handle bad
// inputs
if (version_str == "2") {
version = 2;
}
} else if (firstElement.isInt()) {
auto raw_version = firstElement.toInt();
if (raw_version == 3) {
version = 3;
}
}
}
}
TORCH_INTERNAL_ASSERT(version != -1, "Unable to parse serialization version");
if (version == 1) {
// version 1 - convert to version 3 manually
const auto& elements = v.toTupleRef().elements();
at::Tensor weight = elements[0].toTensor();
std::optional<at::Tensor> bias = elements[1].toOptional<at::Tensor>();
torch::List<at::Tensor> stride_x_kSpatialDim = elements[2].toTensorList();
torch::List<at::Tensor> padding_x_kSpatialDim = elements[3].toTensorList();
torch::List<at::Tensor> dilation_x_kSpatialDim = elements[4].toTensorList();
at::Tensor groups = elements[5].toTensor();
std::vector<int64_t> config_vals;
config_vals.reserve(
stride_x_kSpatialDim.size() + padding_x_kSpatialDim.size() +
dilation_x_kSpatialDim.size() + kSpatialDim + 3);
config_vals.push_back(kSpatialDim);
for (const auto i : c10::irange(stride_x_kSpatialDim.size())) {
auto const & stride = stride_x_kSpatialDim.get(i);
config_vals.push_back(stride[0].item<int16_t>());
}
for (const auto i : c10::irange(padding_x_kSpatialDim.size())) {
auto const &padding = padding_x_kSpatialDim.get(i);
config_vals.push_back(padding[0].item<int16_t>());
}
for (const auto i : c10::irange(dilation_x_kSpatialDim.size())) {
auto const &dilation = dilation_x_kSpatialDim.get(i);
config_vals.push_back(dilation[0].item<int16_t>());
}
// output_padding does not exist in v1, so we fill in a default value
for ([[maybe_unused]] const auto i : c10::irange(kSpatialDim)) {
config_vals.push_back(0);
}
config_vals.push_back(groups[0].item<int16_t>());
// transpose does not exist in v1, so we fill in a default value
config_vals.push_back(0);
std::vector<std::optional<at::Tensor>> tensors;
tensors.emplace_back();
tensors.emplace_back(weight);
tensors.emplace_back(bias);
int64_t version = 3;
return std::tie(version, config_vals, tensors);
} else if (version == 2) {
// version 2
const auto& elements = v.toTupleRef().elements();
std::vector<at::Tensor> non_optional = elements[1].toTensorList().vec();
std::vector<std::optional<at::Tensor>> optional;
if (elements[2].isTensorList()) {
for (const auto& elem : elements[2].toTensorList()) {
optional.emplace_back(static_cast<at::Tensor>(elem));
}
} else {
for (const auto& elem : elements[2].toList()) {
optional.emplace_back(static_cast<c10::IValue>(elem).toOptional<at::Tensor>());
}
}
// create default optional value for bias
if (optional.empty()) {
optional.emplace_back();
}
auto config_a = non_optional[0].accessor<int16_t, 1>();
std::vector<int64_t> config_vals;
config_vals.reserve(config_a.size(0));
for (const auto i : c10::irange(config_a.size(0))) {
config_vals.emplace_back(config_a[i]);
}
auto weight = non_optional[1];
auto bias = optional[0];
std::vector<std::optional<at::Tensor>> tensors;
tensors.emplace_back();
tensors.emplace_back(weight);
tensors.emplace_back(bias);
int64_t version = 3;
return std::tie(version, config_vals, tensors);
} else if (version == 3) {
return v.to<ConvParamsSerializationTypeV3>();
} else {
TORCH_INTERNAL_ASSERT(false, "Unexpected serialized qconv version: ",
version);
}
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free