Program Listing for File placeholder.h
↰ Return to documentation for file (/WorkSpace/CINN/cinn/lang/placeholder.h
)
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "cinn/common/common.h"
#include "cinn/ir/buffer.h"
#include "cinn/ir/ir.h"
#include "cinn/ir/ir_printer.h"
#include "cinn/ir/operation.h"
#include "cinn/ir/tensor.h"
#include "cinn/runtime/intrinsic.h"
namespace cinn {
namespace lang {
using ir::Expr;
template <typename T>
class Placeholder {
public:
Placeholder(const std::string &name, const std::vector<int> &shape);
Placeholder(const std::string &name, const std::vector<Expr> &shape);
// @{
Expr operator()(Expr a) const { return Call({a}); }
Expr operator()(Expr a, Expr b) const { return Call({a, b}); }
Expr operator()(Expr a, Expr b, Expr c) const { return Call({a, b, c}); }
Expr operator()(Expr a, Expr b, Expr c, Expr d) const { return Call({a, b, c, d}); }
Expr operator()(const std::vector<Expr> &indices) const;
// @}
Type type() const { return tensor_->type(); }
operator ir::Tensor() { return tensor_; }
operator ir::Expr() { return Expr(tensor_); }
ir::Tensor &operator->() { return tensor_; }
const ir::Tensor &operator->() const { return tensor_; }
ir::Tensor tensor() const { return tensor_; }
private:
Expr Call(const std::vector<Expr> &indices) const;
void Init(const std::string &name, const std::vector<Expr> &shape);
ir::Tensor tensor_;
};
template <typename T>
Expr Placeholder<T>::operator()(const std::vector<Expr> &indices) const {
return tensor_(indices);
}
template <typename T>
Expr Placeholder<T>::Call(const std::vector<Expr> &indices) const {
return tensor_(indices);
}
template <typename T>
Placeholder<T>::Placeholder(const std::string &name, const std::vector<int> &shape) {
std::vector<Expr> _shape;
for (int v : shape) _shape.push_back(Expr(v));
Init(name, _shape);
}
template <typename T>
Placeholder<T>::Placeholder(const std::string &name, const std::vector<Expr> &shape) {
Init(name, shape);
}
ir::Tensor CreatePlaceHolder(const std::vector<Expr> &shape, Type type, const std::string &name);
template <typename T>
void Placeholder<T>::Init(const std::string &name, const std::vector<Expr> &shape) {
ir::Var buffer_ptr(Context::Global().NewName("buffer"));
buffer_ptr->set_type(type_of<T>());
std::vector<Expr> strides(shape.size(), Expr(1));
Expr offset(0);
std::vector<ir::Var> axis;
for (int i = 0; i < shape.size(); i++) axis.emplace_back(common::axis_name(i));
auto op = ir::PlaceholderOp::Make(name, shape, type_of<T>());
tensor_ = ir::Tensor(name, type_of<T>(), shape, shape, op, {});
Buffer buffer(tensor_->type());
tensor_->Bind(buffer);
}
} // namespace lang
} // namespace cinn