Program Listing for File lower_impl.h
↰ Return to documentation for file (/WorkSpace/CINN/cinn/lang/lower_impl.h
)
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <absl/container/flat_hash_map.h>
#include <iostream>
#include <map>
#include <memory>
#include <set>
#include <stack>
#include <string>
#include <unordered_set>
#include <utility>
#include <vector>
#include "cinn/common/graph_utils.h"
#include "cinn/ir/buffer.h"
#include "cinn/ir/ir_printer.h"
#include "cinn/optim/buffer_assign.h"
#include "cinn/optim/compute_inline_expand.h"
#include "cinn/optim/fold_cinn_call_arguments.h"
#include "cinn/optim/optimize.h"
#include "cinn/optim/remove_nested_block.h"
#include "cinn/optim/replace_call_with_expr.h"
#include "cinn/optim/tensor_write_tell.h"
#include "cinn/optim/transform_gpu_forloop.h"
#include "cinn/optim/transform_polyfor_to_for.h"
#include "cinn/poly/ast_gen.h"
namespace cinn {
namespace poly {
class Stage;
} // namespace poly
namespace lang {
namespace detail {
void CheckNoIslCallRemains(const Expr* expr);
Expr LowerGroup(const poly::ScheduleGroup& group,
const std::map<std::string, Expr>& tuple_to_expr,
std::map<std::string, Tensor>* global_tensor_map,
std::unordered_set<std::string>& resized_buffer,
StageMap stage_map,
ir::CudaAxisInfo* cuda_axis_info = nullptr);
struct CompuGraphNode : public common::GraphNode {
explicit CompuGraphNode(ir::Tensor tensor) : tensor(tensor) {}
ir::Tensor tensor;
std::string id() const override;
const char* type_info() const override;
static const char* __type_info__;
};
std::unique_ptr<common::Graph> CreateCompGraph(const std::vector<ir::Tensor>& tensors,
StageMap stages,
bool hide_inline = false);
class LowerImpl {
public:
LowerImpl(const std::string& fn_name,
StageMap stages,
const std::vector<Tensor>& tensor_args,
const std::vector<Var>& scalar_args,
const std::vector<Tensor>& temp_tensor_args = {},
const Target& target = common::DefaultHostTarget());
std::vector<ir::LoweredFunc> operator()();
const common::Graph* comp_graph() const { return compu_graph_.get(); }
std::vector<ir::Argument> GenerateFunctionArgumentList(Expr fn_body);
std::vector<ir::Argument> GenFuncArgForSplitKernel(Expr func_iterator, std::vector<ir::Tensor> temp_tensors);
std::vector<Expr> GenerateFunctionBody(const poly::Schedule* schedule);
private:
std::vector<Tensor> CollectTemporaryTensors();
void CheckArgsUnique();
inline absl::flat_hash_map<std::string, Tensor> GenTensorArgMap();
inline absl::flat_hash_map<std::string, Tensor> GenAllTensorMap();
std::vector<Tensor> CollectAllTensors();
std::set<std::pair<std::string, std::string>> CollectExtraDependencies() const;
private:
const std::string& fn_name_;
const std::vector<Tensor>& tensor_args_;
const std::vector<Var>& scalar_args_;
std::vector<Tensor> temp_tensor_args_;
Target target_;
StageMap stages_;
std::unique_ptr<common::Graph> compu_graph_;
std::vector<ir::CudaAxisInfo> cuda_axis_info_;
};
bool TensorContainsGPUInfo(ir::Tensor t, poly::Stage* stage);
struct MarkVectorizeMutator : public ir::IRMutator<Expr*> {
const std::map<std::string, ir::VectorizeInfo>& vectorizes;
explicit MarkVectorizeMutator(const std::map<std::string /*tensor name*/, ir::VectorizeInfo>& vectorizes)
: vectorizes(vectorizes) {}
void operator()(Expr* expr) { ir::IRMutator<Expr*>::Visit(expr, expr); }
// NOTE This mutator takes PolyFor as input, not For.
void Visit(const ir::PolyFor* op, Expr* expr) override {
auto* node = expr->As<ir::PolyFor>();
forloop_stack.push_back(node);
ir::IRMutator<ir::Expr*>::Visit(op, expr);
forloop_stack.pop_back();
}
// each statement in ISL is bound to a Store node.
void Visit(const ir::Store* op, Expr* expr) override {
auto* tensor_n = op->tensor.As<ir::_Tensor_>();
CHECK(tensor_n);
auto it = vectorizes.find(tensor_n->name);
if (it != vectorizes.end()) {
CHECK_LT(it->second.level, forloop_stack.size());
forloop_stack[it->second.level]->set_vectorize_info(it->second);
CHECK(it->second.valid());
}
}
std::vector<ir::PolyFor*> forloop_stack;
};
struct MarkUnrollMutator : public ir::IRMutator<Expr*> {
std::map<std::string, std::set<int> /*level*/> unrolls;
explicit MarkUnrollMutator(const std::map<std::string, std::set<int>>& unrolls) : unrolls(unrolls) {}
void operator()(Expr* expr) { ir::IRMutator<>::Visit(expr, expr); }
void Visit(const ir::PolyFor* op, Expr* expr) override {
auto* node = expr->As<ir::PolyFor>();
stack.push_back(node);
ir::IRMutator<>::Visit(op, expr);
stack.pop_back();
}
// each statement in ISL is bound to a Store node.
void Visit(const ir::Store* op, Expr* expr) override {
auto* tensor_n = op->tensor.As<ir::_Tensor_>();
CHECK(tensor_n);
auto it = unrolls.find(tensor_n->name);
if (it != unrolls.end()) {
for (int level : it->second) {
VLOG(1) << "Mark " << level << " Unrolled";
CHECK_LT(level, stack.size());
stack[level]->set_unrolled();
}
}
}
std::vector<ir::PolyFor*> stack;
};
struct MarkParallelMutator : public ir::IRMutator<Expr*> {
std::map<std::string, std::set<int> /*level*/> parallels;
explicit MarkParallelMutator(const std::map<std::string, std::set<int>>& parallels) : parallels(parallels) {}
void operator()(Expr* expr) { ir::IRMutator<>::Visit(expr, expr); }
void Visit(const ir::PolyFor* op, Expr* expr) override {
auto* node = expr->As<ir::PolyFor>();
stack.push_back(node);
ir::IRMutator<>::Visit(op, expr);
stack.pop_back();
}
// each statement in ISL is bound to a Store node.
void Visit(const ir::Store* op, Expr* expr) override {
auto* tensor_n = op->tensor.As<ir::_Tensor_>();
CHECK(tensor_n);
auto it = parallels.find(tensor_n->name);
if (it != parallels.end()) {
for (int level : it->second) {
VLOG(1) << "Mark " << level << " Paralled";
CHECK_LT(level, stack.size());
stack[level]->set_parallel();
}
}
}
std::vector<ir::PolyFor*> stack;
};
} // namespace detail
} // namespace lang
} // namespace cinn