Program Listing for File lower_impl.cc

Return to documentation for file (/WorkSpace/CINN/cinn/lang/lower_impl.cc)

// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "cinn/lang/lower_impl.h"

#include <algorithm>
#include <queue>
#include <string>
#include <unordered_set>

#include "cinn/common/context.h"
#include "cinn/common/ir_util.h"
#include "cinn/ir/ir_printer.h"
#include "cinn/ir/tensor.h"
#include "cinn/poly/stage.h"

namespace cinn {
namespace lang {
namespace detail {

void CheckNoIslCallRemains(Expr* expr) {
  auto isl_calls = ir::CollectIRNodes(
      *expr, [](const Expr* expr) { return expr->As<ir::Call>() && expr->As<ir::Call>()->is_isl_call(); });
#ifdef CINN_DEBUG
  for (auto& item : isl_calls) {
    LOG(ERROR) << "ISL call: " << item;
  }
#endif
  if (!isl_calls.empty()) {
    LOG(WARNING) << "Some ISL call nodes remained, get " << isl_calls.size() << " isl_calls, the first one is "
                 << *isl_calls.begin();
  }
}

void BindBuffer(StageMap& stages) {
  absl::flat_hash_map<std::string, ir::_Tensor_*> tensor_map;
  for (auto& stage : stages) {
    tensor_map[stage.second->tensor()->name] = stage.second->tensor();
  }
  for (auto& stage : stages) {
    if (!stage.second->tensor()->buffer.defined() && !stage.second->meta.tensors_to_share_buffer_with.empty()) {
      for (auto& str : stage.second->meta.tensors_to_share_buffer_with) {
        if (tensor_map[str]->buffer.defined()) {
          auto edited_shape = tensor_map[str]->buffer->shape;
          stage.second->tensor()->Bind(tensor_map[str]->buffer);
          tensor_map[str]->buffer->shape = edited_shape;
          VLOG(3) << "Tensor " << stage.second->tensor()->name << " bind buffer to " << tensor_map[str]->name << " , "
                  << tensor_map[str]->buffer->name;
        }
      }
    }
  }
}

Expr LowerGroup(const poly::ScheduleGroup& group,
                const std::map<std::string, Expr>& tuple_to_expr,
                std::map<std::string, ir::Tensor>* global_tensor_map,
                std::unordered_set<std::string>& resized_buffer,
                StageMap stage_map,
                ir::CudaAxisInfo* cuda_axis_info) {
  BindBuffer(stage_map);
  std::vector<poly::Stage*> stages;
  for (auto& node : group.nodes) {
    VLOG(1) << "In LowerGroup, node id is: " << node->id();
    if (node->stage->has_expression()) {
      stages.push_back(node->stage);
      VLOG(1) << "stage expr " << node->stage->expr();
    } else {
      VLOG(1) << "stage expression is null: " << node->stage->domain();
    }
  }

  if (stages.empty()) return Expr();

  // get isl generated expression
  isl::set context(Context::isl_ctx(), "{:}");
  poly::AstGen gen(context, stages, group);
  isl::ast_node ast = gen.Build();
  ir::Expr e;
  poly::IslAstNodeToCinnExpr(ast, &e);
  // now we get a workable expression, but the statement are something like `B(((16 * po0) + po1), po2)`, we need to
  // transform this to some realworld statement in CINN.

  VLOG(1) << "ast to expr: \n" << e << std::endl;

  // replace isl call to the corresponding CINN statement, we need to replace the axis at the same time.
  for (auto& statement : tuple_to_expr) {
    VLOG(2) << "LowerGroup working on statement: " << statement.first;
    if (!gen.ContainsStatement(statement.first)) continue;
    // the axis_ast_map contains the axis from the original (like `i`) to the transformed (like `i+3`).
    auto axis_expr_map = gen.axis2expr(statement.first);
    for (auto& item : axis_expr_map) {
      VLOG(4) << "statement ast map axis [" << item.first << "] to "
              << "[" << item.second << "]";
    }

    // the original CINN statements.
    Expr statement_candi_expr = tuple_to_expr.at(statement.first);

    VLOG(3) << "replacing " << statement.first << " to " << statement_candi_expr;
    optim::ReplaceIslCallWithExpr(&e, statement.first, statement_candi_expr, axis_expr_map);
  }
  CheckNoIslCallRemains(&e);

  // Update global_tensor_map
  for (auto& e : stage_map) {
    if (!global_tensor_map->count(e.second->id())) {
      (*global_tensor_map)[e.second->id()] = ir::Tensor(e.second->tensor());
    }
  }

  // mark vectorize.
  {
    std::map<std::string, ir::VectorizeInfo> vectorizes;
    for (auto& node : group.nodes) {
      if (node->stage->vectorize_info().valid()) {
        vectorizes[node->stage->id()] = node->stage->vectorize_info();
      }
    }
    MarkVectorizeMutator mutator(vectorizes);
    mutator(&e);
  }

  // mark unroll.
  {
    std::map<std::string, std::set<int>> unrolls;
    for (auto& node : group.nodes) {
      if (!node->stage->unroll_info().empty()) {
        unrolls[node->stage->id()] = node->stage->unroll_info();
      }
    }
    MarkUnrollMutator mutator(unrolls);
    mutator(&e);
  }

  // mark parallel.
  {
    std::map<std::string, std::set<int>> parallels;
    for (auto& node : group.nodes) {
      if (!node->stage->parallel_info().empty()) {
        parallels[node->stage->id()] = node->stage->parallel_info();
      }
    }
    MarkParallelMutator mutator(parallels);
    mutator(&e);
  }

  // mark gpu threads
#ifdef CINN_WITH_CUDA
  {
    optim::forloop_infos_t forloop_infos;
    std::vector<std::string> traverse_order;
    std::set<std::string> temp_set;
    for (auto* stage : stages) {
      // transform the level identified for infors to iter name identified.
      auto iters = common::GatherItersToTensorProducer(stage->id(), &e);
      std::map<std::string, poly::StageForloopInfo> for_infos;
      for (auto& item : stage->forloop_infos()) {
        if (item.first < 0) continue;
        CHECK_LT(item.first, iters.size());
        for_infos[iters[item.first]] = item.second;
      }

      forloop_infos[stage->id()] = for_infos;
    }

    for (auto* stage : stages) {
      CHECK_EQ((*global_tensor_map).count(stage->id()), 1) << "Global_Tensor_Map doesn't contain " << stage->id();
      CHECK_EQ(forloop_infos.count(stage->id()), 1) << "forloop_infos doesn't contain " << stage->id();
      if (stage->ctrl_depends().size() > 0) {
        for (auto& i : stage->ctrl_depends()) {
          CHECK_EQ((*global_tensor_map).count(i->name), 1) << "Global_Tensor_Map doesn't contain " << i->name;
          if (forloop_infos.count(i->name) == 0) continue;
          if (temp_set.count(i->name) == 0) {
            traverse_order.push_back(i->name);
            temp_set.insert(i->name);
          }
        }
      }

      if (temp_set.count(stage->id()) == 0) {
        traverse_order.push_back(stage->id());
        temp_set.insert(stage->id());
      }
    }
    std::reverse(traverse_order.begin(), traverse_order.end());

    optim::TransformGpuForloops(forloop_infos, traverse_order, global_tensor_map, resized_buffer, &e);
    auto axis_info = optim::GatherAxisInfoFromStages(stages);
    if (axis_info.valid()) cuda_axis_info->ExtendWith(axis_info);
  }
#endif  // CINN_WITH_CUDA
  return e;
}

bool TensorContainsGPUInfo(ir::Tensor t, poly::Stage* stage) {
  if (stage->inlined()) return false;
  if (stage) {
    for (auto& info : stage->forloop_infos()) {
      if (info.second.device == ir::DeviceAPI::GPU) {
        return true;
      }
    }
  }
  return false;
}

const char* CompuGraphNode::__type_info__ = "ComputeGraphNode";
const char* CompuGraphNode::type_info() const { return __type_info__; }
std::string CompuGraphNode::id() const {
  CHECK(tensor.defined());
  return tensor->name;
}

void CreateCompGraphWithInlineTensors(common::Graph* graph,
                                      const ir::Tensor& t,
                                      StageMap stages,
                                      std::set<ir::Tensor>* visited) {
  if (visited->count(t)) return;
  common::GraphNode* t_node = graph->RetrieveNode(t->name);
  if (!t_node) {
    t_node = graph->RegisterNode(t->name, new CompuGraphNode(t));
  }

  visited->insert(t);

  // collect dependency tensors of t
  // here we just collect the tensors in Load nodes
  // NOTE there may be some other cases.
  auto deps = ir::CollectLoadTensors(t->body(), [](const Expr* x) { return x->as_tensor(); });
  for (const auto& dep : deps) {
    auto e_tensor = dep.as_tensor_ref();
    auto* e_node  = graph->RetrieveNode(e_tensor->name);
    if (!e_node) {
      e_node = graph->RegisterNode(e_tensor->name, new CompuGraphNode(e_tensor));
    }
    e_node->LinkTo(t_node);
    if (!visited->count(e_tensor)) {
      CreateCompGraphWithInlineTensors(graph, e_tensor, stages, visited);
    }
  }
}

std::unique_ptr<common::Graph> CreateCompGraphWithInlineTensorHidden(const std::vector<ir::Tensor>& tensors,
                                                                     StageMap stages) {
  // create a graph with inline tensor first.
  std::unique_ptr<common::Graph> graph(new common::Graph);
  std::set<ir::Tensor> visited;
  for (auto& t : tensors) {
    CreateCompGraphWithInlineTensors(graph.get(), t, stages, &visited);
  }

  // greedy remove the inline tensor, each time merge the inputs of an inline tensor to its sink node.

  std::set<common::GraphNode*> inline_nodes;
  do {
    inline_nodes = graph->CollectNodes([&](const common::GraphNode* x) {
      auto* comp_node = x->safe_as<CompuGraphNode>();
      return stages[comp_node->tensor]->inlined();
    });
    if (inline_nodes.empty()) break;

    /*
     * A -> inlined -> B
     * C /
     * =>
     * A -> B
     * C /
     */
    for (auto* inline_node : inline_nodes) {
      // remove this node, merge its inputs to the sink nodes.
      auto inline_inlinks  = inline_node->inlinks();
      auto inline_outlinks = inline_node->outlinks();

      // unlink the inline node from its inputs and outputs
      for (auto& link : inline_inlinks) {
        link->source()->UnLinkTo(link->sink());
      }
      for (auto& link : inline_outlinks) {
        link->source()->UnLinkTo(link->sink());
      }

      // link inline node's input nodes to its output nodes.
      for (auto out_edge : inline_outlinks) {
        auto* out = out_edge->sink();
        for (auto in_edge : inline_inlinks) {
          auto* source = in_edge->source();
          source->LinkTo(out);
        }
      }

      graph->DropNode(inline_node);
    }
  } while (!inline_nodes.empty());

  return graph;
}

void CompuGraphAddCtrlDepLinks(common::Graph* graph, StageMap stages) {
  for (auto& x : graph->nodes()) {
    auto* node = x->safe_as<CompuGraphNode>();
    CHECK(node);
    for (auto& dep : stages[node->tensor]->ctrl_depends()) {
      auto* dep_node = graph->RetrieveNode(dep->name);
      if (dep_node) {
        VLOG(3) << "Add control link: " << dep << " -> " << node->id();
        dep_node->LinkTo(node);
      }
    }
  }
}

std::unique_ptr<common::Graph> CreateCompGraph(const std::vector<ir::Tensor>& tensors,
                                               StageMap stages,
                                               bool hide_inline) {
  if (hide_inline) {
    auto graph = CreateCompGraphWithInlineTensorHidden(tensors, stages);
    CompuGraphAddCtrlDepLinks(graph.get(), stages);
    return graph;
  } else {
    auto graph = std::make_unique<common::Graph>();
    std::set<ir::Tensor> visited;
    for (auto& t : tensors) {
      CreateCompGraphWithInlineTensors(graph.get(), t, stages, &visited);
    }
    CompuGraphAddCtrlDepLinks(graph.get(), stages);
    return graph;
  }
}

void LowerImpl::CheckArgsUnique() {
  std::unordered_set<std::string> arg_names;
  for (auto& tensor : tensor_args_) {
    CHECK(!stages_[tensor]->inlined()) << "Inline tensor cannot be argument of function";
    CHECK(!arg_names.count(tensor->name))
        << "The argument of the function, tensor [" << tensor->name << "] duplicates in function " << fn_name_;
    arg_names.insert(tensor->name);
    if (!tensor->buffer.defined()) {
      LOG(ERROR) << "tensor [" << tensor->name << "] buffer is null";
      continue;
    }
    arg_names.insert(tensor->buffer->name);
  }

  for (auto& scalar : scalar_args_) {
    CHECK(!arg_names.count(scalar->name)) << "The argument of the function, scalar [" << scalar->name << "] duplicates";
    arg_names.insert(scalar->name);
  }
}

std::vector<ir::Argument> LowerImpl::GenerateFunctionArgumentList(Expr fn_body) {
  CheckArgsUnique();

  std::vector<ir::Argument> args;
  optim::TensorWriteTeller teller;
  teller.Collect(&fn_body);

  std::set<std::string> arg_names;

  for (auto& scalar : scalar_args_) {
    CHECK(!arg_names.count(scalar->name));
    auto* scalar_node = scalar.As<ir::_Var_>();
    CHECK(scalar_node->type().valid());
    arg_names.insert(scalar->name);

    args.emplace_back(scalar, ir::Argument::IO::kInput);
  }

  for (auto& tensor : tensor_args_) {
    auto* tensor_node = tensor.As<ir::_Tensor_>();
    bool is_output    = teller.IsWrite(tensor->name);
    VLOG(1) << "tensor argument " << tensor->name << " buffer " << tensor->buffer->name;

    // avoid duplicate
    if (!tensor_node->buffer.defined()) continue;
    // if a argument is already marked as kInput, mark it as kOutput and move it to the back.
    if (arg_names.count(tensor_node->buffer->name)) {
      auto it = std::find_if(
          args.begin(), args.end(), [&](const ir::Argument& x) { return x.name() == tensor_node->buffer->name; });
      CHECK(it != args.end());
      if (it->is_input()) {
        args.erase(it);
      } else if (it->is_output()) {
        continue;
      }
    }

    arg_names.insert(tensor_node->buffer->name);

    auto io = is_output ? ir::Argument::IO::kOutput : ir::Argument::IO::kInput;
    VLOG(3) << "Collect " << (is_output ? "W" : "R") << " argument " << tensor->buffer->name;
    args.emplace_back(tensor_node->buffer, io);
  }

  return args;
}
// Generate Function Arguments for splitted kernel.
std::vector<ir::Argument> LowerImpl::GenFuncArgForSplitKernel(Expr func_iterator,
                                                              std::vector<ir::Tensor> temp_tensors) {
  CheckArgsUnique();

  std::vector<ir::Argument> in_args;
  std::vector<ir::Argument> out_args;
  optim::TensorWriteTeller teller;
  teller.Collect(&func_iterator);
  std::set<std::string> arg_names;
  std::set<std::string> all_tensor_names;

  for (auto& scalar : scalar_args_) {
    CHECK(!arg_names.count(scalar->name));
    auto* scalar_node = scalar.As<ir::_Var_>();
    CHECK(scalar_node->type().valid());
    arg_names.insert(scalar->name);

    in_args.emplace_back(scalar, ir::Argument::IO::kInput);
  }

  auto all_tensors = ir::CollectIRNodes(
      func_iterator, [&](const Expr* x) { return x->as_tensor() && !stages_[x->as_tensor()]->inlined(); });

  auto all_vars = ir::CollectIRNodes(func_iterator, [&](const Expr* x) { return x->as_var(); });

  for (auto& i : all_tensors) {
    auto* tensor = i.as_tensor();
    all_tensor_names.insert(tensor->name);
    VLOG(3) << "In all_tensors, it has : " << tensor->name;
    if (!stages_[tensor]->meta.tensors_to_share_buffer_with.empty()) {
      for (auto& i : stages_[tensor]->meta.tensors_to_share_buffer_with) {
        all_tensor_names.insert(i);
        VLOG(3) << "And its share_buffer_tensor is : " << i;
      }
    }
    VLOG(3) << "In all_tensors, it has : " << tensor->name;
  }
  for (auto& i : all_vars) {
    auto* var = i.as_var();
    VLOG(3) << "In all_vars, it has : " << var->name;
  }

  for (auto& i : scalar_args_) {
    VLOG(3) << "In scalar_args_, var has : " << i->name;
  }

  std::set<std::string> temp_tensor_names;

  for (auto& i : temp_tensors) {
    VLOG(3) << "In temp_tensors, it has : " << i->name;
    temp_tensor_names.insert(i->name);
  }

  for (auto& tensor : tensor_args_) {
    VLOG(3) << "In tensor_args_, it has : " << tensor->name;
    if (temp_tensor_names.count(tensor->name) > 0) continue;
    if (all_tensor_names.count(tensor->name) == 0) continue;
    bool is_output = teller.IsWrite(tensor->name);
    VLOG(3) << "tensor argument " << tensor->name << " buffer " << tensor->buffer->name;

    // avoid duplicate
    if (!tensor->buffer.defined()) {
      VLOG(3) << "tensor->buffer is not defined";
      continue;
    }
    // if a argument is already marked as kInput, mark it as kOutput and move it to the back.
    if (arg_names.count(tensor->buffer->name)) {
      auto it = std::find_if(
          in_args.begin(), in_args.end(), [&](const ir::Argument& x) { return x.name() == tensor->buffer->name; });
      if (it != in_args.end()) {
        in_args.erase(it);
      } else {
        continue;
      }
    }

    arg_names.insert(tensor->buffer->name);

    auto io = is_output ? ir::Argument::IO::kOutput : ir::Argument::IO::kInput;
    if (io == ir::Argument::IO::kInput)
      in_args.emplace_back(tensor->buffer, io);
    else
      out_args.emplace_back(tensor->buffer, io);
  }
  std::vector<ir::Argument> args(in_args.begin(), in_args.end());
  args.insert(std::end(args), out_args.begin(), out_args.end());
  return args;
}

std::vector<Tensor> LowerImpl::CollectTemporaryTensors() {
  // a temporary should be in the comp_graph but not contained in the tensor_args.
  absl::flat_hash_map<std::string, Tensor> tensor_arg_map = GenTensorArgMap();
  absl::flat_hash_map<std::string, Tensor> temp_tensor_map;

  for (auto* node : compu_graph_->nodes()) {
    auto* cnode = node->safe_as<CompuGraphNode>();
    CHECK(cnode);
    if (!tensor_arg_map.count(cnode->tensor->name)) {
      temp_tensor_map[cnode->tensor->name] = cnode->tensor;
    }
  }

  std::vector<Tensor> temp_tensors;
  std::transform(temp_tensor_map.begin(),
                 temp_tensor_map.end(),
                 std::back_inserter(temp_tensors),
                 [&](const decltype(temp_tensor_map)::value_type& x) { return x.second; });
  return temp_tensors;
}

absl::flat_hash_map<std::string, Tensor> LowerImpl::GenTensorArgMap() {
  absl::flat_hash_map<std::string, Tensor> map;
  for (auto& t : tensor_args_) {
    map[t->name] = t;
  }
  return map;
}

absl::flat_hash_map<std::string, Tensor> LowerImpl::GenAllTensorMap() {
  absl::flat_hash_map<std::string, Tensor> map;
  for (auto& t : CollectAllTensors()) {
    map[t->name] = t;
  }
  return map;
}

std::vector<ir::LoweredFunc> LowerImpl::operator()() {
  std::vector<poly::Stage*> stages;
  std::map<std::string, ir::Tensor> all_tensor_map;
  for (auto& t : CollectAllTensors()) {
    all_tensor_map[t->name] = t;
    if (!stages_[t]->inlined()) stages.push_back(stages_[t]);
  }

  auto deps     = CollectExtraDependencies();
  auto schedule = poly::CreateSchedule(
      stages, poly::ScheduleKind::Poly, std::vector<std::pair<std::string, std::string>>(deps.begin(), deps.end()));
  auto func_body = GenerateFunctionBody(schedule.get());

  std::vector<ir::LoweredFunc> result;
  int num_func = 0;
  for (auto& func_iterator : func_body) {
    std::set<std::string> temp_tensor_names;
    for (auto& t : temp_tensor_args_) temp_tensor_names.insert(t->name);

    auto tensor_map =
        optim::InitialAssignBuffer(&func_iterator, stages_, all_tensor_map, comp_graph(), temp_tensor_names);
    // copy the tensor(with buffer assigned) back to func's args.
    {
      for (auto& arg : tensor_args_) {
        if (arg->is_placeholder_node()) continue;
        if (arg->buffer.defined()) continue;
        if (arg->body().As<ir::Call>() && arg->body().type().is_void()) continue;  // extern call
        if (tensor_map.find(arg->name) == tensor_map.end()) {
          LOG(INFO) << "Didn't find arg tensor " << arg->name << "in tensor_map.\n"
                    << "The function is " << fn_name_ << "\nAnd all the arg tensors are:\n";
          for (auto& i : tensor_args_) {
            LOG(INFO) << i->name;
          }
          LOG(FATAL) << "Fatal Error!";
        }
        Reference(&arg)->buffer = tensor_map.at(arg->name)->buffer;
      }
    }
    auto store_exprs = ir::CollectIRNodes(func_iterator, [](const Expr* x) { return x->As<ir::Store>(); });
    std::vector<ir::Tensor> new_temp_tensors;
    for (auto& expr : store_exprs) {
      auto* store_node = expr.As<ir::Store>();
      CHECK(store_node);
      auto* tensor = store_node->tensor.As<ir::_Tensor_>();
      CHECK(tensor);
      VLOG(3) << "In store_exprs, its name is : " << tensor->name;
      CHECK(tensor->buffer.defined());
      if (tensor->buffer->memory_type != ir::MemoryType::Heap) {
        new_temp_tensors.push_back(store_node->tensor.as_tensor_ref());
      }
    }

    auto func_temp_tensors = CollectTemporaryTensors();
    std::vector<ir::Buffer> temp_buffers;
    std::unordered_set<std::string> buffer_name_set;
    // TODO(Superjomn) write buffer latter.

    if (target_ == common::DefaultNVGPUTarget()) {
      for (auto& t : new_temp_tensors) {
        if (!tensor_map.count(t->name)) continue;
        auto& tt = tensor_map.at(t->name);
        if (tt->buffer.defined() && !buffer_name_set.count(tt->buffer->name)) {
          temp_buffers.push_back(tt->buffer);
          buffer_name_set.insert(tt->buffer->name);
        }
      }
    } else {
      for (auto& t : func_temp_tensors) {
        if (!tensor_map.count(t->name)) continue;
        auto& tt = tensor_map.at(t->name);
        if (tt->buffer.defined() && !buffer_name_set.count(tt->buffer->name)) {
          temp_buffers.push_back(tt->buffer);
          buffer_name_set.insert(tt->buffer->name);
        }
      }
    }

    ir::LoweredFunc func;
    if (target_ == common::DefaultNVGPUTarget()) {
      auto func_args2         = GenFuncArgForSplitKernel(func_iterator, new_temp_tensors);
      std::string new_fn_name = fn_name_;
      if (num_func > 0) {
        new_fn_name += "_" + std::to_string(num_func);
      }
      VLOG(3) << "Making func :" << new_fn_name;
      for (auto& i : func_args2) {
        VLOG(3) << "func_args2 is : " << i.name();
      }
      for (auto& i : temp_buffers) {
        VLOG(3) << "temp_buffers is : " << i->name;
      }
      func = ir::_LoweredFunc_::Make(new_fn_name, func_args2, func_iterator, temp_buffers);
    } else {
      auto func_args = GenerateFunctionArgumentList(func_iterator);
      func           = ir::_LoweredFunc_::Make(fn_name_, func_args, func_iterator, temp_buffers);
    }

    // some necessary modification.
    optim::ComputeInlineExpand(&func->body, stages_, &all_tensor_map);

    auto res = optim::Optimize(func, target_, FLAGS_cinn_runtime_display_debug_info);

    if (cuda_axis_info_.size() > num_func && cuda_axis_info_[num_func].valid()) {
      auto* res_func           = res.as_lowered_func();
      res_func->cuda_axis_info = cuda_axis_info_[num_func];
    }
    result.push_back(ir::LoweredFunc(res.get()));
    num_func++;
  }
  return result;
}

std::vector<Tensor> LowerImpl::CollectAllTensors() {
  std::vector<Tensor> tensors;
  auto topo_order = compu_graph_->topological_order();  // NOLINT
  auto& nodes     = std::get<0>(topo_order);
  auto& edges     = std::get<1>(topo_order);
  for (auto* node : nodes) {
    auto* cnode = node->safe_as<CompuGraphNode>();
    CHECK(cnode);
    tensors.push_back(cnode->tensor);
  }
  return tensors;
}

std::set<std::pair<std::string, std::string>> LowerImpl::CollectExtraDependencies() const {
  std::set<std::pair<std::string, std::string>> deps;
  for (auto* node : compu_graph_->nodes()) {
    auto* cnode = node->safe_as<CompuGraphNode>();
    CHECK(cnode);
    for (auto& dep : stages_[cnode->tensor]->ctrl_depends()) {
      deps.emplace(dep->name, cnode->tensor->name);
    }
  }
  return deps;
}

std::vector<Expr> LowerImpl::GenerateFunctionBody(const poly::Schedule* schedule) {
  // generate the expressions for each group.
  std::vector<Expr> exprs;
  std::vector<Expr> result;
  auto tensor_map = GenAllTensorMap();
  std::map<std::string, Expr> tuple_to_expr;
  CHECK(!schedule->groups.empty()) << "no group is generated";

  std::map<std::string, ir::Tensor> global_tensor_map;
  std::unordered_set<std::string> resized_buffer;

  for (auto& group : schedule->groups) {
    CHECK_GT(group.nodes.size(), 0) << "group is empty";
    for (auto& node : group.nodes) {
      if (!tensor_map.count(node->id())) {
        VLOG(2) << "tensor_map doesn't count " << node->id();
        continue;
      }
      auto& tensor = tensor_map[node->id()];
      if (!tensor->has_expression()) continue;
      tuple_to_expr[tensor->name] = tensor->tensor_store_expanded_body();
    }

    ir::CudaAxisInfo temp_cuda_axis_info;
    Expr group_expr =
        LowerGroup(group, tuple_to_expr, &global_tensor_map, resized_buffer, stages_, &temp_cuda_axis_info);

    if (group_expr.defined()) {
      cuda_axis_info_.emplace_back(std::move(temp_cuda_axis_info));
      if (target_ == common::DefaultNVGPUTarget()) {
        exprs.push_back(group_expr);
        Expr body = ir::Block::Make(exprs);
        result.push_back(body);
        exprs.clear();
      } else {
        exprs.push_back(group_expr);
      }
    }
  }
  if (target_ == common::DefaultHostTarget()) {
    Expr body = ir::Block::Make(exprs);
    result.push_back(body);
  }

  return result;
}

LowerImpl::LowerImpl(const std::string& fn_name,
                     StageMap stages,
                     const std::vector<Tensor>& tensor_args,
                     const std::vector<Var>& scalar_args,
                     const std::vector<Tensor>& temp_tensor_args,
                     const Target& target)
    : fn_name_(fn_name),
      stages_(stages),
      tensor_args_(tensor_args),
      scalar_args_(scalar_args),
      temp_tensor_args_(temp_tensor_args),
      target_(target) {
  {  // Initialize the graph
    std::vector<ir::Tensor> tensors(tensor_args.begin(), tensor_args.end());
    tensors.insert(std::end(tensors), temp_tensor_args.begin(), temp_tensor_args.end());

    compu_graph_ = CreateCompGraph(tensors, stages, false /*inline_hide*/);

    VLOG(1) << "compu_graph:\n" << compu_graph_->Visualize();
  }

  // Todo: Here insert auto syncthreads() @haoze

  {  // update schedule.
    std::vector<ir::Tensor> tensors(tensor_args.begin(), tensor_args.end());
    tensors.insert(std::end(tensors), temp_tensor_args_.begin(), temp_tensor_args_.end());
    compu_graph_ = CreateCompGraph(tensors, stages, true /*inline_hide*/);

    VLOG(1) << "Computation Graph:\n" << compu_graph_->Visualize();
  }
}

}  // namespace detail
}  // namespace lang
}  // namespace cinn