-
Notifications
You must be signed in to change notification settings - Fork 611
Support batch and classes for NonMaxSuppression #3999
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3701,16 +3701,11 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( | |
return rewriter.notifyMatchFailure( | ||
binder.op, "expected center_point_box attribute to be 0 or 1"); | ||
|
||
Value cst0 = rewriter.create<Torch::ConstantIntOp>( | ||
loc, rewriter.getI64IntegerAttr(0)); | ||
Value cst1 = rewriter.create<Torch::ConstantIntOp>( | ||
loc, rewriter.getI64IntegerAttr(1)); | ||
Value cst2 = rewriter.create<Torch::ConstantIntOp>( | ||
loc, rewriter.getI64IntegerAttr(2)); | ||
Value cst3 = rewriter.create<Torch::ConstantIntOp>( | ||
loc, rewriter.getI64IntegerAttr(3)); | ||
Value cst4 = rewriter.create<Torch::ConstantIntOp>( | ||
loc, rewriter.getI64IntegerAttr(4)); | ||
Value cst0 = rewriter.create<Torch::ConstantIntOp>(loc, 0); | ||
Value cst1 = rewriter.create<Torch::ConstantIntOp>(loc, 1); | ||
Value cst2 = rewriter.create<Torch::ConstantIntOp>(loc, 2); | ||
Value cst3 = rewriter.create<Torch::ConstantIntOp>(loc, 3); | ||
Value cst4 = rewriter.create<Torch::ConstantIntOp>(loc, 4); | ||
Value cst2F = rewriter.create<Torch::ConstantFloatOp>( | ||
loc, rewriter.getF64FloatAttr(2.0)); | ||
Value cstNone = rewriter.create<Torch::ConstantNoneOp>(loc); | ||
|
@@ -3813,36 +3808,18 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( | |
// Create an empty tensor of shape (B*C*N, 3) to store the final result. | ||
// We slice this to required elements at the end | ||
|
||
// FIXME:: Currently empty tensors created with dynamic sizes are not | ||
// fully supported. Uncomment the below lines once dynamic sizes for | ||
// empty tensors are supported end to end. | ||
|
||
/* | ||
Value numResults = rewriter.create<Torch::AtenMulIntOp>( | ||
loc, numClasses.getType(), numBatches, numClasses); | ||
numResults = rewriter.create<Torch::AtenMulIntOp>( | ||
loc, numClasses.getType(), numResults, maxOutputBoxesPerClass); | ||
auto finalResultType = resultType; | ||
*/ | ||
|
||
if (!scoreTensorType.toBuiltinTensor().hasStaticShape()) { | ||
llvm_unreachable("Unimplemented: Encountered dynamic shaped tensors " | ||
"while lowering Onnx NonMaxSuppression op to torch"); | ||
} | ||
auto numResultElements = | ||
scoreTensorType.toBuiltinTensor().getNumElements(); | ||
auto numResults = rewriter.create<Torch::ConstantIntOp>( | ||
loc, rewriter.getI64IntegerAttr(numResultElements)); | ||
|
||
auto intTy = rewriter.getType<Torch::IntType>(); | ||
auto intListTy = rewriter.getType<Torch::ListType>(intTy); | ||
|
||
Value resultShapeList = rewriter.create<Torch::PrimListConstructOp>( | ||
loc, intListTy, SmallVector<Value>{numResults, cst3}); | ||
auto finalResultType = rewriter.getType<Torch::ValueTensorType>( | ||
ArrayRef<int64_t>{numResultElements, 3}, resultType.getDtype()); | ||
Value finalResult = rewriter.create<Torch::AtenEmptyMemoryFormatOp>( | ||
loc, finalResultType, resultShapeList, /*dtype=*/cst4, | ||
loc, resultType, resultShapeList, /*dtype=*/cst4, | ||
/*layout=*/cstNone, | ||
/*device=*/cstNone, /*pinMemory=*/cstNone, | ||
/*memoryFormat=*/cstNone); | ||
|
@@ -3855,16 +3832,15 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( | |
SmallVector<int64_t>{}, nmsTy.getDtype()); | ||
|
||
auto nmsBatchLoop = rewriter.create<Torch::PrimLoopOp>( | ||
loc, TypeRange({finalResultType, intTy, intTy}), numBatches, | ||
cstTrue, | ||
loc, TypeRange({resultType, intTy, intTy}), numBatches, cstTrue, | ||
ValueRange({finalResult, /*Index to finalResult*/ cst0, | ||
/*Num values in result*/ cst0})); | ||
{ | ||
// Batch loop body | ||
PatternRewriter::InsertionGuard guard(rewriter); | ||
Block *batchLoopBody = rewriter.createBlock( | ||
&nmsBatchLoop.getRegion(), nmsBatchLoop.getRegion().begin(), | ||
TypeRange({intTy, finalResultType, intTy, intTy}), | ||
TypeRange({intTy, resultType, intTy, intTy}), | ||
{loc, loc, loc, loc}); | ||
|
||
auto batchIV = batchLoopBody->getArgument(0); | ||
|
@@ -3877,31 +3853,30 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( | |
auto batchValue = rewriter.create<Torch::PrimNumToTensorScalarOp>( | ||
loc, emptyTensorTy, batchIV); | ||
|
||
auto scoreSelect = rewriter.create<Torch::AtenSelectIntOp>( | ||
loc, scoreSlicedType, scores, cst0, batchIV); | ||
auto scoreSelectType = | ||
cast<Torch::ValueTensorType>(scoreSelect.getType()); | ||
auto scoreValueType = rewriter.getType<Torch::ValueTensorType>( | ||
scoreSelectType.getSizes().slice(1), scoreSelectType.getDtype()); | ||
|
||
auto nmsClassLoop = rewriter.create<Torch::PrimLoopOp>( | ||
loc, TypeRange({finalResultType, intTy, intTy}), numClasses, | ||
cstTrue, ValueRange({currRes, finalResIdx, numResultValues})); | ||
loc, TypeRange({resultType, intTy, intTy}), numClasses, cstTrue, | ||
ValueRange({currRes, finalResIdx, numResultValues})); | ||
|
||
{ | ||
// Class loop body | ||
PatternRewriter::InsertionGuard guard(rewriter); | ||
Block *classLoopBody = rewriter.createBlock( | ||
&nmsClassLoop.getRegion(), nmsClassLoop.getRegion().begin(), | ||
TypeRange({intTy, finalResultType, intTy, intTy}), | ||
TypeRange({intTy, resultType, intTy, intTy}), | ||
{loc, loc, loc, loc}); | ||
|
||
auto classIV = classLoopBody->getArgument(0); | ||
auto currRes = classLoopBody->getArgument(1); | ||
auto finalResIdx = classLoopBody->getArgument(2); | ||
Value numResultValues = classLoopBody->getArgument(3); | ||
|
||
auto scoreSelect = rewriter.create<Torch::AtenSelectIntOp>( | ||
loc, scoreSlicedType, scores, cst0, batchIV); | ||
auto scoreSelectType = | ||
cast<Torch::ValueTensorType>(scoreSelect.getType()); | ||
auto scoreValueType = rewriter.getType<Torch::ValueTensorType>( | ||
scoreSelectType.getSizes().slice(1), | ||
scoreSelectType.getDtype()); | ||
|
||
auto scoreValue = rewriter.create<Torch::AtenSelectIntOp>( | ||
loc, scoreValueType, scoreSelect, cst0, classIV); | ||
auto classValue = rewriter.create<Torch::PrimNumToTensorScalarOp>( | ||
|
@@ -3920,20 +3895,30 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( | |
rewriter.create<Torch::PrimNumToTensorScalarOp>( | ||
loc, emptyTensorTy, maxOutputBoxesPerClass); | ||
auto minVal = rewriter.create<Torch::AtenMinimumOp>( | ||
loc, numOutputBoxes.getType(), numOutputBoxes, | ||
maxBoxesPerClass); | ||
loc, emptyTensorTy, numOutputBoxes, maxBoxesPerClass); | ||
numOutputBoxes = | ||
rewriter.create<Torch::AtenItemOp>(loc, intTy, minVal); | ||
|
||
// Loop through the nms result | ||
// The resulting shape of torchvision nms op is [num_selected] while | ||
// that of onnx is [num_selected, 3] where the selected format is | ||
// [batch_index, class_index, box_index]. | ||
// Insert the triplet [batch_index, class_index, box_index] into | ||
// `finalResult` element by element for each box. | ||
|
||
// TODO:: This can be simplified by concatinating the result of nms | ||
// with that of tensors filled with batch and class indices instead | ||
// of using the below loop. Currently this approach results in | ||
// failures while lowering due to dynamic dims | ||
|
||
auto nmsLoop = rewriter.create<Torch::PrimLoopOp>( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm finding it difficult to parse this loop. The result of the (per-batch per-channel) torchvision nms op has shape There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @zjgarvey @jinchen62 Updated the comments for the nmsLoop part. This loop is used to insert the triplet [batch_index, class_index, selected_box_index] at the required indices element by element. " Is the purpose of this loop to insert these elements into the final result? Is it possible to avoid using a loop for this and instead concatenate the nms result with some splat tensors, then insert that into the final result by keeping track of what the cumulative num_selected is?" -> Yes, I had already tried the approach with splat + concats as part of #3981 The IR using concat + splat method is here I made use of loops so that we can have a working solution initially and then update the logic once issues in IREE are fixed. Please let me know your thoughts on this! |
||
loc, TypeRange({finalResultType, intTy}), numOutputBoxes, | ||
cstTrue, ValueRange({currRes, finalResIdx})); | ||
loc, TypeRange({resultType, intTy}), numOutputBoxes, cstTrue, | ||
ValueRange({currRes, finalResIdx})); | ||
{ | ||
PatternRewriter::InsertionGuard guard(rewriter); | ||
Block *loopBody = rewriter.createBlock( | ||
&nmsLoop.getRegion(), nmsLoop.getRegion().begin(), | ||
TypeRange({intTy, finalResultType, intTy}), {loc, loc, loc}); | ||
TypeRange({intTy, resultType, intTy}), {loc, loc, loc}); | ||
auto iter = loopBody->getArgument(0); | ||
auto currRes = loopBody->getArgument(1); | ||
auto idxCst = loopBody->getArgument(2); | ||
|
@@ -3955,7 +3940,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( | |
auto scatterBatch = rewriter.create<Torch::AtenSelectScatterOp>( | ||
loc, outputTensorSliceType, batchDim3D, bCopy, cst0, cst0); | ||
auto batchResult = rewriter.create<Torch::AtenSelectScatterOp>( | ||
loc, finalResultType, currRes, scatterBatch, cst0, idxCst); | ||
loc, resultType, currRes, scatterBatch, cst0, idxCst); | ||
|
||
// Update class dimension | ||
auto classDim3D = rewriter.create<Torch::AtenSelectIntOp>( | ||
|
@@ -3970,8 +3955,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( | |
auto scatterClass = rewriter.create<Torch::AtenSelectScatterOp>( | ||
loc, outputTensorSliceType, classDim3D, cCopy, cst0, cst1); | ||
auto classRes = rewriter.create<Torch::AtenSelectScatterOp>( | ||
loc, finalResultType, batchResult, scatterClass, cst0, | ||
idxCst); | ||
loc, resultType, batchResult, scatterClass, cst0, idxCst); | ||
|
||
// Update nms result dimension | ||
auto resDim3D = rewriter.create<Torch::AtenSelectIntOp>( | ||
|
@@ -3988,7 +3972,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( | |
auto scatterRes = rewriter.create<Torch::AtenSelectScatterOp>( | ||
loc, outputTensorSliceType, resDim3D, rCopy, cst0, cst2); | ||
Value nmsResult = rewriter.create<Torch::AtenSelectScatterOp>( | ||
loc, finalResultType, classRes, scatterRes, cst0, idxCst); | ||
loc, resultType, classRes, scatterRes, cst0, idxCst); | ||
|
||
// Increment the result index | ||
Value next = | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would use tensor type with shape [1] instead of [] since those few arguments are coming with [1] and you do Minimum op with them.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Both the values passed to Minimum op are scalars, which are from aten.size.int op, so had used the shape [] for minimum op