Program Listing for File computation.h
↰ Return to documentation for file (/WorkSpace/CINN/cinn/frontend/computation.h
)
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <iostream>
#include "cinn/frontend/base_builder.h"
#include "cinn/frontend/syntax.h"
#include "cinn/hlir/framework/graph_compiler.h"
#include "cinn/hlir/framework/tensor.h"
namespace cinn {
namespace frontend {
struct ComputationContext;
class CinnComputation {
public:
struct CompileOptions : public hlir::framework::GraphCompiler::CompileOptions {
bool use_decomposer = false;
bool do_prerun = true;
bool use_default_passes = true;
std::vector<std::string> passes;
};
inline static CompileOptions DefaultCompileOptions() {
CompileOptions options;
options.with_instantiate_variables = true;
options.use_decomposer = false;
options.passes = {};
options.do_prerun = true;
options.use_default_passes = true;
return options;
}
static std::shared_ptr<CinnComputation> BuildAndCompile(const Target &target,
BaseBuilder &builder,
const CompileOptions &options = DefaultCompileOptions(),
const std::vector<Variable> &outputs = {},
void *stream = nullptr);
static std::shared_ptr<CinnComputation> Compile(const Target &target,
Program &program,
const CompileOptions &options = DefaultCompileOptions(),
const std::vector<Variable> &outputs = {},
void *stream = nullptr);
static std::shared_ptr<CinnComputation> CompilePaddleModel(const Target &target,
const std::string &model_path,
const std::vector<std::string> &input_names,
const std::vector<hlir::framework::shape_t> &input_shapes,
bool params_combined,
const CompileOptions &options = DefaultCompileOptions(),
void *stream = nullptr);
std::vector<std::string> GetAllTensorNames();
hlir::framework::Tensor GetTensor(const std::string &name);
std::vector<hlir::framework::Tensor> GetInputTensors();
std::vector<hlir::framework::Tensor> GetOutputTensors();
void SetTensorData(hlir::framework::Tensor &t, void *data, size_t size);
void SetTensorData(const std::string &tname, void *data, size_t size);
void GetTensorData(hlir::framework::Tensor &t, void *data, size_t size);
void GetTensorData(const std::string &tname, void *data, size_t size);
void Execute(const std::map<std::string, cinn_pod_value_t> *name2podargs = nullptr);
private:
std::shared_ptr<ComputationContext> context_;
};
} // namespace frontend
} // namespace cinn