Program Listing for File syntax.cc
↰ Return to documentation for file (/WorkSpace/CINN/cinn/frontend/syntax.cc
)
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "cinn/frontend/syntax.h"
#include <absl/types/variant.h>
#include <iomanip>
#include <memory>
#include <sstream>
#include <tuple>
#include <type_traits>
#include <utility>
#include "cinn/frontend/paddle/model_parser.h"
#include "cinn/frontend/paddle_model_to_program.h"
#include "cinn/hlir/framework/node.h"
#include "cinn/hlir/framework/op.h"
#include "cinn/utils/string.h"
namespace cinn {
namespace frontend {
using hlir::framework::Scope;
void Instruction::PrepareOutputs() {
auto* op_def = hlir::framework::OpRegistry::Global()->Find(get()->op_type);
CHECK(op_def) << "No operator called [" << get()->op_type << "]";
for (int i = 0; i < op_def->num_outputs; i++) {
get()->outputs.push_back(Variable());
}
}
Instruction::Instruction(absl::string_view op_type, const std::vector<Variable>& inputs, Program* parent)
: common::Shared<_Instruction_>(common::make_shared<_Instruction_>()) {
get()->op_type = std::string(op_type);
get()->parent_program = parent;
get()->inputs = inputs;
PrepareOutputs();
}
Placeholder::operator Variable() const { return var_; }
Variable Program::conv2d(const Variable& a,
const Variable& b,
const absl::flat_hash_map<std::string, attr_t>& attr_store) {
Instruction instr("conv2d");
instr.SetInputs({a, b});
for (auto& iter : attr_store) {
instr.SetAttr(iter.first, iter.second);
}
AppendInstruction(instr);
return instr.GetOutput(0);
}
Variable Program::layout_transform(const Variable& a, const absl::flat_hash_map<std::string, attr_t>& attr_store) {
Instruction instr("layout_transform");
instr.SetInputs({a});
for (auto& iter : attr_store) {
instr.SetAttr(iter.first, iter.second);
}
AppendInstruction(instr);
return instr.GetOutput(0);
}
Variable Program::conv2d_NCHWc(const Variable& a,
const Variable& b,
const absl::flat_hash_map<std::string, attr_t>& attr_store) {
Instruction instr("conv2d_NCHWc");
instr.SetInputs({a, b});
for (auto& iter : attr_store) {
instr.SetAttr(iter.first, iter.second);
}
AppendInstruction(instr);
return instr.GetOutput(0);
}
Variable Program::depthwise_conv2d(const Variable& a,
const Variable& b,
const absl::flat_hash_map<std::string, attr_t>& attr_store) {
Instruction instr("depthwise_conv2d");
instr.SetInputs({a, b});
for (auto& iter : attr_store) {
instr.SetAttr(iter.first, iter.second);
}
AppendInstruction(instr);
return instr.GetOutput(0);
}
Variable Program::pool2d(const Variable& a, const absl::flat_hash_map<std::string, attr_t>& attr_store) {
Instruction instr("pool2d");
instr.SetInputs({a});
for (auto& iter : attr_store) {
instr.SetAttr(iter.first, iter.second);
}
AppendInstruction(instr);
return instr.GetOutput(0);
}
Variable Program::batchnorm(const Variable& a,
const Variable& scale,
const Variable& bias,
const Variable& mean,
const Variable& variance,
const absl::flat_hash_map<std::string, attr_t>& attr_store) {
Instruction instr("batchnorm");
instr.SetInputs({a, scale, bias, mean, variance});
for (auto& iter : attr_store) {
instr.SetAttr(iter.first, iter.second);
}
AppendInstruction(instr);
return instr.GetOutput(0);
}
template <typename PrimType>
Variable Program::primitive_const_scalar(PrimType value, const std::string& name) {
Instruction instr("const_scalar");
instr.SetInputs({});
instr.SetAttr("value", value);
AppendInstruction(instr);
auto out = instr.GetOutput(0);
out.set_id(name);
auto out_type = type_of<PrimType>();
CHECK(out_type.is_float() || out_type.is_int() || out_type.is_bool()) << "no supported type: " << out_type;
out->type = out_type;
out.set_const(true);
return out;
}
Variable Program::primitive_broadcast_to(const Variable& a,
const std::vector<int>& out_shape,
const std::vector<int>& broadcast_axes) {
Instruction instr("broadcast_to");
instr.SetInputs({a});
instr.SetAttr("out_shape", out_shape);
instr.SetAttr("broadcast_axes", broadcast_axes);
AppendInstruction(instr);
return instr.GetOutput(0);
}
Variable Program::fused_meta_batchnorm_inference(const Variable& a,
const Variable& scale,
const Variable& bias,
const Variable& mean,
const Variable& variance,
const absl::flat_hash_map<std::string, attr_t>& attr_store) {
float epsilon = 0.00001f;
if (attr_store.find("epsilon") != attr_store.end()) {
epsilon = absl::get<float>(attr_store.at("epsilon"));
}
auto eps_var = primitive_const_scalar<float>(epsilon, common::UniqName("epsilon"));
CHECK(!scale->shape.empty()) << "scale's shape is empty.";
auto broadcast_eps = primitive_broadcast_to(eps_var, scale->shape, {0});
auto var_add_eps = add(variance, broadcast_eps);
auto rsrqt_var = primitive_rsqrt(var_add_eps);
auto new_scale = multiply(rsrqt_var, scale);
auto neg_mean = primitive_negative(mean);
auto new_shift = multiply(new_scale, neg_mean);
auto shift_bias = add(new_shift, bias);
CHECK(!a->shape.empty()) << "variable a's shape is empty.";
auto broadcast_new_scale = primitive_broadcast_to(new_scale, a->shape, {1});
auto broadcast_shift_bias = primitive_broadcast_to(shift_bias, a->shape, {1});
auto temp_out = multiply(broadcast_new_scale, a);
auto bn_out = add(temp_out, broadcast_shift_bias);
return bn_out;
}
Variable Program::fused_batchnorm_inference(const Variable& a,
const Variable& scale,
const Variable& bias,
const Variable& mean,
const Variable& variance,
const absl::flat_hash_map<std::string, attr_t>& attr_store) {
float epsilon = 0.00001f;
if (attr_store.find("epsilon") != attr_store.end()) {
epsilon = absl::get<float>(attr_store.at("epsilon"));
}
auto eps_var = primitive_const_scalar<float>(epsilon, common::UniqName("epsilon"));
CHECK(!scale->shape.empty()) << "scale's shape is empty.";
auto var_add_eps = elementwise_add(variance, eps_var);
auto rsrqt_var = primitive_rsqrt(var_add_eps);
auto new_scale = elementwise_mul(rsrqt_var, scale);
auto neg_mean = primitive_negative(mean);
auto new_shift = elementwise_mul(new_scale, neg_mean);
auto shift_bias = elementwise_add(new_shift, bias);
auto temp_out = elementwise_mul(a, new_scale, 1);
auto bn_out = elementwise_add(temp_out, shift_bias, 1);
return bn_out;
}
Variable Program::scale(const Variable& a, const absl::flat_hash_map<std::string, attr_t>& attr_store) {
Instruction instr("scale", {a});
for (auto& iter : attr_store) {
instr.SetAttr(iter.first, iter.second);
}
AppendInstruction(instr);
return instr.GetOutput(0);
}
Variable Program::softmax(const Variable& a, const absl::flat_hash_map<std::string, attr_t>& attr_store) {
Instruction instr("softmax", {a});
for (auto& iter : attr_store) {
instr.SetAttr(iter.first, iter.second);
}
AppendInstruction(instr);
return instr.GetOutput(0);
}
Variable Program::sigmoid(const Variable& a) {
Instruction instr("sigmoid", {a});
AppendInstruction(instr);
return instr.GetOutput(0);
}
Variable Program::slice(const Variable& a, const absl::flat_hash_map<std::string, attr_t>& attr_store) {
Instruction instr("slice", {a});
for (auto& iter : attr_store) {
instr.SetAttr(iter.first, iter.second);
}
AppendInstruction(instr);
return instr.GetOutput(0);
}
Variable Program::dropout_infer(const Variable& a, const absl::flat_hash_map<std::string, attr_t>& attr_store) {
Instruction instr("dropout_infer", {a});
for (auto& iter : attr_store) {
instr.SetAttr(iter.first, iter.second);
}
AppendInstruction(instr);
return instr.GetOutput(0);
}
Instruction& Program::operator[](size_t i) {
CHECK_LT(i, instrs_.size());
return instrs_[i];
}
const Instruction& Program::operator[](size_t i) const {
CHECK_LT(i, instrs_.size());
return instrs_[i];
}
std::ostream& operator<<(std::ostream& os, const Variable& x) {
os << "Var(" << x->id << ")";
return os;
}
std::ostream& operator<<(std::ostream& os, const Instruction& instr) {
os << instr->debug_string();
return os;
}
std::tuple<std::unique_ptr<Program>,
absl::flat_hash_map<std::string, Variable>,
absl::flat_hash_map<std::string, std::string>,
absl::flat_hash_set<std::string>>
LoadPaddleProgram(const std::string& model_dir, Scope* scope, bool is_combined, const common::Target& target) {
VLOG(1) << "Loading Paddle model from " << model_dir;
PaddleModelToProgram program(scope, target);
return std::make_tuple(
program(model_dir, is_combined), program.var_map(), program.var_model_to_program_map(), program.fetch_names());
}
void Program::SetInputs(const std::vector<Variable>& xs) {
CHECK(!xs.empty()) << "At least one input is needed for a program!";
for (int i = 0; i < xs.size(); i++) {
CHECK(!xs[i]->shape.empty()) << "Found " << i << "-th input's shape is not set yet";
CHECK(!xs[i]->type.is_unk()) << "Found " << i << "-th input's type is not set yet";
inputs_.push_back(xs[i]);
}
}
void Program::Validate() const {
CHECK(!inputs_.empty()) << "Inputs of the program is not set yet";
CHECK(!instrs_.empty()) << "No instruction is added yet";
}
Variable Program::add(const Variable& a, const Variable& b) {
Instruction instr("elementwise_add", {a, b});
AppendInstruction(instr);
return instr.GetOutput(0);
}
Variable Program::multiply(const Variable& a, const Variable& b) {
Instruction instr("elementwise_mul", {a, b});
AppendInstruction(instr);
return instr.GetOutput(0);
}
#define SYNTAX_PRIM_UNARY_IMPL(name__) \
Variable Program::primitive_##name__(const Variable& a) { \
Instruction instr(#name__, {a}); \
AppendInstruction(instr); \
return instr.GetOutput(0); \
}
SYNTAX_PRIM_UNARY_IMPL(exp);
SYNTAX_PRIM_UNARY_IMPL(erf);
SYNTAX_PRIM_UNARY_IMPL(sqrt);
SYNTAX_PRIM_UNARY_IMPL(log);
SYNTAX_PRIM_UNARY_IMPL(floor);
SYNTAX_PRIM_UNARY_IMPL(ceil);
SYNTAX_PRIM_UNARY_IMPL(round);
SYNTAX_PRIM_UNARY_IMPL(tanh);
SYNTAX_PRIM_UNARY_IMPL(log2);
SYNTAX_PRIM_UNARY_IMPL(log10);
SYNTAX_PRIM_UNARY_IMPL(trunc);
SYNTAX_PRIM_UNARY_IMPL(cos);
SYNTAX_PRIM_UNARY_IMPL(sin);
SYNTAX_PRIM_UNARY_IMPL(cosh);
SYNTAX_PRIM_UNARY_IMPL(tan);
SYNTAX_PRIM_UNARY_IMPL(sinh);
SYNTAX_PRIM_UNARY_IMPL(acos);
SYNTAX_PRIM_UNARY_IMPL(acosh);
SYNTAX_PRIM_UNARY_IMPL(asin);
SYNTAX_PRIM_UNARY_IMPL(asinh);
SYNTAX_PRIM_UNARY_IMPL(atan);
SYNTAX_PRIM_UNARY_IMPL(atanh);
SYNTAX_PRIM_UNARY_IMPL(isnan);
SYNTAX_PRIM_UNARY_IMPL(isfinite);
SYNTAX_PRIM_UNARY_IMPL(isinf);
SYNTAX_PRIM_UNARY_IMPL(bitwise_not);
SYNTAX_PRIM_UNARY_IMPL(negative);
SYNTAX_PRIM_UNARY_IMPL(identity);
SYNTAX_PRIM_UNARY_IMPL(logical_not);
SYNTAX_PRIM_UNARY_IMPL(sign);
SYNTAX_PRIM_UNARY_IMPL(abs);
SYNTAX_PRIM_UNARY_IMPL(rsqrt);
#define SYNTAX_PRIM_BINARY_IMPL(name__) \
Variable Program::primitive_##name__(const Variable& a, const Variable& b) { \
Instruction instr(#name__, {a, b}); \
AppendInstruction(instr); \
return instr.GetOutput(0); \
}
SYNTAX_PRIM_BINARY_IMPL(substract)
SYNTAX_PRIM_BINARY_IMPL(divide)
SYNTAX_PRIM_BINARY_IMPL(floor_divide)
SYNTAX_PRIM_BINARY_IMPL(mod)
SYNTAX_PRIM_BINARY_IMPL(floor_mod)
SYNTAX_PRIM_BINARY_IMPL(max)
SYNTAX_PRIM_BINARY_IMPL(min)
SYNTAX_PRIM_BINARY_IMPL(power)
SYNTAX_PRIM_BINARY_IMPL(logical_and)
SYNTAX_PRIM_BINARY_IMPL(logical_or)
SYNTAX_PRIM_BINARY_IMPL(logical_xor)
SYNTAX_PRIM_BINARY_IMPL(greater)
SYNTAX_PRIM_BINARY_IMPL(less)
SYNTAX_PRIM_BINARY_IMPL(equal)
SYNTAX_PRIM_BINARY_IMPL(not_equal)
SYNTAX_PRIM_BINARY_IMPL(greater_equal)
SYNTAX_PRIM_BINARY_IMPL(less_equal)
SYNTAX_PRIM_BINARY_IMPL(bitwise_or)
SYNTAX_PRIM_BINARY_IMPL(bitwise_xor)
SYNTAX_PRIM_BINARY_IMPL(bitwise_and)
SYNTAX_PRIM_BINARY_IMPL(left_shift)
SYNTAX_PRIM_BINARY_IMPL(right_shift)
Variable Program::elementwise_add(const Variable& a, const Variable& b, int axis) {
Instruction instr("elementwise_add", {a, b});
instr.SetAttr("axis", axis);
AppendInstruction(instr);
return instr.GetOutput(0);
}
Variable Program::elementwise_mul(const Variable& a, const Variable& b, int axis) {
Instruction instr("elementwise_mul", {a, b});
instr.SetAttr("axis", axis);
AppendInstruction(instr);
return instr.GetOutput(0);
}
Variable Program::elementwise_div(const Variable& a, const Variable& b, int axis) {
Instruction instr("divide", {a, b});
instr.SetAttr("axis", axis);
AppendInstruction(instr);
return instr.GetOutput(0);
}
Variable Program::elementwise_sub(const Variable& a, const Variable& b, int axis) {
Instruction instr("substract", {a, b});
instr.SetAttr("axis", axis);
AppendInstruction(instr);
return instr.GetOutput(0);
}
#define SYNTAX_PRIM_REDUCE_IMPL(name__) \
Variable Program::reduce_##name__(const Variable& a, const std::vector<int>& dim, bool keep_dim) { \
Instruction instr("reduce_" #name__, {a}); \
instr.SetAttr("dim", dim); \
instr.SetAttr("keep_dim", keep_dim); \
AppendInstruction(instr); \
return instr.GetOutput(0); \
}
SYNTAX_PRIM_REDUCE_IMPL(sum)
SYNTAX_PRIM_REDUCE_IMPL(prod)
SYNTAX_PRIM_REDUCE_IMPL(min)
SYNTAX_PRIM_REDUCE_IMPL(max)
Variable Program::assign(const Variable& a) {
Instruction instr("identity", {a});
AppendInstruction(instr);
return instr.GetOutput(0);
}
Variable Program::relu(const Variable& a) {
Instruction instr("relu", {a});
AppendInstruction(instr);
return instr.GetOutput(0);
}
Variable Program::relu6(const Variable& a) {
Instruction instr("relu6", {a});
AppendInstruction(instr);
return instr.GetOutput(0);
}
Variable Program::mul(const Variable& a, const Variable& b, int x_num_col_dims, int y_num_col_dims) {
Instruction instr("mul", {a, b});
instr.SetAttr("x_num_col_dims", x_num_col_dims);
instr.SetAttr("y_num_col_dims", y_num_col_dims);
AppendInstruction(instr);
return instr.GetOutput(0);
}
Variable Program::matmul(const Variable& a, const Variable& b, bool trans_a, bool trans_b, float alpha) {
Instruction instr("matmul", {a, b});
instr.SetAttr("trans_a", trans_a);
instr.SetAttr("trans_b", trans_b);
instr.SetAttr("alpha", alpha);
AppendInstruction(instr);
return instr.GetOutput(0);
}
Variable Program::reshape(const Variable& a, const std::vector<int>& shape) {
Instruction instr("reshape", {a});
instr.SetAttr("shape", shape);
AppendInstruction(instr);
return instr.GetOutput(0);
}
Variable Program::concat(const std::vector<Variable>& input_vars, int axis) {
Instruction instr("concat", input_vars);
instr.SetAttr("axis", axis);
AppendInstruction(instr);
return instr.GetOutput(0);
}
Variable Program::transpose(const Variable& a, const std::vector<int>& axis) {
Instruction instr("transpose", {a});
instr.SetAttr("axis", axis);
AppendInstruction(instr);
return instr.GetOutput(0);
}
Variable Program::mulbias(
const Variable& a, const Variable& b, const Variable& c, int x_num_col_dims, int y_num_col_dims) {
Instruction instr("mulbias", {a, b, c});
instr.SetAttr("x_num_col_dims", x_num_col_dims);
instr.SetAttr("y_num_col_dims", y_num_col_dims);
AppendInstruction(instr);
return instr.GetOutput(1);
}
std::string _Instruction_::debug_string() const {
struct Visit {
std::stringstream& s_;
explicit Visit(std::stringstream& s) : s_(s) {}
void operator()(int x) { s_ << x; }
void operator()(float x) { s_ << x; }
void operator()(bool x) { s_ << (x ? "true" : "false"); }
void operator()(const std::string& x) { s_ << x; }
void operator()(const std::vector<int>& x) { s_ << "[" + utils::Join(x, ",") + "]"; }
void operator()(const std::vector<float>& x) { s_ << "[" + utils::Join(x, ",") + "]"; }
void operator()(const std::vector<bool>& x) { s_ << "[" + utils::Join(x, ",") + "]"; }
void operator()(const std::vector<std::string>& x) { s_ << "[" + utils::Join(x, ",") + "]"; }
};
std::stringstream ss;
std::vector<std::string> input_names, output_names;
std::transform(
inputs.begin(), inputs.end(), std::back_inserter(input_names), [](const Variable& x) { return x->id; });
std::transform(
outputs.begin(), outputs.end(), std::back_inserter(output_names), [](const Variable& x) { return x->id; });
ss << utils::Join(output_names, ", ");
ss << " = ";
ss << op_type;
ss << "(";
ss << utils::Join(input_names, ", ");
if (!attrs.empty() && !input_names.empty()) ss << ", ";
std::vector<std::string> attr_strs;
for (auto& attr : attrs) {
std::stringstream iss;
iss << attr.first << "=";
absl::visit(Visit{iss}, attr.second);
attr_strs.push_back(iss.str());
}
ss << utils::Join(attr_strs, ", ");
ss << ")";
return ss.str();
}
std::ostream& operator<<(std::ostream& os, const Program& program) {
os << "Program {\n";
for (int i = 0; i < program.size(); i++) {
os << program[i] << "\n";
}
os << "}\n";
return os;
}
} // namespace frontend
} // namespace cinn