Program Listing for File lower.cc

Return to documentation for file (/WorkSpace/CINN/cinn/lang/lower.cc)

// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "cinn/lang/lower.h"

#include <iostream>
#include <map>
#include <set>
#include <stack>
#include <unordered_set>
#include <utility>

#include "cinn/ir/buffer.h"
#include "cinn/ir/ir_printer.h"
#include "cinn/lang/lower_impl.h"
#include "cinn/optim/optimize.h"

namespace cinn {
namespace lang {

using ir::Tensor;
using poly::Stage;

std::vector<ir::Buffer> GetTempBuffers(const std::vector<Tensor>& tensor_args,
                                       const poly::StageMap& stage_map,
                                       Expr body) {
  std::unordered_set<std::string> tensor_arg_names;
  std::unordered_set<std::string> buffer_arg_names;
  for (auto& tensor : tensor_args) {
    tensor_arg_names.insert(tensor->name);
    if (tensor->buffer.defined()) {
      buffer_arg_names.insert(tensor->buffer->name);
    }
  }
  std::unordered_set<std::string> temp_buffer_names;  // used to avoid duplication.
  std::vector<ir::Buffer> temp_buffers;
  auto all_temp_tensors = ir::CollectIRNodes(body, [&](const Expr* x) {
    return x->as_tensor() && x->as_tensor()->buffer.defined() && !stage_map[x->as_tensor()]->inlined() &&
           !buffer_arg_names.count(x->as_tensor()->buffer->name) && !tensor_arg_names.count(x->as_tensor()->name);
  });
  for (auto& e : all_temp_tensors) {
    if (!temp_buffer_names.count(e.as_tensor()->buffer->name)) {
      temp_buffers.push_back(e.as_tensor()->buffer);
      temp_buffer_names.insert(e.as_tensor()->buffer->name);
    }
  }
  return temp_buffers;
}

std::set<ir::Tensor> CollectTempTensorsFromCtrlDepends(StageMap stages, const std::vector<Tensor>& tensor_args) {
  std::set<ir::Tensor> res;
  for (auto& stage : stages) {
    res.emplace(ir::Tensor(stage.second->tensor()));
    res.insert(stage.second->ctrl_depends().begin(), stage.second->ctrl_depends().end());
  }

  for (auto& t : tensor_args) {
    if (res.count(t)) res.erase(t);
  }
  return res;
}

void InitReduceTensor(StageMap stages, const Tensor& tensor, const Target& target) {
  if (tensor->is_reduce_tensor() && !tensor->IsReduceInited(stages)) {
    tensor->InitReduction(stages, target);
  }

  auto uninited_reduce_tensors = ir::CollectIRNodes(tensor->body(), [&](const Expr* x) {
    return x && x->defined() && x->as_tensor() && x->as_tensor()->is_reduce_tensor() &&
           !x->as_tensor()->IsReduceInited(stages);
  });
  for (auto& t : uninited_reduce_tensors) {
    VLOG(3) << "Init reduce tensor: " << t.as_tensor()->name;
    t.as_tensor()->InitReduction(stages, target);
  }
}

ir::LoweredFunc Lower(const std::string& name,
                      StageMap stages,
                      const std::vector<Tensor>& tensor_args,
                      const std::vector<Var>& scalar_args,
                      const std::vector<Tensor>& temp_tensors,
                      Module::Builder* b,
                      const Target& target) {
  // Init the reduce tensors first before any process.
  for (auto& t : tensor_args) InitReduceTensor(stages, t, target);
  for (auto& t : temp_tensors) InitReduceTensor(stages, t, target);

  // Merge the ctrl_deps with the given temp_tensors ang get a new temp_tensors
  auto ctrl_deps = CollectTempTensorsFromCtrlDepends(stages, tensor_args);
  ctrl_deps.insert(temp_tensors.begin(), temp_tensors.end());

  auto lower_impl_instance = detail::LowerImpl(
      name, stages, tensor_args, scalar_args, std::vector<Tensor>(ctrl_deps.begin(), ctrl_deps.end()), target);

  auto result = lower_impl_instance();
  std::vector<ir::LoweredFunc> return_value;
  for (auto& res : result) {
    auto temp_buffers = GetTempBuffers(tensor_args, stages, res->body);
    if (b) {
      for (auto& temp_buffer : temp_buffers) {
        b->AddBuffer(temp_buffer);
      }
    }

    {  // set function device_api
      bool contains_gpu = false;
      for (auto& t : tensor_args) {
        if (contains_gpu = detail::TensorContainsGPUInfo(t, stages[t])) break;
      }
      if (!contains_gpu) {
        for (auto& t : temp_tensors) {
          if (contains_gpu = detail::TensorContainsGPUInfo(t, stages[t])) break;
        }
      }

      if (contains_gpu) {
        res->device_api = ir::DeviceAPI::GPU;
      }
    }

    if (b) {
      b->AddFunction(res);
    }

    res->temp_bufs = temp_buffers;

    return_value.push_back(res);
  }
  return return_value[0];
}

std::vector<ir::LoweredFunc> LowerVec(const std::string& name,
                                      StageMap stages,
                                      const std::vector<Tensor>& tensor_args,
                                      const std::vector<Var>& scalar_args,
                                      const std::vector<Tensor>& temp_tensors,
                                      Module::Builder* b,
                                      const Target& target) {
  // Init the reduce tensors first before any process.
  for (auto& t : tensor_args) InitReduceTensor(stages, t, target);
  for (auto& t : temp_tensors) InitReduceTensor(stages, t, target);

  // Merge the ctrl_deps with the given temp_tensors ang get a new temp_tensors
  auto ctrl_deps = CollectTempTensorsFromCtrlDepends(stages, tensor_args);
  ctrl_deps.insert(temp_tensors.begin(), temp_tensors.end());

  auto lower_impl_instance = detail::LowerImpl(
      name, stages, tensor_args, scalar_args, std::vector<Tensor>(ctrl_deps.begin(), ctrl_deps.end()), target);
  // return vectorof ir::LoweredFunc.
  auto result = lower_impl_instance();
  std::vector<ir::LoweredFunc> return_value;
  for (auto& res : result) {
    auto temp_buffers = GetTempBuffers(tensor_args, stages, res->body);
    if (b) {
      for (auto& temp_buffer : temp_buffers) {
        b->AddBuffer(temp_buffer);
      }
    }

    {  // set function device_api
      bool contains_gpu = false;
      for (auto& t : tensor_args) {
        if (contains_gpu = detail::TensorContainsGPUInfo(t, stages[t])) break;
      }
      if (!contains_gpu) {
        for (auto& t : temp_tensors) {
          if (contains_gpu = detail::TensorContainsGPUInfo(t, stages[t])) break;
        }
      }

      if (contains_gpu) {
        res->device_api = ir::DeviceAPI::GPU;
      }
    }

    if (b) {
      b->AddFunction(res);
    }

    res->temp_bufs = temp_buffers;

    return_value.push_back(res);
  }
  return return_value;
}

}  // namespace lang
}  // namespace cinn