Program Listing for File program_pass.h

Return to documentation for file (/WorkSpace/CINN/cinn/frontend/program_pass.h)

// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <memory>
#include <string>
#include <unordered_map>
#include <vector>

#include "cinn/frontend/op_mapper_registry.h"
#include "cinn/frontend/syntax.h"
#include "cinn/utils/registry.h"

namespace cinn {
namespace frontend {

class ProgramPass {
 public:
  ProgramPass(const std::string& name) : name_(name) {}

  static void Apply(Program* prog, const common::Target& target, const std::vector<std::string>& passes);
  virtual void ApplyImpl(Program* prog, const common::Target& target) const {}

  const std::string& name() { return name_; }

 private:
  std::string name_;
};

class ProgramPassRegistry : public Registry<ProgramPass> {
 public:
  static ProgramPassRegistry* Global() {
    static ProgramPassRegistry x;
    return &x;
  }

  inline const ProgramPass* Get(const std::string& name) {
    const ProgramPass* pass = Registry<ProgramPass>::Find(name);
    CHECK(pass) << "Pass [" << name << "] is not registered";
    return pass;
  }

  inline ProgramPass* __REGISTER__(const std::string& name, ProgramPass* pass) {
    std::lock_guard<std::mutex> guard(registering_mutex);
    if (fmap_.count(name)) {
      return fmap_[name];
    }

    fmap_[name] = pass;
    const_list_.push_back(pass);
    entry_list_.push_back(pass);
    return pass;
  }

  inline ProgramPass* __REGISTER_OR_GET__(const std::string& name, ProgramPass* pass) {
    if (!fmap_.count(name)) {
      return __REGISTER__(name, pass);
    } else {
      return fmap_.at(name);
    }
  }

 private:
  ProgramPassRegistry() = default;
  CINN_DISALLOW_COPY_AND_ASSIGN(ProgramPassRegistry);
};

#define CINN_REGISTER_PROGRAM_PASS(PassType, PassClass)         \
  static ::cinn::frontend::ProgramPass* __make_##PassType##__ = \
      ::cinn::frontend::ProgramPassRegistry::Global()->__REGISTER_OR_GET__(#PassType, new PassClass{#PassType})

using ProgramPassFunction = std::function<void(Program*, const std::unordered_set<std::string>&)>;

void ApplyPass(Program* program, const std::unordered_set<std::string>& fetch_ids, const std::string& pass);

class ProgramPassFunctionRegistry : public FunctionRegEntryBase<ProgramPassFunctionRegistry, ProgramPassFunction> {
 public:
  ProgramPassFunctionRegistry() = default;

 private:
  CINN_DISALLOW_COPY_AND_ASSIGN(ProgramPassFunctionRegistry);
};

#define CINN_REGISTER_PROGRAM_PASS_FUNCTION(name) \
  CINN_REGISTRY_REGISTER(::cinn::frontend::ProgramPassFunctionRegistry, ProgramPassFunctionRegistry, name)

}  // namespace frontend
}  // namespace cinn