Program Listing for File computation.h

Return to documentation for file (/WorkSpace/CINN/cinn/frontend/computation.h)

// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <iostream>

#include "cinn/frontend/base_builder.h"
#include "cinn/frontend/syntax.h"
#include "cinn/hlir/framework/graph_compiler.h"
#include "cinn/hlir/framework/tensor.h"

namespace cinn {
namespace frontend {

struct ComputationContext;

class CinnComputation {
 public:
  struct CompileOptions : public hlir::framework::GraphCompiler::CompileOptions {
    bool use_decomposer     = false;
    bool do_prerun          = true;
    bool use_default_passes = true;
    std::vector<std::string> passes;
  };

  inline static CompileOptions DefaultCompileOptions() {
    CompileOptions options;
    options.with_instantiate_variables = true;
    options.use_decomposer             = false;
    options.passes                     = {};
    options.do_prerun                  = true;
    options.use_default_passes         = true;
    return options;
  }

  static std::shared_ptr<CinnComputation> BuildAndCompile(const Target &target,
                                                          BaseBuilder &builder,
                                                          const CompileOptions &options = DefaultCompileOptions(),
                                                          const std::vector<Variable> &outputs = {},
                                                          void *stream                         = nullptr);
  static std::shared_ptr<CinnComputation> Compile(const Target &target,
                                                  Program &program,
                                                  const CompileOptions &options        = DefaultCompileOptions(),
                                                  const std::vector<Variable> &outputs = {},
                                                  void *stream                         = nullptr);
  static std::shared_ptr<CinnComputation> CompilePaddleModel(const Target &target,
                                                             const std::string &model_path,
                                                             const std::vector<std::string> &input_names,
                                                             const std::vector<hlir::framework::shape_t> &input_shapes,
                                                             bool params_combined,
                                                             const CompileOptions &options = DefaultCompileOptions(),
                                                             void *stream                  = nullptr);

  std::vector<std::string> GetAllTensorNames();

  hlir::framework::Tensor GetTensor(const std::string &name);

  std::vector<hlir::framework::Tensor> GetInputTensors();

  std::vector<hlir::framework::Tensor> GetOutputTensors();

  void SetTensorData(hlir::framework::Tensor &t, void *data, size_t size);

  void SetTensorData(const std::string &tname, void *data, size_t size);

  void GetTensorData(hlir::framework::Tensor &t, void *data, size_t size);
  void GetTensorData(const std::string &tname, void *data, size_t size);

  void Execute(const std::map<std::string, cinn_pod_value_t> *name2podargs = nullptr);

 private:
  std::shared_ptr<ComputationContext> context_;
};

}  // namespace frontend
}  // namespace cinn