Program Listing for File decomposer_registry.h

Return to documentation for file (/WorkSpace/CINN/cinn/frontend/decomposer_registry.h)

// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <functional>
#include <string>
#include <unordered_map>

#include "cinn/common/target.h"
#include "cinn/frontend/cinn_builder.h"
#include "cinn/frontend/syntax.h"

namespace cinn {
namespace frontend {

class Decomposer;

class DecomposerContext {
 public:
  explicit DecomposerContext(CinnBuilder* builder, absl::flat_hash_map<std::string, Variable>* var_map)
      : builder_(builder), var_map_(var_map) {}

  CinnBuilder* builder() const { return builder_; };

  // Map the new var to the original var.
  void MapOutToOrigin(const Variable& new_var, const Variable& ori_var) const {
    if (new_var->shape != ori_var->shape) {
      LOG(FATAL) << "The output shape shoule be equal to the original. But received : " << new_var->id << ".shape=["
                 << utils::Join(new_var->shape, ", ") << "] and the original var " << ori_var->id << ".shape=["
                 << utils::Join(ori_var->shape, ", ") << "].";
    }
    (*var_map_)[new_var->id] = ori_var;
  }

 private:
  CinnBuilder* builder_{nullptr};
  absl::flat_hash_map<std::string, Variable>* var_map_{nullptr};
};

class InstrDecomposerRegistry : public Registry<Decomposer> {
 public:
  static InstrDecomposerRegistry* Global() {
    static InstrDecomposerRegistry x;
    return &x;
  }

  inline const Decomposer* Get(const std::string& op_name, const common::Target& target) {
    const Decomposer* decomposer = Find(op_name, target);
    CHECK(decomposer) << "Decomposer for [" << op_name << ", " << target << "] is not registered";
    return decomposer;
  }

  inline const Decomposer* Find(const std::string& name, const common::Target& target) {
    return Registry<Decomposer>::Find(name + "_" + target.arch_str());
  }

 private:
  InstrDecomposerRegistry() = default;
  CINN_DISALLOW_COPY_AND_ASSIGN(InstrDecomposerRegistry);
};

class Decomposer {
 public:
  using DecomposerKernel = std::function<void(const Instruction& instr, const DecomposerContext&)>;

  Decomposer& SetBody(const DecomposerKernel& kernel) {
    kernel_ = kernel;
    return *this;
  }

  void Run(const Instruction& instr, const DecomposerContext& context) const { kernel_(instr, context); }

  std::string name;

 private:
  DecomposerKernel kernel_;
};

#define CINN_DECOMPOSER_REGISTER_CORE(name, target, kernel)        \
  ::cinn::frontend::InstrDecomposerRegistry::Global()              \
      ->__REGISTER__(std::string(#name) + "_" + target.arch_str()) \
      .SetBody(kernel)

#define CINN_DECOMPOSER_REGISTER_ALL(name, kernel)                                                 \
  static std::vector<::cinn::common::Target> all_targets = {::cinn::common::DefaultHostTarget(),   \
                                                            ::cinn::common::DefaultNVGPUTarget()}; \
  for (auto& target : all_targets) {                                                               \
    ::cinn::frontend::InstrDecomposerRegistry::Global()                                            \
        ->__REGISTER__(std::string(#name) + "_" + target.arch_str())                               \
        .SetBody(kernel);                                                                          \
  }

#define GET_MACRO(_0, _1, _2, FUNC, ...) FUNC
#define CINN_DECOMPOSER_REGISTER(...) \
  GET_MACRO(__VA_ARGS__, CINN_DECOMPOSER_REGISTER_CORE, CINN_DECOMPOSER_REGISTER_ALL)(__VA_ARGS__)

}  // namespace frontend
}  // namespace cinn