Program Listing for File lower_test.cc
↰ Return to documentation for file (/WorkSpace/CINN/cinn/lang/lower_test.cc
)
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "cinn/lang/lower.h"
#include <gtest/gtest.h>
#include <set>
#include "cinn/cinn.h"
#include "cinn/lang/buffer.h"
#include "cinn/lang/compute.h"
#include "cinn/lang/placeholder.h"
#include "cinn/utils/string.h"
namespace cinn {
namespace lang {
TEST(lower, basic) {
auto M = Expr(100);
auto N = Expr(15);
Placeholder<float> A("A", {Expr(M), Expr(N)});
auto B = Compute(
{M, N}, [=](Var i, Var j) -> Expr { return A(i, j) + 1.f; }, "B");
auto stages = CreateStages({B});
auto lower_funcs = Lower("cal_B", stages, {A, B});
LOG(INFO) << "lower_size " << lower_funcs;
#define TEST_SOUTPUT(x, out) \
std::cout << "\n" << x << std::endl; \
EXPECT_EQ(utils::GetStreamCnt(x), utils::Trim(out));
auto out = R"ROC(
{
for (i, 0, 100)
{
for (j, 0, 15)
{
B[i, j] = (1 + A[i, j])
}
}
}
)ROC";
TEST_SOUTPUT(lower_funcs->body, out);
}
TEST(lower, more_complex) {
Expr M(100);
Expr N(15);
Expr K(200);
Placeholder<float> A("A", {Expr(M), Expr(N)});
Placeholder<float> B("B", {Expr(N), Expr(K)});
auto C = Compute(
{M, N, K}, [=](Var i, Var j, Var k) -> Expr { return A(i, j) * B(j, k); }, "C");
auto stages = CreateStages({C});
auto lower_funcs = Lower("cal_C", stages, {A, B, C});
std::cout << "func:\n" << Expr(lower_funcs->self()) << std::endl;
}
TEST(lower, dynamic_shape) {
Var B("B"); // B is like shape here.
Expr N(15);
Expr K(200);
// Input is B * N, B is like batch.
Placeholder<float> X("X", {Expr(B), Expr(N)});
Placeholder<float> W("W", {Expr(N), Expr(K)});
auto C = Compute(
{B, N, K}, [=](Var i, Var j, Var k) -> Expr { return X(i, j) * W(j, k); }, "C");
auto stages = CreateStages({C});
auto lower_funcs = Lower("cal_C", stages, {X, W, C});
std::cout << "func:\n" << Expr(lower_funcs->self()) << std::endl;
}
TEST(lower, lowered_call) {
Var B("B"); // B is like shape here.
Expr N(15);
// Input is B * N, B is like batch.
Placeholder<float> X("X", {Expr(B), Expr(N)});
Placeholder<float> Y("Y", {Expr(B), Expr(N)});
auto Z = Compute(
{B, N}, [&](Var i, Var j) { return X(i, j) + Y(i, j); }, "Z");
std::vector<ReturnType> return_types({{Float(32), std::vector<Expr>{{B, N}}, "C"}});
auto tensors = CallLowered("lowered_fun0", {X, Y, Z}, return_types);
auto C = tensors[0];
auto stages = CreateStages({X, Y, Z, C});
LOG(INFO) << "call_op: " << C->operation->as<ir::CallOp>()->call_expr;
auto lower_func = Lower("fn", stages, {X, Y, Z, C});
}
// test the temp_buffers are all collected.
TEST(lower, temp_buffer_collects) {
Expr M(10);
Placeholder<float> A("A", {M});
auto B = Compute(
{M}, [&](Expr i) -> Expr { return A(i); }, "B"); // temp
auto C = Compute(
{M}, [&](Expr i) -> Expr { return B(i); }, "C"); // temp
auto D = Compute(
{M}, [&](Expr i) -> Expr { return C(i); }, "D"); // temp
auto output = Compute(
{M}, [&](Expr i) -> Expr { return D(i); }, "output");
ir::Module::Builder b("somemodule", common::DefaultHostTarget());
auto stages = CreateStages({B, C, D, output});
auto fn = Lower("fn", stages, {A, output}, {}, {}, &b);
auto module = b.Build();
ASSERT_EQ(module.buffers().size(), 3UL);
std::set<std::string> detected_buffer_names({"_B", "_C", "_D"});
for (auto& buffer : module.buffers()) {
ASSERT_TRUE(detected_buffer_names.count(buffer->name));
}
}
} // namespace lang
} // namespace cinn