Program Listing for File builtin.cc

Return to documentation for file (/WorkSpace/CINN/cinn/lang/builtin.cc)

// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "cinn/lang/builtin.h"

#include <cmath>
#include <limits>
#include <utility>

#include "cinn/cinn.h"
#include "cinn/common/ir_util.h"
#include "cinn/ir/ir.h"
#include "cinn/lang/buffer.h"

namespace cinn {
namespace lang {

Expr logic_and(const std::vector<Expr>& conds) {
  CHECK(!conds.empty());
  auto start = ir::And::Make(conds[0], conds[1]);
  for (int i = 2; i < conds.size(); i++) {
    start = ir::And::Make(start, conds[i]);
  }
  return start;
}

Expr logic_or(const std::vector<Expr>& conds) {
  CHECK(!conds.empty());
  auto start = ir::Or::Make(conds[0], conds[1]);
  for (int i = 2; i < conds.size(); i++) {
    start = ir::Or::Make(start, conds[i]);
  }
  return start;
}

#define EXTERN_CALL_IMP(name__, target__) \
  Expr name__(Expr e) { return ir::Call::Make(e->type(), #target__, {e}, {}, ir::CallType::Extern); }

#define EXTERN_CALL_IMP_NO_VEC(name__, target__)                                                               \
  Expr name__(Expr e) {                                                                                        \
    return ir::Call::Make(                                                                                     \
        e->type(), #target__, {e}, {}, ir::CallType::Extern, ir::FunctionRef(), 0, {{"vectorizable", false}}); \
  }

EXTERN_CALL_IMP(Exp, exp);
EXTERN_CALL_IMP_NO_VEC(Erf, erf);
EXTERN_CALL_IMP(Sqrt, sqrt);
EXTERN_CALL_IMP(Rsqrt, rsqrt);
EXTERN_CALL_IMP(Log, log);
EXTERN_CALL_IMP(Log2, log2);
EXTERN_CALL_IMP(Log10, log10);
EXTERN_CALL_IMP(Floor, floor);
EXTERN_CALL_IMP(Ceil, ceil);
EXTERN_CALL_IMP(Round, round);
EXTERN_CALL_IMP(Trunc, trunc);
EXTERN_CALL_IMP(Cos, cos);
EXTERN_CALL_IMP(Sin, sin);
EXTERN_CALL_IMP(Cosh, cosh);
EXTERN_CALL_IMP(Tan, tan);
EXTERN_CALL_IMP(Tanh, tanh);
EXTERN_CALL_IMP(Sinh, sinh);
EXTERN_CALL_IMP_NO_VEC(Acos, acos);
EXTERN_CALL_IMP_NO_VEC(Acosh, acosh);
EXTERN_CALL_IMP_NO_VEC(Asin, asin);
EXTERN_CALL_IMP_NO_VEC(Asinh, asinh);
EXTERN_CALL_IMP_NO_VEC(Atan, atan);
EXTERN_CALL_IMP_NO_VEC(Atanh, atanh);

Expr min_value(const Type& type) {
  CHECK_EQ(type.lanes(), 1);
#define FOR_CASE(type__)                                \
  if (type == type_of<type__>()) {                      \
    return Expr(std::numeric_limits<type__>::lowest()); \
  }
  FOR_CASE(int32_t)
  FOR_CASE(int64_t)
  FOR_CASE(uint32_t)
  FOR_CASE(uint64_t)
  FOR_CASE(float)
  FOR_CASE(double)
#undef FOR_CASE
  return Expr();
}

Expr max_value(const Type& type) {
  CHECK_EQ(type.lanes(), 1);

#define FOR_CASE(type__)                             \
  if (type == type_of<type__>()) {                   \
    return Expr(std::numeric_limits<type__>::max()); \
  }
  FOR_CASE(int32_t)
  FOR_CASE(int64_t)
  FOR_CASE(uint32_t)
  FOR_CASE(uint64_t)
  FOR_CASE(float)
  FOR_CASE(double)
#undef FOR_CASE

  CINN_NOT_IMPLEMENTED
  return Expr();
}

Expr Abs(Expr e) {
  Type type      = e->type();
  Type bool_type = Bool(type.lanes());
  if (type.is_uint()) {
    return e;
  } else if (type.is_int()) {
    auto node = e.As<ir::IntImm>();
    if (node) {
      return make_const(type, std::abs(node->value));
    }
    return ir::Select::Make(e > make_const(e->type(), 0), e, -e);
  } else if (type.is_float()) {
    auto node = e.As<ir::FloatImm>();
    if (node) {
      return make_const(type, std::fabs(node->value));
    }
    return CallExtern("fabs", {e});
  }
}

Expr IsNan(Expr e) {
  Type type = e->type();
  if (type.is_int() || type.is_uint()) {
    return common::make_bool(false, type.lanes());
  } else if (type.is_float()) {
    auto* node = e.As<ir::FloatImm>();
    if (node) {
      return common::make_bool(std::isnan(node->value), type.lanes());
    }
    Expr arg = e;
    if (type.bits() == 16) {
      arg = ir::Cast::Make(Float(32), std::move(e));
    }
    return CallExtern("isnan", {arg}, {{"vectorizable", false}});
  } else {
    LOG(FATAL) << type << "is not supported for isnan op.";
    return e;
  }
}

Expr Infinity(const Type& type) {
  CHECK_EQ(type.lanes(), 1U);
  if (type.is_float()) {
    if (type.bits() == 64) {
      return make_const(type, std::numeric_limits<double>::infinity());
    } else if (type.bits() == 32 || type.bits() == 16) {
      return make_const(type, std::numeric_limits<float>::infinity());
    }
  }
  LOG(FATAL) << "Cannot decide infinity for type " << type;
  return Expr();
}

Expr IsInf(Expr e) {
  Type type = e->type();
  if (type.is_int() || type.is_uint()) {
    return common::make_bool(false, type.lanes());
  } else if (type.is_float()) {
    Expr arg = e;
    return CallExtern("isinf", {arg}, {{"vectorizable", false}});
  } else {
    LOG(FATAL) << type << "is not supported for isinf op.";
    return e;
  }
}

Expr IsFinite(Expr e) { return !IsInf(e) && !IsNan(e); }

}  // namespace lang
}  // namespace cinn