// input split is omitted, default to split by 1. const auto* const attr_proto = ctx.getAttribute("keepdims"); if (attr_proto) { keepdims = attr_proto->i(); } } else { splitSize = [&]() -> int64_t { // Need input split shape info and initializer data to infer split sizes. if (!hasInputShape(ctx, 1)) { return -1; } const TensorProto* splitInitializer = ctx.getInputData(1); if (nullptr == splitInitializer || !splitInitializer->has_data_type()) { return -1; } std::vector splitSizes; if (splitInitializer->data_type() == TensorProto::INT64) { const auto data = ParseData(splitInitializer); splitSizes.insert(splitSizes.end(), data.begin(), data.end()); } else if (splitInitializer->data_type() == TensorProto::INT32) { const auto data = ParseData(splitInitializer); splitSizes.insert(splitSizes.end(), data.begin(), data.end()); } else { // unaccepted data type fail_shape_inference("Only supports `int32_t` or `int64_t` inputs for split"); } if (splitSizes.empty()) { fail_shape_inference("Input 'split' can not be empty."); } const auto& splitDim = inputShape.dim(axis); if (!splitDim.has_dim_value()) { // Unable to verify nor infer exact split dimension size. return -1; } int64_t splitDimValue = splitDim.dim_value(); const auto& splitShape = getInputShape(ctx, 1); if (splitShape.dim_size() == 0) { // split is scalar if (splitDimValue % splitSizes[0] == 0) { // all output chunks have the same shape, assign that to output sequence shape. return splitSizes[0]; } return -1; } else { // split is 1-D tensor int64_t splitSizesSum = std::accumulate(splitSizes.begin(), splitSizes.end(), static_cast(0)); if (splitDimValue != splitSizesSum) { fail_shape_inference( "Sum of split values not equal to 'input' dim size on 'axis'. 'axis' dim size=", splitDimValue, " sum of split values=", splitSizesSum); } if (std::adjacent_find(splitSizes.begin(), splitSizes.end(), std::not_equal_to()) == splitSizes.end()) { // all split sizes are the same. return splitSizes[0]; }