Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix slice upstream - Incompatible dimensions #16818

Merged
merged 4 commits into from
Jul 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion onnxruntime/core/optimizer/compute_optimizer/shared_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,9 @@ NodeArg* CreateInitializerFromVector(Graph& graph,
total_count *= dim;
}

ORT_ENFORCE(total_count == static_cast<int64_t>(values.size()));
ORT_ENFORCE(total_count == static_cast<int64_t>(values.size()),
"The total count of dims does not match the size of values. ",
"total_count: ", total_count, " values.size(): ", values.size());

const_tensor.set_raw_data(values.data(), values.size() * sizeof(int64_t));
return &graph_utils::AddInitializer(graph, const_tensor);
Expand Down
83 changes: 67 additions & 16 deletions onnxruntime/core/optimizer/compute_optimizer/upstream_gather.cc
Original file line number Diff line number Diff line change
Expand Up @@ -138,20 +138,65 @@ SliceInfo UpStreamGatherGraphTransformer::PropagateSlicingForInput(
std::to_string(!info.is_scalar_slice));

InlinedVector<NodeArg*> input_args;
input_args.reserve(slice_node.InputDefs().size());
input_args.resize(slice_node.InputDefs().size());

int axis_input_index = -1; // -1 means axis is passed in attribute.
if (std::holds_alternative<int>(info.axis_attr_name_or_input_index)) {
axis_input_index = std::get<int>(info.axis_attr_name_or_input_index);
}

auto create_axes_input = [&info, new_axis, &graph]() -> NodeArg* {
InlinedVector<int64_t> dims;
if (info.rank_of_axis_value == 1) {
dims.push_back(1);
}
return CreateInitializerFromVector(graph, dims, {new_axis}, graph.GenerateNodeArgName("axes"));
};

