Program Listing for File compute.h

Return to documentation for file (/WorkSpace/CINN/cinn/lang/compute.h)

// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
#include <absl/types/variant.h>

#include <functional>
#include <map>
#include <string>
#include <utility>
#include <vector>

#include "cinn/ir/ir.h"
#include "cinn/ir/ir_operators.h"
#include "cinn/lang/placeholder.h"
#include "cinn/poly/schedule.h"

namespace cinn {
namespace lang {

using compute_handler_t = std::function<Expr(const std::vector<Expr> &)>;
using attr_t            = absl::variant<int, float, bool, std::string>;

// @{
// The shape are constant integers.
ir::Tensor Compute(const std::vector<Expr> &domain,
                   std::function<Expr()> fn,
                   const std::string &name        = "",
                   const std::vector<Expr> &shape = {});
ir::Tensor Compute(const std::vector<Expr> &domain,
                   std::function<Expr(Expr)> fn,
                   const std::string &name        = "",
                   const std::vector<Expr> &shape = {});
ir::Tensor Compute(const std::vector<Expr> &domain,
                   std::function<Expr(Expr, Expr)> fn,
                   const std::string &name        = "",
                   const std::vector<Expr> &shape = {});
ir::Tensor Compute(const std::vector<Expr> &domain,
                   std::function<Expr(Expr, Expr, Expr)> fn,
                   const std::string &name        = "",
                   const std::vector<Expr> &shape = {});
ir::Tensor Compute(const std::vector<Expr> &domain,
                   std::function<Expr(Expr, Expr, Expr, Expr)> fn,
                   const std::string &name        = "",
                   const std::vector<Expr> &shape = {});

ir::Tensor Compute(const std::vector<Expr> &domain,
                   std::function<Expr(Expr, Expr, Expr, Expr, Expr)> fn,
                   const std::string &name        = "",
                   const std::vector<Expr> &shape = {});

ir::Tensor Compute(const std::vector<Expr> &domain,
                   std::function<Expr(Expr, Expr, Expr, Expr, Expr, Expr)> fn,
                   const std::string &name        = "",
                   const std::vector<Expr> &shape = {});

ir::Tensor Compute(const std::vector<Expr> &domain,
                   compute_handler_t fn,
                   const std::string &name        = "",
                   const std::vector<Expr> &shape = {});
// @}

struct ReturnType {
  Type type;
  std::vector<Expr> dims;
  std::string name;
};

std::vector<ir::Tensor> CallLowered(const std::string &target,
                                    const std::vector<Expr> &args,
                                    const std::vector<ReturnType> &return_types);

Expr CallExtern(const std::string &target,
                const std::vector<Expr> &args,
                const std::map<std::string, attr_t> &attrs = {});

}  // namespace lang
}  // namespace cinn