Program Listing for File syntax.h
↰ Return to documentation for file (/WorkSpace/CINN/cinn/frontend/syntax.h
)
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <absl/container/flat_hash_map.h>
#include <absl/container/flat_hash_set.h>
#include <absl/strings/string_view.h>
#include <glog/logging.h>
#include <memory>
#include <string>
#include <tuple>
#include <utility>
#include <vector>
#include "cinn/common/common.h"
#include "cinn/common/context.h"
#include "cinn/common/object.h"
#include "cinn/common/type.h"
#include "cinn/hlir/framework/node.h"
#include "cinn/hlir/framework/scope.h"
namespace cinn {
namespace frontend {
struct Program;
struct Variable;
struct _Variable_ : public common::Object {
std::string id;
common::Type type;
std::vector<int> shape;
bool is_const = false;
const char* type_info() const override { return __type_info__; }
static constexpr char* __type_info__ = "cinn_frontend_variable";
};
struct Variable : public common::Shared<_Variable_> {
explicit Variable(const std::string& id_hint = "") : common::Shared<_Variable_>(common::make_shared<_Variable_>()) {
if (!id_hint.empty()) CheckVarNameValid(id_hint);
get()->id = id_hint.empty() ? common::Context::Global().NewName("var") : id_hint;
}
void set_id(const std::string& id) { operator->()->id = id; }
void set_const(bool is_const) { operator->()->is_const = is_const; }
bool is_const() { return operator->()->is_const; }
_Variable_* operator->() { return get(); }
const _Variable_* operator->() const { return get(); }
};
class Placeholder {
public:
Placeholder(const common::Type& type,
const std::vector<int>& shape,
absl::string_view id_hint = "",
bool is_const = false) {
if (!id_hint.empty()) CheckVarNameValid(std::string(id_hint));
id_ = id_hint.empty() ? common::Context::Global().NewName("placeholder") : (std::string)id_hint;
var_ = Variable(id_);
var_->shape = shape;
var_->type = type;
var_->is_const = is_const;
}
explicit Placeholder(const Variable& var) {
id_ = var->id;
var_ = var;
}
const std::vector<int>& shape() const { return var_->shape; }
Type type() const { return var_->type; }
absl::string_view id() const { return id_; }
operator Variable() const;
void set_const(bool is_const) { Variable()->is_const = is_const; }
bool is_const() { return Variable().is_const(); }
Program* parent_program() { return parent_program_; }
private:
Variable var_;
std::string id_{};
Program* parent_program_{};
};
struct _Instruction_ : public common::Object {
using attr_t = hlir::framework::AttrType;
std::string op_type;
absl::flat_hash_map<std::string, attr_t> attrs;
std::vector<std::pair<std::string, attr_t>> attrs_ordered;
std::vector<Variable> inputs;
std::vector<Variable> outputs;
Program* parent_program{};
const char* type_info() const override { return __type_info__; }
std::string debug_string() const;
static constexpr char* __type_info__ = "cinn_frontend_instruction";
};
struct Instruction : public common::Shared<_Instruction_> {
explicit Instruction(absl::string_view op_type, const std::vector<Variable>& inputs = {}, Program* parent = nullptr);
void SetInputs(const std::vector<Variable>& vars) { get()->inputs = vars; }
const std::vector<Variable>& GetOutputs() const { return get()->outputs; }
const Variable& GetOutput(size_t offset) const {
CHECK_LT(offset, get()->outputs.size());
return GetOutputs()[offset];
}
template <typename T>
void SetAttr(const std::string& key, const T& v) {
get()->attrs[key] = v;
}
template <typename T>
T GetAttrs(const std::string& key) const {
auto it = get()->attrs.find(key);
CHECK(it != get()->attrs.end()) << "No attribute called [" << key << "]";
return absl::get<T>(it->second);
}
private:
// Generate outputs according to op's declaration.
void PrepareOutputs();
};
struct Program {
using attr_t = hlir::framework::NodeAttr::attr_t;
Program() = default;
Program(std::vector<Instruction>&& instrs, std::vector<Variable>&& inputs)
: instrs_(std::move(instrs)), inputs_(std::move(inputs)) {}
void SetInputs(const std::vector<Variable>& xs);
const std::vector<Variable>& GetInputs() const { return inputs_; }
template <typename PrimType>
Variable primitive_const_scalar(PrimType value, const std::string& name);
template <typename PrimType>
Variable fill_constant(const std::vector<int>& shape,
float float_value,
const std::string& str_value,
bool force_cpu,
const std::string& name) {
Instruction instr("fill_constant");
PrimType value;
if (str_value.empty()) {
value = static_cast<PrimType>(float_value);
} else {
if (str_value == "inf") {
value = static_cast<PrimType>(std::numeric_limits<double>::infinity());
} else if (str_value == "-inf") {
value = static_cast<PrimType>(-std::numeric_limits<double>::infinity());
} else if (str_value == "nan") {
value = static_cast<PrimType>(std::numeric_limits<double>::quiet_NaN());
} else {
std::stringstream convert_stream(str_value);
if (std::is_same<int64_t, PrimType>::value) {
int64_t tmp_value;
convert_stream >> tmp_value;
value = static_cast<PrimType>(tmp_value);
} else {
double tmp_value;
convert_stream >> tmp_value;
value = static_cast<PrimType>(tmp_value);
}
}
}
instr.SetInputs({});
instr.SetAttr("shape", shape);
instr.SetAttr("value", value);
instr.SetAttr("force_cpu", force_cpu);
AppendInstruction(instr);
auto out = instr.GetOutput(0);
out.set_id(name);
return out;
}
Variable add(const Variable& a, const Variable& b);
Variable multiply(const Variable& a, const Variable& b);
Variable mul(const Variable& a, const Variable& b, int x_num_col_dims = 1, int y_num_col_dims = 1);
Variable matmul(const Variable& a, const Variable& b, bool trans_a = false, bool trans_b = false, float alpha = 1);
Variable reshape(const Variable& a, const std::vector<int>& shape);
Variable concat(const std::vector<Variable>& input_vars, int axis = 0);
Variable transpose(const Variable& input_vars, const std::vector<int>& axis);
Variable mulbias(
const Variable& a, const Variable& b, const Variable& c, int x_num_col_dims = 1, int y_num_col_dims = 1);
#define SYNTAX_PRIM_UNARY_DECL(name__) Variable primitive_##name__(const Variable& a);
SYNTAX_PRIM_UNARY_DECL(exp);
SYNTAX_PRIM_UNARY_DECL(erf);
SYNTAX_PRIM_UNARY_DECL(sqrt);
SYNTAX_PRIM_UNARY_DECL(log);
SYNTAX_PRIM_UNARY_DECL(floor);
SYNTAX_PRIM_UNARY_DECL(ceil);
SYNTAX_PRIM_UNARY_DECL(round);
SYNTAX_PRIM_UNARY_DECL(tanh);
SYNTAX_PRIM_UNARY_DECL(log2);
SYNTAX_PRIM_UNARY_DECL(log10);
SYNTAX_PRIM_UNARY_DECL(trunc);
SYNTAX_PRIM_UNARY_DECL(cos);
SYNTAX_PRIM_UNARY_DECL(sin);
SYNTAX_PRIM_UNARY_DECL(cosh);
SYNTAX_PRIM_UNARY_DECL(tan);
SYNTAX_PRIM_UNARY_DECL(sinh);
SYNTAX_PRIM_UNARY_DECL(acos);
SYNTAX_PRIM_UNARY_DECL(acosh);
SYNTAX_PRIM_UNARY_DECL(asin);
SYNTAX_PRIM_UNARY_DECL(asinh);
SYNTAX_PRIM_UNARY_DECL(atan);
SYNTAX_PRIM_UNARY_DECL(atanh);
SYNTAX_PRIM_UNARY_DECL(isnan);
SYNTAX_PRIM_UNARY_DECL(isfinite);
SYNTAX_PRIM_UNARY_DECL(isinf);
SYNTAX_PRIM_UNARY_DECL(bitwise_not);
SYNTAX_PRIM_UNARY_DECL(negative);
SYNTAX_PRIM_UNARY_DECL(identity);
SYNTAX_PRIM_UNARY_DECL(logical_not);
SYNTAX_PRIM_UNARY_DECL(sign);
SYNTAX_PRIM_UNARY_DECL(abs);
SYNTAX_PRIM_UNARY_DECL(rsqrt);
#define SYNTAX_PRIM_BINARY_DECL(name__) Variable primitive_##name__(const Variable& a, const Variable& b);
SYNTAX_PRIM_BINARY_DECL(substract)
SYNTAX_PRIM_BINARY_DECL(divide)
SYNTAX_PRIM_BINARY_DECL(floor_divide)
SYNTAX_PRIM_BINARY_DECL(mod)
SYNTAX_PRIM_BINARY_DECL(floor_mod)
SYNTAX_PRIM_BINARY_DECL(max)
SYNTAX_PRIM_BINARY_DECL(min)
SYNTAX_PRIM_BINARY_DECL(power)
SYNTAX_PRIM_BINARY_DECL(logical_and)
SYNTAX_PRIM_BINARY_DECL(logical_or)
SYNTAX_PRIM_BINARY_DECL(logical_xor)
SYNTAX_PRIM_BINARY_DECL(greater)
SYNTAX_PRIM_BINARY_DECL(less)
SYNTAX_PRIM_BINARY_DECL(equal)
SYNTAX_PRIM_BINARY_DECL(not_equal)
SYNTAX_PRIM_BINARY_DECL(greater_equal)
SYNTAX_PRIM_BINARY_DECL(less_equal)
SYNTAX_PRIM_BINARY_DECL(bitwise_or)
SYNTAX_PRIM_BINARY_DECL(bitwise_xor)
SYNTAX_PRIM_BINARY_DECL(bitwise_and)
SYNTAX_PRIM_BINARY_DECL(left_shift)
SYNTAX_PRIM_BINARY_DECL(right_shift)
#define SYNTAX_PRIM_REDUCE_DECL(name__) \
Variable reduce_##name__(const Variable& a, const std::vector<int>& dim, bool keep_dim = false);
SYNTAX_PRIM_REDUCE_DECL(sum)
SYNTAX_PRIM_REDUCE_DECL(prod)
SYNTAX_PRIM_REDUCE_DECL(min)
SYNTAX_PRIM_REDUCE_DECL(max)
Variable primitive_broadcast_to(const Variable& a,
const std::vector<int>& out_shape,
const std::vector<int>& broadcast_axes);
Variable elementwise_add(const Variable& a, const Variable& b, int axis = -1);
Variable elementwise_mul(const Variable& a, const Variable& b, int axis = -1);
Variable elementwise_div(const Variable& a, const Variable& b, int axis = -1);
Variable elementwise_sub(const Variable& a, const Variable& b, int axis = -1);
// copy the tensor
Variable assign(const Variable& a);
Variable relu(const Variable& a);
Variable relu6(const Variable& a);
Variable conv2d(const Variable& a, const Variable& b, const absl::flat_hash_map<std::string, attr_t>& attr_store);
Variable layout_transform(const Variable& a, const absl::flat_hash_map<std::string, attr_t>& attr_store);
Variable conv2d_NCHWc(const Variable& a,
const Variable& b,
const absl::flat_hash_map<std::string, attr_t>& attr_store);
Variable depthwise_conv2d(const Variable& a,
const Variable& b,
const absl::flat_hash_map<std::string, attr_t>& attr_store);
Variable pool2d(const Variable& a, const absl::flat_hash_map<std::string, attr_t>& attr_store);
Variable batchnorm(const Variable& a,
const Variable& scale,
const Variable& bias,
const Variable& mean,
const Variable& variance,
const absl::flat_hash_map<std::string, attr_t>& attr_store);
Variable fused_meta_batchnorm_inference(const Variable& a,
const Variable& scale,
const Variable& bias,
const Variable& mean,
const Variable& variance,
const absl::flat_hash_map<std::string, attr_t>& attr_store);
Variable fused_batchnorm_inference(const Variable& a,
const Variable& scale,
const Variable& bias,
const Variable& mean,
const Variable& variance,
const absl::flat_hash_map<std::string, attr_t>& attr_store);
Variable scale(const Variable& a, const absl::flat_hash_map<std::string, attr_t>& attr_store);
Variable softmax(const Variable& a, const absl::flat_hash_map<std::string, attr_t>& attr_store);
Variable sigmoid(const Variable& a);
Variable slice(const Variable& a, const absl::flat_hash_map<std::string, attr_t>& attr_store);
Variable dropout_infer(const Variable& a, const absl::flat_hash_map<std::string, attr_t>& attr_store);
Instruction& operator[](size_t i);
const Instruction& operator[](size_t i) const;
inline size_t size() const { return instrs_.size(); }
void Validate() const;
void AppendInstruction(const Instruction& other) { instrs_.push_back(other); }
private:
std::vector<Instruction> instrs_;
std::vector<Variable> inputs_;
};
std::tuple<std::unique_ptr<Program>,
absl::flat_hash_map<std::string, Variable>,
absl::flat_hash_map<std::string, std::string>,
absl::flat_hash_set<std::string>>
LoadPaddleProgram(const std::string& model_dir,
hlir::framework::Scope* scope,
bool is_combined,
const common::Target& target = common::DefaultHostTarget());
std::ostream& operator<<(std::ostream& os, const Variable& x);
std::ostream& operator<<(std::ostream& os, const Instruction& instr);
std::ostream& operator<<(std::ostream& os, const Program& program);
} // namespace frontend
} // namespace cinn