// The first slice op's data input should be current_node's current_node_input_index-th input.
// For some cases when rank changes, slice op's slice input should also be adapted.
input_args.push_back(current_node.MutableInputDefs()[current_node_input_index]);
for (size_t i = 1; i < slice_node.InputDefs().size(); ++i) {
input_args.push_back(slice_node.MutableInputDefs()[i]);
int i = 0;
for (; i < static_cast<int>(slice_node.InputDefs().size()); ++i) {
if (i == info.GetDataInputIndex()) {
input_args[i] = current_node.MutableInputDefs()[current_node_input_index];
} else if (axis_input_index != -1 && i == axis_input_index) {
if (info.non_negative_axis == new_axis) {
input_args[i] = slice_node.MutableInputDefs()[i];
} else {
input_args[i] = create_axes_input();
}
} else {
input_args[i] = slice_node.MutableInputDefs()[i];
}
}

// It is possible axes input is null.
if (axis_input_index != -1 && info.non_negative_axis != new_axis) {
for (; i <= axis_input_index; ++i) {
if (i == axis_input_index) {
input_args.push_back(create_axes_input());
} else {
NodeArg& empty_input = graph.GetOrCreateNodeArg("", nullptr);
input_args.push_back(&empty_input);
}
}
}

// Update the axis attribute if new_axis is not the same as the original slicing axis (which happens when data
// layout got changed by Transpose or Reshape ops)
onnxruntime::NodeAttributes attributes = slice_node.GetAttributes();
if (info.non_negative_axis != new_axis) {
attributes[info.axis_attr_name] =
ONNX_NAMESPACE::MakeAttribute(info.axis_attr_name, static_cast<int64_t>(new_axis));

if (axis_input_index == -1 && info.non_negative_axis != new_axis) {
std::string attr_name = std::get<std::string>(info.axis_attr_name_or_input_index);
if (info.rank_of_axis_value == 0) {
attributes[attr_name] =
ONNX_NAMESPACE::MakeAttribute(attr_name, static_cast<int64_t>(new_axis));
} else if (info.rank_of_axis_value == 1) {
attributes[attr_name] =
ONNX_NAMESPACE::MakeAttribute(attr_name, std::vector<int64_t>{static_cast<int64_t>(new_axis)});
} else {
ORT_THROW("Unexpected rank of axis attribute value: " + std::to_string(info.rank_of_axis_value));
}
}

InlinedVector<NodeArg*> output_args;
Expand Down Expand Up @@ -183,7 +228,8 @@ SliceInfo UpStreamGatherGraphTransformer::PropagateSlicingForInput(
auto new_slice_out_arg = new_slice_node->MutableOutputDefs()[new_slice_output_index_to_connect];
UpdateSliceOutputShape(*new_slice_out_arg, new_axis, info.output_dim_on_axis);

auto new_slice_info = SliceInfo(graph, new_slice_node, info.is_scalar_slice, info.axis_attr_name, new_axis);
auto new_slice_info = SliceInfo(graph, new_slice_node, info.is_scalar_slice, info.axis_attr_name_or_input_index,
new_axis, info.rank_of_axis_value);
new_slice_info.entry_node_name = info.entry_node_name;
new_slice_info.entry_slice_arg_name = info.entry_slice_arg_name;
return new_slice_info;
Expand Down Expand Up @@ -263,7 +309,8 @@ std::optional<SliceInfo> IsSupportedGatherND(Graph& graph, Node& node,
return std::nullopt;
}

return SliceInfo(graph, &node, false, "batch_dims", static_cast<int>(batch_dims), true);
return SliceInfo(graph, &node, false, "batch_dims", static_cast<int>(batch_dims),
0 /* rank of axis attribute value */, true);
}

std::optional<SliceInfo> IsSupportedGather(Graph& graph, Node& node,
Expand Down Expand Up @@ -304,7 +351,7 @@ std::optional<SliceInfo> IsSupportedGather(Graph& graph, Node& node,
}
}

return SliceInfo(graph, &node, dim_size == 0, "axis", axis, true);
return SliceInfo(graph, &node, dim_size == 0, "axis", axis, 0 /* rank of axis attribute value */, true);
}

std::optional<SliceInfo> IsSupportedShrunkenGather(Graph& graph, Node& node,
Expand Down Expand Up @@ -342,7 +389,7 @@ std::optional<SliceInfo> IsSupportedShrunkenGather(Graph& graph, Node& node,
return std::nullopt;
}

return SliceInfo(graph, &node, false /*is_slice_scalar*/, "axis", axis, true);
return SliceInfo(graph, &node, false /*is_slice_scalar*/, "axis", axis, 0 /* rank of axis attribute value */, true);
}

/**
Expand All @@ -366,42 +413,46 @@ std::optional<SliceInfo> IsSupportedSlice(Graph& graph, Node& node,
const NodeArg* axes_input = node.InputDefs().size() > 3 ? node.InputDefs()[3] : nullptr;

if (data_input->Shape() == nullptr || starts_input->Shape() == nullptr || ends_input->Shape() == nullptr ||
(axes_input && axes_input->Shape() == nullptr)) {
(axes_input && axes_input->Exists() && axes_input->Shape() == nullptr)) {
LOG_DEBUG_INFO(logger, "Skip Slice node " + node.Name() + " due to undefined shape.");
return std::nullopt;
}

// Make sure starts/ends/axes/steps are all 1D tensors, since we only support single-dimension slicing.
if (starts_input->Shape()->dim_size() != 1 || ends_input->Shape()->dim_size() != 1 ||
(axes_input && axes_input->Shape()->dim_size() != 1)) {
(axes_input && axes_input->Exists() && axes_input->Shape()->dim_size() != 1)) {
LOG_DEBUG_INFO(logger, "Skip Slice node " + node.Name() + " due to unsupported dim size: " +
std::to_string(starts_input->Shape()->dim_size()) + ", " +
std::to_string(ends_input->Shape()->dim_size()) + ", " +
std::to_string(axes_input ? axes_input->Shape()->dim_size() : 0));
std::to_string(axes_input && axes_input->Exists() ? axes_input->Shape()->dim_size() : 0));
return std::nullopt;
}

// Try to parse the 'axes' value.
int axis = 0;
if (axes_input) {
if (axes_input && axes_input->Exists()) {
InlinedVector<int64_t> axes_values;
if (!graph_utils::IsConstantInitializer(graph, axes_input->Name()) ||
!optimizer_utils::AppendTensorFromInitializer(graph, *axes_input, axes_values, true) ||
axes_values.size() != 1) {
LOG_DEBUG_INFO(logger, "Skip Slice node " + node.Name() + " due to unsupported axes value.");
return std::nullopt;
}
axis = static_cast<int>(axes_values[0]);
} else {
// If 'axes' is not specified, then it is [0, .., r-1], so we force data rank to be 1.
if (data_input->Shape()->dim_size() != 1) {
LOG_DEBUG_INFO(logger, "Skip Slice node " + node.Name() + " due to unsupported data rank: " +
std::to_string(data_input->Shape()->dim_size()));
return std::nullopt;
}
}

if (axis < 0)
axis += data_input->Shape()->dim_size();

return SliceInfo(graph, &node, false /*is_slice_scalar*/, "axis", axis, true);
return SliceInfo(graph, &node, false /*is_slice_scalar*/, 3 /* axis input index */, axis,
1 /* rank of axes value */, true);
}

} // namespace
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,21 @@ struct SliceInfo : public UpstreamOperatorInfoBase {
public:
SliceInfo(const Graph& graph, Node* slice_node,
bool is_slice_scalar,
const std::string& slice_axis_attr_name,
std::variant<std::string, int> axis_name_or_index,
int slice_axis,
int rank_of_axis,
bool is_entry_node_ptr = false)
: UpstreamOperatorInfoBase(slice_node, is_entry_node_ptr), is_scalar_slice(is_slice_scalar) {
axis_attr_name = slice_axis_attr_name;
axis_attr_name_or_input_index = axis_name_or_index;
rank_of_axis_value = rank_of_axis;

if (std::holds_alternative<int>(axis_name_or_index)) {
int axis_input_index = std::get<int>(axis_name_or_index);
ORT_ENFORCE(axis_input_index >= 0, "Axis input index is invalid");
}

ORT_ENFORCE(rank_of_axis_value == 0 || rank_of_axis_value == 1, "Rank of axis value is invalid: " +
std::to_string(rank_of_axis_value));

const NodeArg* input = node_ptr->InputDefs()[kSliceDataInputIndex_];
const NodeArg* output = node_ptr->OutputDefs()[kSliceOutputIndex_];
Expand Down Expand Up @@ -65,8 +75,16 @@ struct SliceInfo : public UpstreamOperatorInfoBase {
}

bool is_scalar_slice; // whether the slice is a scalar, if it is after Gather, the rank will be reduced by 1.
std::string axis_attr_name;

// The index of the input that contains the axis value. If it is a string, then axis will be treated as an attribute.
std::variant<std::string, int> axis_attr_name_or_input_index;

int non_negative_axis; // The axis to slice on

// The rank of value for axis attribute. For example, for Gather, its axis attribute is a scalar, so the rank is 0.
// For Slice, its axes attribute is a 1D tensor, so the rank is 1.
int rank_of_axis_value;

std::string entry_slice_arg_name;

int input_rank; // rank of the Gather data input tensor
Expand Down
Loading