Program Listing for File builtin.h
↰ Return to documentation for file (/WorkSpace/CINN/cinn/lang/builtin.h
)
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <vector>
#include "cinn/common/ir_util.h"
#include "cinn/ir/ir.h"
#include "cinn/ir/ir_operators.h"
namespace cinn {
namespace lang {
Expr logic_and(const std::vector<Expr>& conds);
Expr logic_or(const std::vector<Expr>& conds);
#define EXTERN_CALL_DCL(name__) Expr name__(Expr e);
EXTERN_CALL_DCL(Exp);
EXTERN_CALL_DCL(Erf);
EXTERN_CALL_DCL(Sqrt);
EXTERN_CALL_DCL(Rsqrt);
EXTERN_CALL_DCL(Log);
EXTERN_CALL_DCL(Log2);
EXTERN_CALL_DCL(Log10);
EXTERN_CALL_DCL(Floor);
EXTERN_CALL_DCL(Ceil);
EXTERN_CALL_DCL(Round);
EXTERN_CALL_DCL(Trunc);
EXTERN_CALL_DCL(Cos);
EXTERN_CALL_DCL(Cosh);
EXTERN_CALL_DCL(Tan);
EXTERN_CALL_DCL(Sin);
EXTERN_CALL_DCL(Sinh);
EXTERN_CALL_DCL(Acos);
EXTERN_CALL_DCL(Acosh);
EXTERN_CALL_DCL(Asin);
EXTERN_CALL_DCL(Asinh);
EXTERN_CALL_DCL(Atan);
EXTERN_CALL_DCL(Atanh);
EXTERN_CALL_DCL(Tanh);
inline Expr Sigmoid(Expr e) {
auto one = common::make_const(e->type(), 1);
return one / (one + Exp(-e));
}
inline Expr Sign(Expr e) {
auto zero = make_const(e->type(), 0);
auto one = make_const(e->type(), 1);
auto neg_one = make_const(e->type(), -1);
auto ret1 = ir::Select::Make(e > zero, one, zero);
auto ret2 = ir::Select::Make(e < zero, neg_one, ret1);
return ret2;
}
Expr Abs(Expr e);
inline Expr Negative(Expr e) { return -e; }
inline Expr Identity(Expr e) { return e; }
inline Expr LogicalNot(Expr e) { return !e; }
inline Expr BitwiseNot(Expr e) { return ~e; }
inline Expr BitwiseAnd(Expr a, Expr b) { return a & b; }
inline Expr BitwiseOr(Expr a, Expr b) { return a | b; }
inline Expr BitwiseXor(Expr a, Expr b) { return a ^ b; }
inline Expr LeftShift(Expr a, Expr b) { return a << b; }
inline Expr RightShift(Expr a, Expr b) { return a >> b; }
template <typename T>
inline Expr Relu(Expr e, T threshold = static_cast<T>(0)) {
return ir::Max::Make(e, make_const(e->type(), threshold));
}
template <typename T>
inline Expr Relu6(Expr e, T threshold = static_cast<T>(0)) {
return ir::Min::Make(ir::Max::Make(e, make_const(e->type(), threshold)), make_const(e->type(), 6));
}
inline Expr LeakyRelu(Expr e, double alpha) {
auto zero = make_const(e->type(), 0);
return ir::Select::Make(e > zero, e, e * make_const(e->type(), alpha));
}
inline Expr LeakyRelu(Expr e, Expr alpha) {
auto zero = make_const(e->type(), 0);
return ir::Select::Make(e > zero, e, e * alpha);
}
inline Expr ReduceSum(Expr e, const std::vector<Var>& reduce_axis, Expr initial = Expr()) {
if (!initial.defined()) {
initial = make_const(e->type(), 0.f);
}
return ir::Reduce::Make(ir::Reduce::kSum, initial, e, reduce_axis);
}
inline Expr ReduceMul(Expr e, const std::vector<Var>& reduce_axis, Expr initial = Expr()) {
if (!initial.defined()) {
initial = make_const(e->type(), 1);
}
return ir::Reduce::Make(ir::Reduce::kMul, initial, e, reduce_axis);
}
Expr min_value(const Type& type);
Expr max_value(const Type& type);
inline Expr ReduceMax(Expr e, const std::vector<Var>& reduce_axis, Expr initial = Expr()) {
if (!initial.defined()) {
initial = min_value(e.type());
}
return ir::Reduce::Make(ir::Reduce::kMax, initial, e, reduce_axis);
}
inline Expr ReduceMin(Expr e, const std::vector<Var>& reduce_axis, Expr initial = Expr()) {
if (!initial.defined()) {
initial = max_value(e.type());
}
return ir::Reduce::Make(ir::Reduce::kMin, initial, e, reduce_axis);
}
Expr IsNan(Expr e);
Expr Infinity(const Type& type);
Expr IsInf(Expr e);
Expr IsFinite(Expr e);
} // namespace lang
} // namespace cinn