Program Listing for File interpreter.cc

Return to documentation for file (/WorkSpace/CINN/cinn/frontend/interpreter.cc)

// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "cinn/frontend/interpreter.h"

#include "cinn/frontend/syntax.h"
#include "cinn/hlir/framework/graph.h"
#include "cinn/hlir/framework/pass.h"
#include "cinn/hlir/op/use_ops.h"
#include "cinn/hlir/pass/use_pass.h"

namespace cinn::frontend {

struct Interpreter::Impl {
  Impl(const std::vector<std::string>& input_names, const std::vector<hlir::framework::shape_t>& input_shapes)
      : scope_(std::make_shared<hlir::framework::Scope>()), input_names_(input_names), input_shapes_(input_shapes) {}

  void Build(const std::vector<std::string>& input_names,
             const std::vector<hlir::framework::shape_t>& input_shapes,
             const Target& target,
             const std::string& model_name = "");

 private:
  friend class Interpreter;

  std::vector<std::string> input_names_;
  absl::flat_hash_set<std::string> fetch_names_;
  std::vector<hlir::framework::shape_t> input_shapes_;

  std::shared_ptr<hlir::framework::Scope> scope_;
  std::unique_ptr<frontend::Program> program_;
  std::unique_ptr<hlir::framework::GraphCompiler> graph_compiler_;

  absl::flat_hash_map<std::string, Variable> var_map_;
  absl::flat_hash_map<std::string, std::string> var_map_paddle_to_cinn_;
  absl::flat_hash_map<std::string, std::string> var_map_cinn_to_paddle_;

  std::unique_ptr<hlir::framework::Program> runtime_program_;
  std::unique_ptr<hlir::framework::Program> prerun_program_;
};

void Interpreter::LoadPaddleModel(const std::string& model_dir,
                                  const Target& target,
                                  bool params_combined,
                                  const std::string& model_name) {
  auto programTuple               = LoadPaddleProgram(model_dir, impl_->scope_.get(), params_combined, target);
  auto& program                   = std::get<0>(programTuple);
  auto& var_map                   = std::get<1>(programTuple);
  auto& var_map_paddle_to_program = std::get<2>(programTuple);
  auto& fetch_names               = std::get<3>(programTuple);
  impl_->program_.reset(program.release());
  impl_->var_map_                = var_map;
  impl_->var_map_paddle_to_cinn_ = var_map_paddle_to_program;
  impl_->fetch_names_            = fetch_names;

  impl_->Build(impl_->input_names_, impl_->input_shapes_, target, model_name);
}

void Interpreter::Run() { impl_->runtime_program_->Execute(); }

hlir::framework::Tensor Interpreter::GetTensor(const std::string& name) {
  if (impl_->scope_->FindVar(name)) return impl_->scope_->GetTensor(name);

  auto it = impl_->var_map_paddle_to_cinn_.find(name);
  if (it == impl_->var_map_paddle_to_cinn_.end()) {
    LOG(FATAL) << "No variable called [" << name
               << "] found in executor\nThe existing vars: " << utils::Join(impl_->scope_->var_names(), ", ");
  }
  return impl_->scope_->GetTensor(it->second);
}

void Interpreter::Impl::Build(const std::vector<std::string>& input_names,
                              const std::vector<hlir::framework::shape_t>& input_shapes,
                              const Target& target,
                              const std::string& model_name) {
  CHECK(!input_names.empty());
  CHECK(!var_map_.empty());
  CHECK_EQ(input_names.size(), input_shapes.size());

  std::vector<Variable> input_vars;
  std::transform(input_names.begin(), input_names.end(), std::back_inserter(input_vars), [&](const std::string& x) {
    return var_map_.at(x);
  });

  for (int i = 0; i < input_vars.size(); i++) input_vars[i]->shape = input_shapes[i];

  program_->SetInputs({input_vars});
  program_->Validate();

  VLOG(3) << "Program:\n" << *program_;

  auto graph                 = std::make_shared<hlir::framework::Graph>(*program_, target);
  graph->attrs["model_name"] = std::make_shared<absl::any>(model_name);

  hlir::framework::ApplyPass(graph.get(), "InferShape");
#ifndef CINN_WITH_CUDA
  if (target.arch == Target::Arch::X86) {
    hlir::framework::ApplyPass(graph.get(), "AlterLayout");
  }
#endif
  hlir::framework::ApplyPass(graph.get(), "ConstPropagate");
  hlir::framework::ApplyPass(graph.get(), "OpFusion");
  // Target target = common::DefaultHostTarget();
  scope_ = hlir::framework::BuildScope(target, graph, scope_);

  std::unordered_set<std::string> fetch_var_ids;
  for (auto& name : fetch_names_) {
    CHECK(var_map_.count(name)) << "var_map finds no fetch var " << name;
    fetch_var_ids.insert(var_map_.at(name)->id);
  }

  graph_compiler_.reset(new hlir::framework::GraphCompiler(target, scope_, graph));
  hlir::framework::GraphCompiler::CompileOptions options;
  options.with_instantiate_variables = true;
  runtime_program_                   = graph_compiler_->Build(options, std::move(fetch_var_ids)).runtime_program;
  runtime_program_->PreRun();
}

std::shared_ptr<hlir::framework::Scope> Interpreter::scope() {
  CHECK(impl_->scope_);
  return impl_->scope_;
}

Interpreter::Interpreter(const std::vector<std::string>& input_names,
                         const std::vector<hlir::framework::shape_t>& input_shapes)
    : impl_(new Impl(input_names, input_shapes)) {}

}  // namespace cinn::frontend

cinn::frontend::Interpreter::~Interpreter() {}