Program Listing for File lower_impl.h

Return to documentation for file (/WorkSpace/CINN/cinn/lang/lower_impl.h)

// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
#include <absl/container/flat_hash_map.h>

#include <iostream>
#include <map>
#include <memory>
#include <set>
#include <stack>
#include <string>
#include <unordered_set>
#include <utility>
#include <vector>

#include "cinn/common/graph_utils.h"
#include "cinn/ir/buffer.h"
#include "cinn/ir/ir_printer.h"
#include "cinn/optim/buffer_assign.h"
#include "cinn/optim/compute_inline_expand.h"
#include "cinn/optim/fold_cinn_call_arguments.h"
#include "cinn/optim/optimize.h"
#include "cinn/optim/remove_nested_block.h"
#include "cinn/optim/replace_call_with_expr.h"
#include "cinn/optim/tensor_write_tell.h"
#include "cinn/optim/transform_gpu_forloop.h"
#include "cinn/optim/transform_polyfor_to_for.h"
#include "cinn/poly/ast_gen.h"

namespace cinn {

namespace poly {
class Stage;
}  // namespace poly

namespace lang {
namespace detail {

void CheckNoIslCallRemains(const Expr* expr);

Expr LowerGroup(const poly::ScheduleGroup& group,
                const std::map<std::string, Expr>& tuple_to_expr,
                std::map<std::string, Tensor>* global_tensor_map,
                std::unordered_set<std::string>& resized_buffer,
                StageMap stage_map,
                ir::CudaAxisInfo* cuda_axis_info = nullptr);

struct CompuGraphNode : public common::GraphNode {
  explicit CompuGraphNode(ir::Tensor tensor) : tensor(tensor) {}

  ir::Tensor tensor;

  std::string id() const override;
  const char* type_info() const override;
  static const char* __type_info__;
};

std::unique_ptr<common::Graph> CreateCompGraph(const std::vector<ir::Tensor>& tensors,
                                               StageMap stages,
                                               bool hide_inline = false);

class LowerImpl {
 public:
  LowerImpl(const std::string& fn_name,
            StageMap stages,
            const std::vector<Tensor>& tensor_args,
            const std::vector<Var>& scalar_args,
            const std::vector<Tensor>& temp_tensor_args = {},
            const Target& target                        = common::DefaultHostTarget());

  std::vector<ir::LoweredFunc> operator()();

  const common::Graph* comp_graph() const { return compu_graph_.get(); }

  std::vector<ir::Argument> GenerateFunctionArgumentList(Expr fn_body);

  std::vector<ir::Argument> GenFuncArgForSplitKernel(Expr func_iterator, std::vector<ir::Tensor> temp_tensors);

  std::vector<Expr> GenerateFunctionBody(const poly::Schedule* schedule);

 private:
  std::vector<Tensor> CollectTemporaryTensors();

  void CheckArgsUnique();

  inline absl::flat_hash_map<std::string, Tensor> GenTensorArgMap();

  inline absl::flat_hash_map<std::string, Tensor> GenAllTensorMap();

  std::vector<Tensor> CollectAllTensors();

  std::set<std::pair<std::string, std::string>> CollectExtraDependencies() const;

 private:
  const std::string& fn_name_;
  const std::vector<Tensor>& tensor_args_;
  const std::vector<Var>& scalar_args_;
  std::vector<Tensor> temp_tensor_args_;
  Target target_;

  StageMap stages_;

  std::unique_ptr<common::Graph> compu_graph_;

  std::vector<ir::CudaAxisInfo> cuda_axis_info_;
};

bool TensorContainsGPUInfo(ir::Tensor t, poly::Stage* stage);

struct MarkVectorizeMutator : public ir::IRMutator<Expr*> {
  const std::map<std::string, ir::VectorizeInfo>& vectorizes;

  explicit MarkVectorizeMutator(const std::map<std::string /*tensor name*/, ir::VectorizeInfo>& vectorizes)
      : vectorizes(vectorizes) {}

  void operator()(Expr* expr) { ir::IRMutator<Expr*>::Visit(expr, expr); }

  // NOTE This mutator takes PolyFor as input, not For.
  void Visit(const ir::PolyFor* op, Expr* expr) override {
    auto* node = expr->As<ir::PolyFor>();
    forloop_stack.push_back(node);
    ir::IRMutator<ir::Expr*>::Visit(op, expr);
    forloop_stack.pop_back();
  }

  // each statement in ISL is bound to a Store node.
  void Visit(const ir::Store* op, Expr* expr) override {
    auto* tensor_n = op->tensor.As<ir::_Tensor_>();
    CHECK(tensor_n);
    auto it = vectorizes.find(tensor_n->name);
    if (it != vectorizes.end()) {
      CHECK_LT(it->second.level, forloop_stack.size());
      forloop_stack[it->second.level]->set_vectorize_info(it->second);
      CHECK(it->second.valid());
    }
  }

  std::vector<ir::PolyFor*> forloop_stack;
};

struct MarkUnrollMutator : public ir::IRMutator<Expr*> {
  std::map<std::string, std::set<int> /*level*/> unrolls;

  explicit MarkUnrollMutator(const std::map<std::string, std::set<int>>& unrolls) : unrolls(unrolls) {}

  void operator()(Expr* expr) { ir::IRMutator<>::Visit(expr, expr); }

  void Visit(const ir::PolyFor* op, Expr* expr) override {
    auto* node = expr->As<ir::PolyFor>();
    stack.push_back(node);
    ir::IRMutator<>::Visit(op, expr);
    stack.pop_back();
  }

  // each statement in ISL is bound to a Store node.
  void Visit(const ir::Store* op, Expr* expr) override {
    auto* tensor_n = op->tensor.As<ir::_Tensor_>();
    CHECK(tensor_n);
    auto it = unrolls.find(tensor_n->name);
    if (it != unrolls.end()) {
      for (int level : it->second) {
        VLOG(1) << "Mark " << level << " Unrolled";
        CHECK_LT(level, stack.size());
        stack[level]->set_unrolled();
      }
    }
  }

  std::vector<ir::PolyFor*> stack;
};

struct MarkParallelMutator : public ir::IRMutator<Expr*> {
  std::map<std::string, std::set<int> /*level*/> parallels;

  explicit MarkParallelMutator(const std::map<std::string, std::set<int>>& parallels) : parallels(parallels) {}

  void operator()(Expr* expr) { ir::IRMutator<>::Visit(expr, expr); }

  void Visit(const ir::PolyFor* op, Expr* expr) override {
    auto* node = expr->As<ir::PolyFor>();
    stack.push_back(node);
    ir::IRMutator<>::Visit(op, expr);
    stack.pop_back();
  }

  // each statement in ISL is bound to a Store node.
  void Visit(const ir::Store* op, Expr* expr) override {
    auto* tensor_n = op->tensor.As<ir::_Tensor_>();
    CHECK(tensor_n);
    auto it = parallels.find(tensor_n->name);
    if (it != parallels.end()) {
      for (int level : it->second) {
        VLOG(1) << "Mark " << level << " Paralled";
        CHECK_LT(level, stack.size());
        stack[level]->set_parallel();
      }
    }
  }

  std::vector<ir::PolyFor*> stack;
};

}  // namespace detail
}  // namespace lang
}  // namespace cinn