Program Listing for File cinn_builder.h

Return to documentation for file (/WorkSpace/CINN/cinn/frontend/cinn_builder.h)

// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <glog/logging.h>

#include <cstdint>
#include <string>
#include <vector>

#include "cinn/common/type.h"
#include "cinn/frontend/base_builder.h"
#include "cinn/frontend/syntax.h"

// clang-format off
#define UNARY_OP_FOREACH(macro__)         \
    macro__(Exp)                          \
    macro__(Erf)                          \
    macro__(Sqrt)                         \
    macro__(Rsqrt)                        \
    macro__(Log)                          \
    macro__(Log2)                         \
    macro__(Log10)                        \
    macro__(Floor)                        \
    macro__(Ceil)                         \
    macro__(Round)                        \
    macro__(Trunc)                        \
    macro__(Sin)                          \
    macro__(Cos)                          \
    macro__(Tan)                          \
    macro__(Sinh)                         \
    macro__(Cosh)                         \
    macro__(Tanh)                         \
    macro__(Asin)                         \
    macro__(Acos)                         \
    macro__(Atan)                         \
    macro__(Asinh)                        \
    macro__(Acosh)                        \
    macro__(Atanh)                        \
    macro__(IsNan)                        \
    macro__(IsFinite)                     \
    macro__(IsInf)                        \
    macro__(LogicalNot)                   \
    macro__(BitwiseNot)                   \
    macro__(Negative)                     \
    macro__(Sign)                         \
    macro__(Abs)                          \
    macro__(Identity)

#define BINARY_OP_FOREACH(macro__)        \
    macro__(Dot)                          \
    macro__(Add)                          \
    macro__(Sub)                          \
    macro__(Mul)                          \
    macro__(Div)                          \
    macro__(FloorDiv)                     \
    macro__(Mod)                          \
    macro__(FloorMod)                     \
    macro__(Max)                          \
    macro__(Min)                          \
    macro__(Power)                        \
    macro__(LogicalAnd)                   \
    macro__(LogicalOr)                    \
    macro__(LogicalXor)                   \
    macro__(BitwiseAnd)                   \
    macro__(BitwiseOr)                    \
    macro__(BitwiseXor)                   \
    macro__(LeftShift)                    \
    macro__(RightShift)
// clang-format on
namespace cinn {
namespace frontend {

enum class ComparisonKind : std::int8_t {
  kUnk = -1,
  kEq,
  kNe,
  kGe,
  kGt,
  kLe,
  kLt,
};

enum class ReduceKind : std::int8_t {
  kUnk = -1,
  kSum,
  kProd,
  kMax,
  kMin,
};

class CinnBuilder : public BaseBuilder {
 public:
  using BaseBuilder::BaseBuilder;

  template <typename T>
  Variable ConstScalar(T value, const std::string& name) {
    Instruction instr("const_scalar");
    instr.SetInputs({});
    instr.SetAttr<T>("value", value);
    InferShape(instr);
    AppendInstruction(instr);
    auto out = instr.GetOutput(0);
    out.set_id(name);
    auto out_type = type_of<T>();
    CHECK(out_type.is_float() || out_type.is_int() || out_type.is_bool()) << "no supported type: " << out_type;
    out->type = out_type;
    return out;
  }

#define UNARY_OP_DECL(func_name__) Variable func_name__(const Variable& operand);
  UNARY_OP_FOREACH(UNARY_OP_DECL)
#undef UNARY_OP_DECL

#define BINARY_OP_DECL(func_name__) Variable func_name__(const Variable& lhs, const Variable& rhs);
  BINARY_OP_FOREACH(BINARY_OP_DECL)
#undef BINARY_OP_DECL

  Variable Concat(const std::vector<Variable>& input_vars, int axis = 0);

  Variable Conv(const Variable& lhs,
                const Variable& rhs,
                const std::vector<int>& strides      = {1, 1},
                const std::vector<int>& paddings     = {0, 0},
                const std::vector<int>& dilations    = {1, 1},
                int groups                           = 1,
                const std::string& conv_type         = "forward",
                const std::string& data_format       = "NCHW",
                const std::string& padding_algorithm = "EXPLICIT",
                const std::vector<int>& output_shape = {});

  Variable Compare(const Variable& lhs, const Variable& rhs, ComparisonKind kind);

  Variable Reduce(const Variable& operand, ReduceKind kind, const std::vector<int>& dim, bool keep_dim = false);

  Variable BroadcastTo(const Variable& operand,
                       const std::vector<int>& out_shape,
                       const std::vector<int>& broadcast_axes);

  Variable Reshape(const Variable& operand, const std::vector<int>& shape);

  Variable Transpose(const Variable& operand, const std::vector<int>& axis);

  Variable Slice(const Variable& operand,
                 const std::vector<int>& axes,
                 const std::vector<int>& starts = {},
                 const std::vector<int>& ends   = {});

  Variable Select(const Variable& condition, const Variable& true_value, const Variable& false_value);

  Variable Reverse(const Variable& operand, const std::vector<int>& axis);

  std::vector<Variable> BnMeanVariance(const Variable& x);

  std::vector<Variable> BnGradBiasScale(const Variable& x, const Variable& x_mean, const Variable& y_grad);

 private:
  Variable UnaryOp(const std::string& op_type, const Variable& operand);

  Variable BinaryOp(const std::string& op_type, const Variable& lhs, const Variable& rhs);
};

}  // namespace frontend
}  // namespace cinn