Addmm Class — pytorch Architecture
Architecture documentation for the Addmm class in vulkan_api_test.cpp from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/test/vulkan_api_test.cpp lines 7173–7205
class Addmm final : public BaseOp {
public:
Addmm(
const int64_t m1H,
const int64_t m1W,
const int64_t m2W,
const float beta,
const float alpha)
: BaseOp(OpType::addmm),
m2_(at::rand(c10::IntArrayRef({m1W, m2W}), at::device(at::kCPU).dtype(at::kFloat))),
b_(at::rand(c10::IntArrayRef({m1H, m2W}), at::device(at::kCPU).dtype(at::kFloat))),
beta_(beta),
alpha_(alpha) {
}
at::Tensor run(at::Tensor& t) const override {
if (t.is_vulkan()) {
return at::addmm(b_, t, m2_, beta_, alpha_);
}
return at::addmm(b_, t, m2_, beta_, alpha_);
}
std::string toString() const override {
return "addmm";
}
private:
at::Tensor m2_;
at::Tensor b_;
float beta_;
float alpha_;
};
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free