Program Listing for File base_builder.cc
↰ Return to documentation for file (/WorkSpace/CINN/cinn/frontend/base_builder.cc
)
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "cinn/frontend/base_builder.h"
#include <algorithm>
#include <functional>
#include <string>
#include <utility>
#include <vector>
#include "cinn/common/common.h"
#include "cinn/common/context.h"
#include "cinn/common/type.h"
#include "cinn/frontend/syntax.h"
#include "cinn/hlir/framework/node.h"
#include "cinn/hlir/framework/op.h"
namespace cinn {
namespace frontend {
using common::Context;
using common::Type;
using hlir::framework::AttrMapType;
using hlir::framework::Operator;
using hlir::framework::shape_t;
BaseBuilder::BaseBuilder(const std::string& name) : name_(name) {}
Program BaseBuilder::Build() {
Program program{std::move(instrs_), std::move(inputs_)};
program.Validate();
return program;
}
Placeholder BaseBuilder::CreateInput(const Type& type, const std::vector<int>& shape, const std::string& id_hint) {
if (!id_hint.empty()) {
CheckVarNameValid(id_hint);
}
std::string id = id_hint.empty() ? Context::Global().NewName("placeholder") : id_hint;
inputs_.emplace_back(id);
auto& var = inputs_.back();
var->type = type;
var->shape = shape;
return Placeholder(var);
}
Placeholder BaseBuilder::CreateInput(const Variable& var) {
CHECK(!var->shape.empty()) << "The input's shape is not set yet";
CHECK(!var->type.is_unk()) << "The input's type is not set yet";
inputs_.push_back(var);
return Placeholder(var);
}
void BaseBuilder::InferShape(Instruction instr) const {
using shape_func_t = std::function<std::vector<shape_t>(const std::vector<shape_t>&, const AttrMapType&)>;
using type_func_t = std::function<std::vector<Type>(const std::vector<Type>&, const AttrMapType&)>;
const auto& op_infershape = Operator::GetAttrs<shape_func_t>("infershape");
const auto& op_inferdtype = Operator::GetAttrs<type_func_t>("inferdtype");
size_t size = instr->inputs.size();
std::vector<shape_t> in_shapes(size);
std::vector<Type> in_types(size);
std::transform(
instr->inputs.begin(), instr->inputs.end(), in_shapes.begin(), [](const Variable& var) { return var->shape; });
std::transform(
instr->inputs.begin(), instr->inputs.end(), in_types.begin(), [](const Variable& var) { return var->type; });
auto key = Operator::Get(instr->op_type);
auto out_shapes = op_infershape[key](in_shapes, instr->attrs);
auto out_types = op_inferdtype[key](in_types, instr->attrs);
auto& outs = instr->outputs;
for (size_t i = 0; i < outs.size(); i++) {
outs[i]->shape = out_shapes[i];
outs[i]->type = out_types[i];
}
}
} // namespace frontend
} // namespace cinn