Program Listing for File packed_func.h

Return to documentation for file (/WorkSpace/CINN/cinn/lang/packed_func.h)

// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <functional>
#include <string>
#include <utility>
#include <vector>

#include "cinn/common/cinn_value.h"
#include "cinn/ir/ir.h"

namespace cinn {
namespace lang {
using common::CINNValue;

using ArgValue = CINNValue;

using RetValue = CINNValue;

class Args {
 public:
  Args() = default;
  Args(cinn_value_t* values, int* type_codes, int len);

  void Append(const ArgValue& arg) { values_.push_back(arg); }

  size_t size() const { return values_.size(); }

  bool empty() const { return values_.empty(); }

  ArgValue& operator[](int i) { return values_[i]; }
  const ArgValue& operator[](int i) const { return values_[i]; }

  common::CINNValuePack ToValuePack() const { return common::CINNValuePack(values_); }

 private:
  std::vector<ArgValue> values_;
};

namespace detail {

template <bool stop, std::size_t I, typename F>
struct for_each_dispatcher {
  template <typename T, typename... Args>
  static void Run(const F& f, T&& value, Args&&... args) {
    f(I, std::forward<T>(value));
    for_each_dispatcher<sizeof...(Args) == 0, I + 1, F>::Run(f, std::forward<Args>(args)...);
  }
};

template <std::size_t I, typename F>
struct for_each_dispatcher<true, I, F> {
  static void Run(const F& f) {}
};

template <typename F, typename... Args>
inline void for_each(const F& f, Args&&... args) {
  for_each_dispatcher<sizeof...(Args) == 0, 0, F>::Run(f, std::forward<Args>(args)...);
}

struct FuncArgsSetter {
  FuncArgsSetter(Args* args) : args_(args) {}  // NOLINT

  template <typename T>
  void operator()(size_t I, T v) const {
    args_->Append(ArgValue(v));
  }

 private:
  mutable Args* args_{};
};

}  // namespace detail

class PackedFunc {
 public:
  using body_t = std::function<void(Args args, RetValue*)>;

  PackedFunc() = default;
  explicit PackedFunc(const std::string& name) : name_(name) {}
  explicit PackedFunc(body_t body) : body_(body) {}

  template <typename... Args_>
  inline RetValue operator()(Args_&&... args) const {
    Args _args;
    detail::FuncArgsSetter setter(&_args);
    detail::for_each(setter, std::forward<Args_>(args)...);

    RetValue ret_value;
    body_(_args, &ret_value);
    return ret_value;
  }

  inline body_t body() const { return body_; }

 private:
  std::string name_;
  body_t body_;
};

}  // namespace lang
}  // namespace cinn