Attr Class — pytorch Architecture
Architecture documentation for the Attr class in Attr.h from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/mkldnn/xpu/detail/Attr.h lines 132–367
class Attr {
public:
Attr() : q_scale_(1.f) {}
Attr(float q_scale, int64_t zp = 0) : q_scale_(q_scale), q_zero_point_(zp) {}
/***** eltwise *****/
dnnl::algorithm kind_with_relu = dnnl::algorithm::eltwise_relu;
dnnl::algorithm kind_with_sigmoid = dnnl::algorithm::eltwise_logistic;
dnnl::algorithm kind_with_gelu_tanh = dnnl::algorithm::eltwise_gelu_tanh;
dnnl::algorithm kind_with_gelu_erf = dnnl::algorithm::eltwise_gelu_erf;
dnnl::algorithm kind_with_mish = dnnl::algorithm::eltwise_mish;
dnnl::algorithm kind_with_linear = dnnl::algorithm::eltwise_linear;
dnnl::algorithm kind_with_swish = dnnl::algorithm::eltwise_swish;
dnnl::algorithm kind_with_sqrt = dnnl::algorithm::eltwise_sqrt;
dnnl::algorithm kind_with_tanh = dnnl::algorithm::eltwise_tanh;
dnnl::algorithm kind_with_square = dnnl::algorithm::eltwise_square;
dnnl::algorithm kind_with_abs = dnnl::algorithm::eltwise_abs;
dnnl::algorithm kind_with_exp = dnnl::algorithm::eltwise_exp;
dnnl::algorithm kind_with_log = dnnl::algorithm::eltwise_log;
dnnl::algorithm kind_with_round = dnnl::algorithm::eltwise_round;
dnnl::algorithm kind_with_hardswish = dnnl::algorithm::eltwise_hardswish;
dnnl::algorithm kind_with_soft_relu = dnnl::algorithm::eltwise_soft_relu;
dnnl::algorithm kind_with_elu = dnnl::algorithm::eltwise_elu;
dnnl::algorithm kind_with_pow = dnnl::algorithm::eltwise_pow;
dnnl::algorithm kind_with_clip = dnnl::algorithm::eltwise_clip;
// note: hardsigmoid seems oneDNN still not support
dnnl::algorithm kind_with_hardsigmoid = dnnl::algorithm::eltwise_hardsigmoid;
/***** binary *****/
dnnl::algorithm kind_with_binary_mul = dnnl::algorithm::binary_mul;
dnnl::algorithm kind_with_binary_add = dnnl::algorithm::binary_add;
dnnl::algorithm kind_with_binary_sub = dnnl::algorithm::binary_sub;
dnnl::algorithm kind_with_binary_div = dnnl::algorithm::binary_div;
dnnl::algorithm kind_with_binary_eq = dnnl::algorithm::binary_eq;
dnnl::algorithm kind_with_binary_ne = dnnl::algorithm::binary_ne;
dnnl::algorithm kind_with_binary_ge = dnnl::algorithm::binary_ge;
dnnl::algorithm kind_with_binary_gt = dnnl::algorithm::binary_gt;
dnnl::algorithm kind_with_binary_le = dnnl::algorithm::binary_le;
dnnl::algorithm kind_with_binary_lt = dnnl::algorithm::binary_lt;
dnnl::algorithm kind_with_binary_max = dnnl::algorithm::binary_max;
dnnl::algorithm kind_with_binary_min = dnnl::algorithm::binary_min;
// append sum post op
Attr& append_post_sum(
float sum_scale,
float sum_q_scale = 1.f,
int64_t zp = 0) {
ops_params_.push_back(
PostOpParam(/*scale_sum*/ sum_scale * sum_q_scale, zp, kind_t::sum));
return *this;
}
// append eltwise post op
Attr& append_post_eltwise(
float scale,
float alpha,
float beta,
dnnl::algorithm algo) {
ops_params_.push_back(
PostOpParam(scale, alpha, beta, algo, kind_t::eltwise));
return *this;
}
// append binary post op
template <bool is_matmul = false>
Attr& append_post_binary(dnnl::algorithm algo, const at::Tensor& binary) {
auto binary_ = binary.is_quantized() ? at::dequantize(binary) : binary;
bool binary_is_channels_last =
(binary_.suggest_memory_format() == at::MemoryFormat::ChannelsLast ||
binary_.suggest_memory_format() == at::MemoryFormat::ChannelsLast3d);
if constexpr (!is_matmul) {
binary_ = binary_is_channels_last ? binary_ : binary_.contiguous();
}
dnnl::memory::desc md = get_onednn_md(binary_);
auto expected_md = dnnl::memory::desc(
md.get_dims(), md.get_data_type(), dnnl::memory::format_tag::any);
if constexpr (is_matmul) {
ops_params_.push_back(PostOpParam(binary_, md, md, algo, kind_t::binary));
} else {
ops_params_.push_back(
PostOpParam(binary_, md, expected_md, algo, kind_t::binary));
}
return *this;
}
Attr& append_scale_binary(
dnnl::algorithm algo,
at::Tensor binary,
float scale,
float sum_q_scale = 1.f,
int64_t zp = 0) {
ops_params_.push_back(PostOpParam(
binary, /*scale_sum*/ scale * sum_q_scale, algo, kind_t::binary));
return *this;
}
// append bias with binary_add method (only used for QConv now)
Attr& append_bias(const at::Tensor& binary, const int ndimension) {
// In PyTorch, bias are in shape of [OC],
// we expand its shape according to Conv dimension
// Conv1d [OC, 1, 1], Conv2d [1, OC, 1, ,1], Conv3d [1, OC, 1, 1, 1]
at::Tensor binary_ = binary.contiguous();
dnnl::memory::desc binary_md;
switch (ndimension) {
case 1:
binary_md = dnnl::memory::desc(
{binary.size(0), 1, 1},
dnnl::memory::data_type::f32,
dnnl::memory::format_tag::abc);
break;
case 2:
binary_md = dnnl::memory::desc(
{1, binary.size(0), 1, 1},
dnnl::memory::data_type::f32,
dnnl::memory::format_tag::abcd);
break;
case 3:
binary_md = dnnl::memory::desc(
{1, binary.size(0), 1, 1, 1},
dnnl::memory::data_type::f32,
dnnl::memory::format_tag::abcde);
break;
default:
TORCH_INTERNAL_ASSERT(
0, "XPU only supports append_bias for Conv1d, Conv2d and Conv3d.");
}
// In this case, expected_md = binary_md
ops_params_.push_back(PostOpParam(
binary_, binary_md, binary_md, kind_with_binary_add, kind_t::binary));
return *this;
}
// append prelu post op
Attr& append_post_prelu(int mask) {
ops_params_.push_back(PostOpParam(mask, kind_t::prelu));
return *this;
}
dnnl::post_ops extract_post_ops(const at::Tensor& dst) {
// this function is used to extract post ops params from the ops_params_
// and put them into onednn post ops
for (size_t i = 0; i < ops_params_.size(); ++i) {
kind_t kind = ops_params_[i].kind_;
switch (kind) {
case kind_t::eltwise: {
dnnl::algorithm algo = ops_params_[i].algo_;
float alpha = ops_params_[i].alpha_;
float beta = ops_params_[i].beta_;
dnnl_post_ops_.append_eltwise(algo, alpha, beta);
break;
}
case kind_t::sum: {
float scale = ops_params_[i].scale_;
int64_t zero_point = ops_params_[i].zero_point_;
// TODO [Asymmetric]:
// Post-sum zp for gpu is not supported currently
dnnl_post_ops_.append_sum(scale, zero_point);
break;
}
case kind_t::binary: {
dnnl::algorithm algo = ops_params_[i].algo_;
auto expected_md = ops_params_[i].expected_meta_;
// In this case user may create src1 memory descriptor with
// format_tag::any or set a specific tag. However, in later case if
// tags mismatch with dst, it would result in suboptimal performance.
// So here we use format_tag::any to make sure the fast can be
// selected.
// Thus we use expected_md (with format_any) here to create pd instead
// of original md
dnnl_post_ops_.append_binary(algo, expected_md);
break;
}
default:
break;
}
}
return dnnl_post_ops_;
}
bool with_sum() {
for (size_t i = 0; i < ops_params_.size(); ++i) {
if (ops_params_[i].kind_ == kind_t::sum) {
return true;
}
}
return false;
}
bool with_binary() {
for (size_t i = 0; i < ops_params_.size(); ++i) {
if (ops_params_[i].kind_ == kind_t::binary) {
return true;
}
}
return false;
}
void construct_post_binary(
dnnl::primitive_desc& pd,
std::unordered_map<int, dnnl::memory>& args) {
// This function is used to construct binary memory desc in binary post ops.
// According to oneDNN doc, the binary tensor can be in shape of
// [1, 1, 1, 1], tensor broadcast
// [1, C, 1, 1], channel broadcast
// [dst.shape], no broadcast and eltwise-wise binary operations on dst
auto& engine = GpuEngineManager::Instance().get_engine();
for (size_t i = 0; i < ops_params_.size(); ++i) {
kind_t kind = ops_params_[i].kind_;
if (kind == kind_t::binary) {
dnnl::memory binary_m;
auto binary = ops_params_[i].binary_;
auto md = ops_params_[i].meta_;
// query expected_md to achieve peak performance
auto expected_md = pd.query_md(
dnnl::query::exec_arg_md,
DNNL_ARG_ATTR_MULTIPLE_POST_OP(i) | DNNL_ARG_SRC_1);
binary_m = at::native::onednn::make_onednn_memory(
md, engine, binary.data_ptr());
args.insert(
{DNNL_ARG_ATTR_MULTIPLE_POST_OP(i) | DNNL_ARG_SRC_1, binary_m});
}
}
}
float q_scale_ = 1.0; // the scale used to quantize the fused result from fp32
// to int8, only works for int8 case
int64_t q_zero_point_ = 0;
std::vector<PostOpParam> ops_params_; // series of post ops
dnnl::post_ops dnnl_post_ops_;
};
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free