Skip to content

Support batch and classes for NonMaxSuppression #3999

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Address review comments and simplify lit tests
  • Loading branch information
praveen-g-ctt committed Feb 19, 2025
commit 420bbca68ec4307c0c2c0375ee4be6a4d4d93192
88 changes: 36 additions & 52 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3701,16 +3701,11 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
return rewriter.notifyMatchFailure(
binder.op, "expected center_point_box attribute to be 0 or 1");

Value cst0 = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(0));
Value cst1 = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(1));
Value cst2 = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(2));
Value cst3 = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(3));
Value cst4 = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(4));
Value cst0 = rewriter.create<Torch::ConstantIntOp>(loc, 0);
Value cst1 = rewriter.create<Torch::ConstantIntOp>(loc, 1);
Value cst2 = rewriter.create<Torch::ConstantIntOp>(loc, 2);
Value cst3 = rewriter.create<Torch::ConstantIntOp>(loc, 3);
Value cst4 = rewriter.create<Torch::ConstantIntOp>(loc, 4);
Value cst2F = rewriter.create<Torch::ConstantFloatOp>(
loc, rewriter.getF64FloatAttr(2.0));
Value cstNone = rewriter.create<Torch::ConstantNoneOp>(loc);
Expand Down Expand Up @@ -3813,36 +3808,18 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
// Create an empty tensor of shape (B*C*N, 3) to store the final result.
// We slice this to required elements at the end

// FIXME:: Currently empty tensors created with dynamic sizes are not
// fully supported. Uncomment the below lines once dynamic sizes for
// empty tensors are supported end to end.

/*
Value numResults = rewriter.create<Torch::AtenMulIntOp>(
loc, numClasses.getType(), numBatches, numClasses);
numResults = rewriter.create<Torch::AtenMulIntOp>(
loc, numClasses.getType(), numResults, maxOutputBoxesPerClass);
auto finalResultType = resultType;
*/

if (!scoreTensorType.toBuiltinTensor().hasStaticShape()) {
llvm_unreachable("Unimplemented: Encountered dynamic shaped tensors "
"while lowering Onnx NonMaxSuppression op to torch");
}
auto numResultElements =
scoreTensorType.toBuiltinTensor().getNumElements();
auto numResults = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(numResultElements));

auto intTy = rewriter.getType<Torch::IntType>();
auto intListTy = rewriter.getType<Torch::ListType>(intTy);

Value resultShapeList = rewriter.create<Torch::PrimListConstructOp>(
loc, intListTy, SmallVector<Value>{numResults, cst3});
auto finalResultType = rewriter.getType<Torch::ValueTensorType>(
ArrayRef<int64_t>{numResultElements, 3}, resultType.getDtype());
Value finalResult = rewriter.create<Torch::AtenEmptyMemoryFormatOp>(
loc, finalResultType, resultShapeList, /*dtype=*/cst4,
loc, resultType, resultShapeList, /*dtype=*/cst4,
/*layout=*/cstNone,
/*device=*/cstNone, /*pinMemory=*/cstNone,
/*memoryFormat=*/cstNone);
Expand All @@ -3855,16 +3832,15 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
SmallVector<int64_t>{}, nmsTy.getDtype());

auto nmsBatchLoop = rewriter.create<Torch::PrimLoopOp>(
loc, TypeRange({finalResultType, intTy, intTy}), numBatches,
cstTrue,
loc, TypeRange({resultType, intTy, intTy}), numBatches, cstTrue,
ValueRange({finalResult, /*Index to finalResult*/ cst0,
/*Num values in result*/ cst0}));
{
// Batch loop body
PatternRewriter::InsertionGuard guard(rewriter);
Block *batchLoopBody = rewriter.createBlock(
&nmsBatchLoop.getRegion(), nmsBatchLoop.getRegion().begin(),
TypeRange({intTy, finalResultType, intTy, intTy}),
TypeRange({intTy, resultType, intTy, intTy}),
{loc, loc, loc, loc});

auto batchIV = batchLoopBody->getArgument(0);
Expand All @@ -3877,31 +3853,30 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
auto batchValue = rewriter.create<Torch::PrimNumToTensorScalarOp>(
loc, emptyTensorTy, batchIV);

auto scoreSelect = rewriter.create<Torch::AtenSelectIntOp>(
loc, scoreSlicedType, scores, cst0, batchIV);
auto scoreSelectType =
cast<Torch::ValueTensorType>(scoreSelect.getType());
auto scoreValueType = rewriter.getType<Torch::ValueTensorType>(
scoreSelectType.getSizes().slice(1), scoreSelectType.getDtype());

auto nmsClassLoop = rewriter.create<Torch::PrimLoopOp>(
loc, TypeRange({finalResultType, intTy, intTy}), numClasses,
cstTrue, ValueRange({currRes, finalResIdx, numResultValues}));
loc, TypeRange({resultType, intTy, intTy}), numClasses, cstTrue,
ValueRange({currRes, finalResIdx, numResultValues}));

