Program Listing for File decomposer.cc

Return to documentation for file (/WorkSpace/CINN/cinn/frontend/pass/decomposer.cc)

// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "cinn/frontend/decomposer_registry.h"
#include "cinn/frontend/program_pass.h"

namespace cinn {
namespace frontend {
namespace pass {

class DecomposerPass : public ProgramPass {
 public:
  using ProgramPass::ProgramPass;

  void ApplyImpl(Program* prog, const common::Target& target) const {
    // step 1: set the inputs of the origin program to the new program
    CinnBuilder builder("decomposer_builder");
    for (auto& var : prog->GetInputs()) {
      builder.CreateInput(var);
    }

    // step 2: use primitive instructions to build the new program
    absl::flat_hash_map<std::string, Variable> var_map;
    DecomposerContext context(&builder, &var_map);
    for (size_t i = 0; i < prog->size(); i++) {
      auto instr      = (*prog)[i];
      auto decomposer = InstrDecomposerRegistry::Global()->Find(instr->op_type, target);
      if (decomposer) {
        decomposer->Run(instr, context);
      } else {
        builder.AppendInstruction(instr);
      }
    }
    *prog = builder.Build();

    // step 3: set the origin output to the output of decomposed operator.
    for (size_t i = 0; i < prog->size(); i++) {
      auto& outputs = (*prog)[i]->outputs;
      for (size_t j = 0; j < outputs.size(); j++) {
        auto it = var_map.find(outputs[j]->id);
        if (it != var_map.end()) {
          outputs[j] = it->second;
        }
      }
      auto& inputs = (*prog)[i]->inputs;
      for (size_t j = 0; j < inputs.size(); j++) {
        auto it = var_map.find(inputs[j]->id);
        if (it != var_map.end()) {
          inputs[j] = it->second;
        }
      }
    }
  }
};

}  // namespace pass
}  // namespace frontend
}  // namespace cinn

CINN_REGISTER_HELPER(Decomposer) {
  CINN_REGISTER_PROGRAM_PASS(Decomposer, ::cinn::frontend::pass::DecomposerPass);

  return true;
}