Program Listing for File cinn_builder.cc

Return to documentation for file (/WorkSpace/CINN/cinn/frontend/cinn_builder.cc)

// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "cinn/frontend/cinn_builder.h"

#include <glog/logging.h>

#include <string>
#include <utility>
#include <vector>

#include "cinn/frontend/syntax.h"

namespace cinn {
namespace frontend {

#define UNARY_OP_DEF(func_name__, op_type__) \
  Variable CinnBuilder::func_name__(const Variable& operand) { return UnaryOp(#op_type__, operand); }
UNARY_OP_DEF(Exp, exp)
UNARY_OP_DEF(Erf, erf)
UNARY_OP_DEF(Sqrt, sqrt)
UNARY_OP_DEF(Rsqrt, rsqrt)
UNARY_OP_DEF(Log, log)
UNARY_OP_DEF(Log2, log2)
UNARY_OP_DEF(Log10, log10)
UNARY_OP_DEF(Floor, floor)
UNARY_OP_DEF(Ceil, ceil)
UNARY_OP_DEF(Round, round)
UNARY_OP_DEF(Trunc, trunc)
UNARY_OP_DEF(Sin, sin)
UNARY_OP_DEF(Cos, cos)
UNARY_OP_DEF(Tan, tan)
UNARY_OP_DEF(Sinh, sinh)
UNARY_OP_DEF(Cosh, cosh)
UNARY_OP_DEF(Tanh, tanh)
UNARY_OP_DEF(Asin, asin)
UNARY_OP_DEF(Acos, acos)
UNARY_OP_DEF(Atan, atan)
UNARY_OP_DEF(Asinh, asinh)
UNARY_OP_DEF(Acosh, acosh)
UNARY_OP_DEF(Atanh, atanh)
UNARY_OP_DEF(IsNan, isnan)
UNARY_OP_DEF(IsFinite, isfinite)
UNARY_OP_DEF(IsInf, isinf)
UNARY_OP_DEF(LogicalNot, logical_not)
UNARY_OP_DEF(BitwiseNot, bitwise_not)
UNARY_OP_DEF(Negative, negative)
UNARY_OP_DEF(Sign, sign)
UNARY_OP_DEF(Abs, abs)
UNARY_OP_DEF(Identity, identity)
#undef UNARY_OP_DEF

#define BINARY_OP_DEF(func_name__, op_type__) \
  Variable CinnBuilder::func_name__(const Variable& lhs, const Variable& rhs) { return BinaryOp(#op_type__, lhs, rhs); }
BINARY_OP_DEF(Dot, matmul)
BINARY_OP_DEF(Add, elementwise_add)
BINARY_OP_DEF(Sub, substract)
BINARY_OP_DEF(Mul, elementwise_mul)
BINARY_OP_DEF(Div, divide)
BINARY_OP_DEF(FloorDiv, floor_divide)
BINARY_OP_DEF(Mod, mod)
BINARY_OP_DEF(FloorMod, floor_mod)
BINARY_OP_DEF(Max, max)
BINARY_OP_DEF(Min, min)
BINARY_OP_DEF(Power, power)
BINARY_OP_DEF(LogicalAnd, logical_and)
BINARY_OP_DEF(LogicalOr, logical_or)
BINARY_OP_DEF(LogicalXor, logical_xor)
BINARY_OP_DEF(BitwiseAnd, bitwise_and)
BINARY_OP_DEF(BitwiseOr, bitwise_or)
BINARY_OP_DEF(BitwiseXor, bitwise_xor)
BINARY_OP_DEF(LeftShift, left_shift)
BINARY_OP_DEF(RightShift, right_shift)
#undef BINARY_OP_DEF

Variable CinnBuilder::Concat(const std::vector<Variable>& input_vars, int axis) {
  Instruction instr("concat", input_vars);
  instr.SetAttr("axis", axis);
  InferShape(instr);
  AppendInstruction(instr);
  return instr.GetOutput(0);
}

Variable CinnBuilder::Conv(const Variable& lhs,
                           const Variable& rhs,
                           const std::vector<int>& strides,
                           const std::vector<int>& paddings,
                           const std::vector<int>& dilations,
                           int groups,
                           const std::string& conv_type,
                           const std::string& data_format,
                           const std::string& padding_algorithm,
                           const std::vector<int>& output_shape) {
  Instruction instr("conv2d");
  instr.SetInputs({lhs, rhs});
  instr.SetAttr("stride", strides);
  instr.SetAttr("padding", paddings);
  instr.SetAttr("dilation", dilations);
  instr.SetAttr("groups", groups);
  instr.SetAttr("conv_type", conv_type);
  instr.SetAttr("data_format", data_format);
  instr.SetAttr("padding_algorithm", padding_algorithm);
  instr.SetAttr("output_shape", output_shape);

  InferShape(instr);
  AppendInstruction(instr);
  return instr.GetOutput(0);
}

Variable CinnBuilder::Compare(const Variable& lhs, const Variable& rhs, ComparisonKind kind) {
  switch (kind) {
    case ComparisonKind::kEq:
      return BinaryOp("equal", lhs, rhs);
    case ComparisonKind::kNe:
      return BinaryOp("not_equal", lhs, rhs);
    case ComparisonKind::kGe:
      return BinaryOp("greater_equal", lhs, rhs);
    case ComparisonKind::kGt:
      return BinaryOp("greater", lhs, rhs);
    case ComparisonKind::kLe:
      return BinaryOp("less_equal", lhs, rhs);
    case ComparisonKind::kLt:
      return BinaryOp("less", lhs, rhs);
    default:
      LOG(FATAL) << "unknown comparison kind";
  }
}

Variable CinnBuilder::Reduce(const Variable& operand, ReduceKind kind, const std::vector<int>& dim, bool keep_dim) {
  auto reduce_func = [&](const std::string& op_type) {
    Instruction instr(op_type, {operand});
    instr.SetAttr("dim", dim);
    instr.SetAttr("keep_dim", keep_dim);
    InferShape(instr);
    AppendInstruction(instr);
    return instr.GetOutput(0);
  };

  switch (kind) {
    case ReduceKind::kSum:
      return reduce_func("reduce_sum");
    case ReduceKind::kProd:
      return reduce_func("reduce_prod");
    case ReduceKind::kMax:
      return reduce_func("reduce_max");
    case ReduceKind::kMin:
      return reduce_func("reduce_min");
    default:
      LOG(FATAL) << "unknown reduction kind";
  }
}

Variable CinnBuilder::BroadcastTo(const Variable& operand,
                                  const std::vector<int>& out_shape,
                                  const std::vector<int>& broadcast_axes) {
  Instruction instr("broadcast_to", {operand});
  instr.SetAttr("out_shape", out_shape);
  instr.SetAttr("broadcast_axes", broadcast_axes);
  InferShape(instr);
  AppendInstruction(instr);
  return instr.GetOutput(0);
}

Variable CinnBuilder::Reshape(const Variable& operand, const std::vector<int>& shape) {
  Instruction instr("reshape", {operand});
  instr.SetAttr("shape", shape);
  InferShape(instr);
  AppendInstruction(instr);
  return instr.GetOutput(0);
}

Variable CinnBuilder::Transpose(const Variable& operand, const std::vector<int>& axis) {
  Instruction instr("transpose", {operand});
  instr.SetAttr("axis", axis);
  InferShape(instr);
  AppendInstruction(instr);
  return instr.GetOutput(0);
}

Variable CinnBuilder::Slice(const Variable& operand,
                            const std::vector<int>& axes,
                            const std::vector<int>& starts,
                            const std::vector<int>& ends) {
  Instruction instr("slice", {operand});
  instr.SetAttr("axes", axes);
  instr.SetAttr("starts", starts);
  instr.SetAttr("ends", ends);
  InferShape(instr);
  AppendInstruction(instr);
  return instr.GetOutput(0);
}

Variable CinnBuilder::Select(const Variable& condition, const Variable& true_value, const Variable& false_value) {
  Instruction instr("select", {condition, true_value, false_value});
  InferShape(instr);
  AppendInstruction(instr);
  return instr.GetOutput(0);
}

Variable CinnBuilder::Reverse(const Variable& operand, const std::vector<int>& axis) {
  Instruction instr("reverse", {operand});
  instr.SetAttr("axis", axis);
  InferShape(instr);
  AppendInstruction(instr);
  return instr.GetOutput(0);
}

std::vector<Variable> CinnBuilder::BnMeanVariance(const Variable& x) {
  Instruction instr("bn_mean_variance", {x});
  // optimize bn forward reduce computation, set reduce dimension(NCHW suppport only, to be deprecated).
  instr.SetAttr("dim", std::vector<int>{0, 2, 3});
  instr.SetAttr("keep_dim", false);
  InferShape(instr);
  AppendInstruction(instr);
  return instr.GetOutputs();
}

std::vector<Variable> CinnBuilder::BnGradBiasScale(const Variable& x, const Variable& x_mean, const Variable& y_grad) {
  Instruction instr("bn_grad_bias_scale", {x, x_mean, y_grad});
  // optimize bn backward reduce computation, set reduce dimension(NCHW suppport only, to be deprecated).
  instr.SetAttr("dim", std::vector<int>{0, 2, 3});
  instr.SetAttr("keep_dim", false);
  InferShape(instr);
  AppendInstruction(instr);
  return instr.GetOutputs();
}

Variable CinnBuilder::UnaryOp(const std::string& op_type, const Variable& operand) {
  Instruction instr(op_type, {operand});
  InferShape(instr);
  AppendInstruction(instr);
  return instr.GetOutput(0);
}

Variable CinnBuilder::BinaryOp(const std::string& op_type, const Variable& lhs, const Variable& rhs) {
  Instruction instr(op_type, {lhs, rhs});
  InferShape(instr);
  AppendInstruction(instr);
  return instr.GetOutput(0);
}

}  // namespace frontend
}  // namespace cinn