Program Listing for File compute.cc

Return to documentation for file (/WorkSpace/CINN/cinn/lang/compute.cc)

// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "cinn/lang/compute.h"

#include "cinn/backends/extern_func_protos.h"
#include "cinn/common/common.h"
#include "cinn/ir/operation.h"
#include "cinn/optim/ir_simplify.h"
#include "cinn/poly/dim.h"
#include "cinn/poly/domain.h"
#include "cinn/poly/stage.h"
#include "cinn/runtime/use_extern_funcs.h"

namespace cinn {
namespace lang {

ir::Tensor Compute(const std::vector<Expr> &domain,
                   std::function<Expr()> fn,
                   const std::string &name,
                   const std::vector<Expr> &shape) {
  return Compute(
      domain,
      [fn](const std::vector<Expr> &axis) -> Expr {
        // CHECK_EQ(axis.size(), 0);
        return fn();
      },
      name,
      shape);
}

ir::Tensor Compute(const std::vector<Expr> &domain,
                   std::function<Expr(Expr)> fn,
                   const std::string &name,
                   const std::vector<Expr> &shape) {
  return Compute(
      domain,
      [fn](const std::vector<Expr> &axis) -> Expr {
        CHECK_EQ(axis.size(), 1);
        return fn(axis[0]);
      },
      name,
      shape);
}

ir::Tensor Compute(const std::vector<Expr> &domain,
                   std::function<Expr(Expr, Expr)> fn,
                   const std::string &name,
                   const std::vector<Expr> &shape) {
  return Compute(
      domain,
      [fn](const std::vector<Expr> &axis) -> Expr {
        CHECK_EQ(axis.size(), 2);
        return fn(axis[0], axis[1]);
      },
      name,
      shape);
}

ir::Tensor Compute(const std::vector<Expr> &domain,
                   std::function<Expr(Expr, Expr, Expr)> fn,
                   const std::string &name,
                   const std::vector<Expr> &shape) {
  return Compute(
      domain,
      [fn](const std::vector<Expr> &axis) -> Expr {
        CHECK_EQ(axis.size(), 3);
        return fn(axis[0], axis[1], axis[2]);
      },
      name,
      shape);
}

ir::Tensor Compute(const std::vector<Expr> &domain,
                   std::function<Expr(Expr, Expr, Expr, Expr)> fn,
                   const std::string &name,
                   const std::vector<Expr> &shape) {
  return Compute(
      domain,
      [fn](const std::vector<Expr> &axis) -> Expr {
        CHECK_EQ(axis.size(), 4);
        return fn(axis[0], axis[1], axis[2], axis[3]);
      },
      name,
      shape);
}

ir::Tensor Compute(const std::vector<Expr> &domain,
                   std::function<Expr(Expr, Expr, Expr, Expr, Expr)> fn,
                   const std::string &name,
                   const std::vector<Expr> &shape) {
  return Compute(
      domain,
      [fn](const std::vector<Expr> &axis) -> Expr {
        CHECK_EQ(axis.size(), 5);
        return fn(axis[0], axis[1], axis[2], axis[3], axis[4]);
      },
      name,
      shape);
}

ir::Tensor Compute(const std::vector<Expr> &domain,
                   std::function<Expr(Expr, Expr, Expr, Expr, Expr, Expr)> fn,
                   const std::string &name,
                   const std::vector<Expr> &shape) {
  return Compute(
      domain,
      [fn](const std::vector<Expr> &axis) -> Expr {
        CHECK_EQ(axis.size(), 6);
        return fn(axis[0], axis[1], axis[2], axis[3], axis[4], axis[5]);
      },
      name,
      shape);
}

ir::Tensor Compute(const std::vector<Expr> &domain,
                   std::function<Expr(const std::vector<Expr> &)> fn,
                   const std::string &name,
                   const std::vector<Expr> &shape) {
  auto axises = common::GenDefaultAxis(domain.size());
  std::vector<Expr> _axis;
  for (auto &x : axises) _axis.push_back(x);
  Expr fn_body = fn(_axis);

  std::vector<Var> reduce_axis;
  if (fn_body.defined() && fn_body.As<ir::Reduce>()) {
    auto &fn_reduce_axis = fn_body.As<ir::Reduce>()->reduce_axis;
    reduce_axis.insert(std::begin(reduce_axis), fn_reduce_axis.begin(), fn_reduce_axis.end());
  }

  // When the fn_body is a CallExtern, a tensor will return directly.
  if (fn_body.as_tensor()) {
    return fn_body.as_tensor_ref();
  }

  // shape is the buffer's shape.
  std::vector<Expr> domain_without_reduce_axis;
  std::vector<Expr> shape_simplified;

  // construct the shape.
  for (auto dim : domain) {
    auto copied = dim;
    optim::Simplify(&copied);
    domain_without_reduce_axis.push_back(copied);
  }

  for (auto dim : shape) {
    auto copied = dim;
    optim::Simplify(&copied);
    shape_simplified.push_back(copied);
  }

  auto real_shape = shape_simplified.empty() ? domain_without_reduce_axis : shape_simplified;

  // The body returns void, that means no buffer is needed.
  if (fn_body.type() == Void()) real_shape.clear();

  auto unique_name = name.empty() ? Context::Global().NewName("tensor") : name;

  // check reduce_axis not include the reserved axis name
  for (auto &ra : reduce_axis) {
    CHECK(!common::IsAxisNameReserved(ra->name)) << "reduce axis [" << ra->name << "]'s name is reserved";
  }

  VLOG(3) << "domain: " << domain_without_reduce_axis;

  auto op     = ir::ComputeOp::Make(unique_name, fn, real_shape, domain_without_reduce_axis, reduce_axis);
  auto tensor = ir::Tensor(unique_name, fn_body.type(), real_shape, domain_without_reduce_axis, op, reduce_axis);
  return tensor;
}

std::vector<ir::Tensor> CallLowered(const std::string &target,
                                    const std::vector<Expr> &args,
                                    const std::vector<ReturnType> &return_types) {
  auto call = ir::Call::Make(Void(), target, args, {}, ir::CallType::CINN, ir::FunctionRef(), 0);
  std::vector<ir::Tensor> new_tensors;
  for (int i = 0; i < return_types.size(); i++) {
    auto &return_type = return_types[i];
    auto call_op      = ir::CallOp::Make(target, call);
    auto new_tensor   = ir::Tensor(return_type.name, return_type.type, return_type.dims, {Expr(1)}, call_op);
    // Append write tensors in the tail.
    call.As<ir::Call>()->write_args.push_back(new_tensor);
    new_tensor->set_type(return_type.type);
    new_tensor->WithBuffer();
    new_tensors.push_back(new_tensor);
  }

  return new_tensors;
}

Expr CallExtern(const std::string &target, const std::vector<Expr> &args, const std::map<std::string, attr_t> &attrs) {
  auto *proto = backends::ExternFunctionProtoRegistry::Global().Lookup(target);
  CHECK(proto) << "No extern function prototype " << target << " found\n"
               << "existing records are:\n"
               << backends::ExternFunctionProtoRegistry::Global().debug_string();

  auto call = ir::Call::Make(proto->ret_type, target, args, {}, ir::CallType::Extern, ir::FunctionRef(), 0, attrs);
  std::vector<Expr> mutable_args;
  // Call a function with multiple outputs.
  if (proto->ret_type.is_void()) {
    for (int i = 0; i < proto->mutable_arg_types.size(); i++) {
      auto shape                         = proto->shape_inference(args, i);
      auto op                            = ir::CallOp::Make(target, call);
      op->as<ir::CallOp>()->value_slot   = i;
      op->as<ir::CallOp>()->is_tuple_get = true;
      auto name = Context::Global().NewName("tuple_" + target + "_out" + std::to_string(i) + "_");
      auto ret  = ir::Tensor(name, proto->mutable_arg_types[i], shape, shape, op, {});
      mutable_args.push_back(ret);
    }
    call.As<ir::Call>()->write_args = mutable_args;
  }
  return call;
}

}  // namespace lang
}  // namespace cinn