Program Listing for File broadcast.cc
↰ Return to documentation for file (/WorkSpace/CINN/cinn/frontend/decomposer/broadcast.cc
)
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "cinn/frontend/decomposer_registry.h"
#include "cinn/frontend/syntax.h"
namespace cinn {
namespace frontend {
namespace decomposer {
void GetReduceDimsForX(const std::vector<int>& dx_shape,
const std::vector<int>& dout_shape,
std::vector<int>* reduce_dims) {
// e.g., dx_shape = [4, 1, 3], dout_shape = [4, 2, 3], reduce_dims=[1]
for (size_t i = 0; i < dout_shape.size(); ++i) {
if (dx_shape[i] == 1 && dout_shape[i] != 1) {
reduce_dims->push_back(i);
}
}
VLOG(3) << "The reduce_dims for X: " << utils::Join(*reduce_dims, ",");
}
void GetReduceDimsForY(const std::vector<int>& dy_shape,
const std::vector<int>& dout_shape,
int axis,
std::vector<int>* reduce_dims) {
// e.g., dy_shape = [3, 1, 4], dout_shape = [2, 3, 4, 4, 5], axis = 1
// reduce_dims=[0, 2, 4]
for (size_t i = 0; i < dout_shape.size(); ++i) {
if (i < axis || i >= axis + dy_shape.size()) {
reduce_dims->push_back(i);
} else {
if (dy_shape[i - axis] == 1 && dout_shape[i] != 1) {
reduce_dims->push_back(i);
}
}
}
VLOG(3) << "The reduce_dims for Y: " << utils::Join(*reduce_dims, ",");
}
void elementwise_add(const Instruction& instr, const DecomposerContext& context) {
CHECK_EQ(instr->inputs.size(), 2UL) << " 2 input tensors for " << instr->op_type;
CHECK_EQ(instr->outputs.size(), 1UL) << "1 output tensor for " << instr->op_type;
auto x = instr->inputs[0];
auto y = instr->inputs[1];
auto output = instr->outputs[0];
int axis = instr.GetAttrs<int>("axis");
axis = axis >= 0 ? axis : x->shape.size() - y->shape.size();
auto* builder = context.builder();
Variable out;
Variable bcast_x = x;
Variable bcast_y = y;
// e.g., x.shape = [4, 1, 3], y.shape = [2, 3], aixs = 1 out.shape = [4, 2, 3]
// bcast_axes_x = [0, 1, 2], bcast_axes_y = [1, 2]
if (x->shape != output->shape) {
std::vector<int> bcast_axes_x(x->shape.size());
std::iota(bcast_axes_x.begin(), bcast_axes_x.end(), 0);
bcast_x = builder->BroadcastTo(x, output->shape, bcast_axes_x);
}
// if y.shape=[1], y does not need to be broadcast
if (y->shape != output->shape && y->shape != std::vector<int>(1, 1)) {
std::vector<int> bcast_axes_y(y->shape.size());
std::iota(bcast_axes_y.begin(), bcast_axes_y.end(), axis);
bcast_y = builder->BroadcastTo(y, output->shape, bcast_axes_y);
}
out = builder->Add(bcast_x, bcast_y);
// map the the output of decomposed operator to the original.
context.MapOutToOrigin(out, output);
}
void elementwise_add_grad(const Instruction& instr, const DecomposerContext& context) {
CHECK_EQ(instr->inputs.size(), 3UL) << " 3 input tensors for " << instr->op_type;
CHECK_EQ(instr->outputs.size(), 2UL) << "2 output tensors for " << instr->op_type;
auto dout = instr->inputs[0];
auto dx = instr->outputs[0];
auto dy = instr->outputs[1];
int axis = instr.GetAttrs<int>("axis");
axis = axis >= 0 ? axis : dx->shape.size() - dy->shape.size();
auto* builder = context.builder();
Variable dx_t;
if (dx->shape == dout->shape) {
dx_t = builder->Identity(dout);
context.MapOutToOrigin(dx, dout);
} else {
std::vector<int> x_reduce_dims;
GetReduceDimsForX(dx->shape, dout->shape, &x_reduce_dims);
// The rank of dx is same as dout, so set keep_dim = true
dx_t = builder->Reduce(dout, ReduceKind::kSum, x_reduce_dims, true);
}
Variable dy_t;
if (dy->shape == dout->shape) {
dy_t = builder->Identity(dout);
context.MapOutToOrigin(dy, dout);
} else {
std::vector<int> y_reduce_dims;
GetReduceDimsForY(dy->shape, dout->shape, axis, &y_reduce_dims);
// The rank of dy is less or equal to dout, after reduce_sum, there
// may be some extra "1" in the front or back of dy_res's shape. So
// the dt_res needs to be reshaped.
auto dy_res = builder->Reduce(dout, ReduceKind::kSum, y_reduce_dims, true);
dy_t = builder->Reshape(dy_res, dy->shape);
}
// map the the output of decomposed operator to the original.
context.MapOutToOrigin(dx_t, dx);
context.MapOutToOrigin(dy_t, dy);
}
} // namespace decomposer
} // namespace frontend
} // namespace cinn
CINN_REGISTER_HELPER(broadcast_decomposers) {
CINN_DECOMPOSER_REGISTER(elementwise_add, cinn::frontend::decomposer::elementwise_add);
return true;
}
CINN_REGISTER_HELPER(broadcast_grad_decomposers) {
CINN_DECOMPOSER_REGISTER(elementwise_add_grad, cinn::frontend::decomposer::elementwise_add_grad);
return true;
}