{
// Class loop body
PatternRewriter::InsertionGuard guard(rewriter);
Block *classLoopBody = rewriter.createBlock(
&nmsClassLoop.getRegion(), nmsClassLoop.getRegion().begin(),
TypeRange({intTy, finalResultType, intTy, intTy}),
TypeRange({intTy, resultType, intTy, intTy}),
{loc, loc, loc, loc});

auto classIV = classLoopBody->getArgument(0);
auto currRes = classLoopBody->getArgument(1);
auto finalResIdx = classLoopBody->getArgument(2);
Value numResultValues = classLoopBody->getArgument(3);

auto scoreSelect = rewriter.create<Torch::AtenSelectIntOp>(
loc, scoreSlicedType, scores, cst0, batchIV);
auto scoreSelectType =
cast<Torch::ValueTensorType>(scoreSelect.getType());
auto scoreValueType = rewriter.getType<Torch::ValueTensorType>(
scoreSelectType.getSizes().slice(1),
scoreSelectType.getDtype());

auto scoreValue = rewriter.create<Torch::AtenSelectIntOp>(
loc, scoreValueType, scoreSelect, cst0, classIV);
auto classValue = rewriter.create<Torch::PrimNumToTensorScalarOp>(
Expand All @@ -3920,20 +3895,30 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
rewriter.create<Torch::PrimNumToTensorScalarOp>(
loc, emptyTensorTy, maxOutputBoxesPerClass);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would use tensor type with shape [1] instead of [] since those few arguments are coming with [1] and you do Minimum op with them.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both the values passed to Minimum op are scalars, which are from aten.size.int op, so had used the shape [] for minimum op

auto minVal = rewriter.create<Torch::AtenMinimumOp>(
loc, numOutputBoxes.getType(), numOutputBoxes,
maxBoxesPerClass);
loc, emptyTensorTy, numOutputBoxes, maxBoxesPerClass);
numOutputBoxes =
rewriter.create<Torch::AtenItemOp>(loc, intTy, minVal);

// Loop through the nms result
// The resulting shape of torchvision nms op is [num_selected] while
// that of onnx is [num_selected, 3] where the selected format is
// [batch_index, class_index, box_index].
// Insert the triplet [batch_index, class_index, box_index] into
// `finalResult` element by element for each box.

// TODO:: This can be simplified by concatinating the result of nms
// with that of tensors filled with batch and class indices instead
// of using the below loop. Currently this approach results in
// failures while lowering due to dynamic dims

auto nmsLoop = rewriter.create<Torch::PrimLoopOp>(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm finding it difficult to parse this loop.

The result of the (per-batch per-channel) torchvision nms op has shape <num_selected>, and we need it to be <num_selected x 3>, where each triple is like [batch_index, class_index, selected_box_index]. Is the purpose of this loop to insert these elements into the final result? Is it possible to avoid using a loop for this and instead concatenate the nms result with some splat tensors, then insert that into the final result by keeping track of what the cumulative num_selected is?

Copy link
Contributor Author

@praveen-g-ctt praveen-g-ctt Feb 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zjgarvey @jinchen62 Updated the comments for the nmsLoop part. This loop is used to insert the triplet [batch_index, class_index, selected_box_index] at the required indices element by element.

" Is the purpose of this loop to insert these elements into the final result? Is it possible to avoid using a loop for this and instead concatenate the nms result with some splat tensors, then insert that into the final result by keeping track of what the cumulative num_selected is?"

-> Yes, I had already tried the approach with splat + concats as part of #3981
I was running into runtime issues like segfault / invalid mem access due to non handling of dynamic dims in IREE.

The IR using concat + splat method is here

I made use of loops so that we can have a working solution initially and then update the logic once issues in IREE are fixed. Please let me know your thoughts on this!

loc, TypeRange({finalResultType, intTy}), numOutputBoxes,
cstTrue, ValueRange({currRes, finalResIdx}));
loc, TypeRange({resultType, intTy}), numOutputBoxes, cstTrue,
ValueRange({currRes, finalResIdx}));
{
PatternRewriter::InsertionGuard guard(rewriter);
Block *loopBody = rewriter.createBlock(
&nmsLoop.getRegion(), nmsLoop.getRegion().begin(),
TypeRange({intTy, finalResultType, intTy}), {loc, loc, loc});
TypeRange({intTy, resultType, intTy}), {loc, loc, loc});
auto iter = loopBody->getArgument(0);
auto currRes = loopBody->getArgument(1);
auto idxCst = loopBody->getArgument(2);
Expand All @@ -3955,7 +3940,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
auto scatterBatch = rewriter.create<Torch::AtenSelectScatterOp>(
loc, outputTensorSliceType, batchDim3D, bCopy, cst0, cst0);
auto batchResult = rewriter.create<Torch::AtenSelectScatterOp>(
loc, finalResultType, currRes, scatterBatch, cst0, idxCst);
loc, resultType, currRes, scatterBatch, cst0, idxCst);

// Update class dimension
auto classDim3D = rewriter.create<Torch::AtenSelectIntOp>(
Expand All @@ -3970,8 +3955,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
auto scatterClass = rewriter.create<Torch::AtenSelectScatterOp>(
loc, outputTensorSliceType, classDim3D, cCopy, cst0, cst1);
auto classRes = rewriter.create<Torch::AtenSelectScatterOp>(
loc, finalResultType, batchResult, scatterClass, cst0,
idxCst);
loc, resultType, batchResult, scatterClass, cst0, idxCst);

// Update nms result dimension
auto resDim3D = rewriter.create<Torch::AtenSelectIntOp>(
Expand All @@ -3988,7 +3972,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
auto scatterRes = rewriter.create<Torch::AtenSelectScatterOp>(
loc, outputTensorSliceType, resDim3D, rCopy, cst0, cst2);
Value nmsResult = rewriter.create<Torch::AtenSelectScatterOp>(
loc, finalResultType, classRes, scatterRes, cst0, idxCst);
loc, resultType, classRes, scatterRes, cst0, idxCst);

// Increment the result index
Value next =
Expand Down
Loading
Loading