-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Object Detection using TorchSharp #6605
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 4 commits
Commits
Show all changes
14 commits
Select commit
Hold shift + click to select a range
a6dec6a
only base model files committed
michaelgsharp 07c9f40
builds working, finishing tests
michaelgsharp d1f564f
minor image errors
michaelgsharp 44703ca
image updates
michaelgsharp 9f38813
updates from PR comments, minor bug fixese
michaelgsharp 4060072
minor changes from PR
michaelgsharp d33701d
minor changes from PR and build fixes
michaelgsharp d66ace1
changed testing epochs to 1 so tests wont time out
michaelgsharp 785859a
minor changes for PR
michaelgsharp 0dd722c
added predicted box column
michaelgsharp d2a2fd9
Update src/Microsoft.ML.TorchSharp/AutoFormerV2/ObjectDetectionMetric…
michaelgsharp 7ed3bd4
fix for metrics
michaelgsharp 7b5299d
minor test fixes
michaelgsharp 792407d
Merge branch 'obj-detection' of https://github.com/michaelgsharp/mach…
michaelgsharp File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,153 @@ | ||
// Licensed to the .NET Foundation under one or more agreements. | ||
// The .NET Foundation licenses this file to you under the MIT license. | ||
// See the LICENSE file in the project root for more information. | ||
|
||
using System; | ||
using System.Collections.Generic; | ||
using System.Linq; | ||
using TorchSharp; | ||
using static TorchSharp.torch; | ||
using static TorchSharp.torch.nn; | ||
|
||
namespace Microsoft.ML.TorchSharp.AutoFormerV2 | ||
{ | ||
/// <summary> | ||
/// Anchor boxes are a set of predefined bounding boxes of a certain height and width, whose location and size can be adjusted by the regression head of model. | ||
/// </summary> | ||
public class Anchors : Module<Tensor, Tensor> | ||
michaelgsharp marked this conversation as resolved.
Show resolved
Hide resolved
|
||
{ | ||
[System.Diagnostics.CodeAnalysis.SuppressMessage("Naming", "MSML_PrivateFieldName:private field names not in _camelCase format", Justification = "Need to match TorchSharp.")] | ||
private readonly int[] pyramidLevels; | ||
|
||
[System.Diagnostics.CodeAnalysis.SuppressMessage("Naming", "MSML_PrivateFieldName:private field names not in _camelCase format", Justification = "Need to match TorchSharp.")] | ||
private readonly int[] strides; | ||
|
||
[System.Diagnostics.CodeAnalysis.SuppressMessage("Naming", "MSML_PrivateFieldName:private field names not in _camelCase format", Justification = "Need to match TorchSharp.")] | ||
private readonly int[] sizes; | ||
|
||
[System.Diagnostics.CodeAnalysis.SuppressMessage("Naming", "MSML_PrivateFieldName:private field names not in _camelCase format", Justification = "Need to match TorchSharp.")] | ||
private readonly double[] ratios; | ||
|
||
[System.Diagnostics.CodeAnalysis.SuppressMessage("Naming", "MSML_PrivateFieldName:private field names not in _camelCase format", Justification = "Need to match TorchSharp.")] | ||
private readonly double[] scales; | ||
|
||
/// <summary> | ||
/// Initializes a new instance of the <see cref="Anchors"/> class. | ||
/// </summary> | ||
/// <param name="pyramidLevels">Pyramid levels.</param> | ||
/// <param name="strides">Strides between adjacent bboxes.</param> | ||
/// <param name="sizes">Different sizes for bboxes.</param> | ||
/// <param name="ratios">Different ratios for height/width.</param> | ||
/// <param name="scales">Scale size of bboxes.</param> | ||
public Anchors(int[] pyramidLevels = null, int[] strides = null, int[] sizes = null, double[] ratios = null, double[] scales = null) | ||
: base(nameof(Anchors)) | ||
{ | ||
this.pyramidLevels = pyramidLevels != null ? pyramidLevels : new int[] { 3, 4, 5, 6, 7 }; | ||
this.strides = strides != null ? strides : this.pyramidLevels.Select(x => (int)Math.Pow(2, x)).ToArray(); | ||
this.sizes = sizes != null ? sizes : this.pyramidLevels.Select(x => (int)Math.Pow(2, x + 2)).ToArray(); | ||
this.ratios = ratios != null ? ratios : new double[] { 0.5, 1, 2 }; | ||
this.scales = scales != null ? scales : new double[] { Math.Pow(2, 0), Math.Pow(2, 1.0 / 3.0), Math.Pow(2, 2.0 / 3.0) }; | ||
} | ||
|
||
/// <summary> | ||
/// Generate anchors for an image. | ||
/// </summary> | ||
/// <param name="image">Image in Tensor format.</param> | ||
/// <returns>All anchors.</returns> | ||
[System.Diagnostics.CodeAnalysis.SuppressMessage("Naming", "MSML_GeneralName:This name should be PascalCased", Justification = "Need to match TorchSharp.")] | ||
public override Tensor forward(Tensor image) | ||
JakeRadMSFT marked this conversation as resolved.
Show resolved
Hide resolved
|
||
{ | ||
using (var scope = torch.NewDisposeScope()) | ||
{ | ||
var imageShape = torch.tensor(image.shape.AsSpan().Slice(2).ToArray()); | ||
|
||
// compute anchors over all pyramid levels | ||
var allAnchors = torch.zeros(new long[] { 0, 4 }, dtype: torch.float32); | ||
|
||
for (int idx = 0; idx < this.pyramidLevels.Length; ++idx) | ||
{ | ||
var x = this.pyramidLevels[idx]; | ||
var shape = ((imageShape + Math.Pow(2, x) - 1) / Math.Pow(2, x)).to_type(torch.int32); | ||
var anchors = GenerateAnchors( | ||
baseSize: this.sizes[idx], | ||
ratios: this.ratios, | ||
scales: this.scales); | ||
var shiftedAnchors = Shift(shape, this.strides[idx], anchors); | ||
allAnchors = torch.cat(new List<Tensor>() { allAnchors, shiftedAnchors }, dim: 0); | ||
} | ||
|
||
var output = allAnchors.unsqueeze(dim: 0); | ||
output = output.to(image.device); | ||
|
||
return output.MoveToOuterDisposeScope(); | ||
} | ||
} | ||
|
||
/// <summary> | ||
/// Generate a set of anchors given size, ratios and scales. | ||
/// </summary> | ||
/// <param name="baseSize">Base size for width and height.</param> | ||
/// <param name="ratios">Ratios for height/width.</param> | ||
/// <param name="scales">Scales to resize base size.</param> | ||
/// <returns>A set of anchors.</returns> | ||
private static Tensor GenerateAnchors(int baseSize = 16, double[] ratios = null, double[] scales = null) | ||
{ | ||
using (var anchorsScope = torch.NewDisposeScope()) | ||
{ | ||
ratios ??= new double[] { 0.5, 1, 2 }; | ||
scales ??= new double[] { Math.Pow(2, 0), Math.Pow(2, 1.0 / 3.0), Math.Pow(2, 2.0 / 3.0) }; | ||
|
||
var numAnchors = ratios.Length * scales.Length; | ||
|
||
// initialize output anchors | ||
var anchors = torch.zeros(new long[] { numAnchors, 4 }, dtype: torch.float32); | ||
|
||
// scale base_size | ||
anchors[.., 2..] = baseSize * torch.tile(scales, new long[] { 2, ratios.Length }).transpose(1, 0); | ||
|
||
// compute areas of anchors | ||
var areas = torch.mul(anchors[.., 2], anchors[.., 3]); | ||
|
||
// correct for ratios | ||
anchors[.., 2] = torch.sqrt(areas / torch.repeat_interleave(ratios, new long[] { scales.Length })); | ||
anchors[.., 3] = torch.mul(anchors[.., 2], torch.repeat_interleave(ratios, new long[] { scales.Length })); | ||
|
||
// transform from (x_ctr, y_ctr, w, h) -> (x1, y1, x2, y2) | ||
anchors[.., torch.TensorIndex.Tensor(torch.tensor(new long[] { 0, 2 }, dtype: torch.int64))] -= torch.tile(anchors[.., 2] * 0.5, new long[] { 2, 1 }).T; | ||
anchors[.., torch.TensorIndex.Tensor(torch.tensor(new long[] { 1, 3 }, dtype: torch.int64))] -= torch.tile(anchors[.., 3] * 0.5, new long[] { 2, 1 }).T; | ||
|
||
return anchors.MoveToOuterDisposeScope(); | ||
} | ||
} | ||
|
||
/// <summary> | ||
/// Duplicate and distribute anchors to different positions give border of positions and stride between positions. | ||
/// </summary> | ||
/// <param name="shape">Border to distribute anchors.</param> | ||
/// <param name="stride">Stride between adjacent anchors.</param> | ||
/// <param name="anchors">Anchors to distribute.</param> | ||
/// <returns>The shifted anchors.</returns> | ||
private static Tensor Shift(Tensor shape, int stride, Tensor anchors) | ||
{ | ||
using (var anchorsScope = torch.NewDisposeScope()) | ||
{ | ||
Tensor shiftX = (torch.arange(start: 0, stop: (int)shape[1]) + 0.5) * stride; | ||
Tensor shiftY = (torch.arange(start: 0, stop: (int)shape[0]) + 0.5) * stride; | ||
|
||
var shiftXExpand = torch.repeat_interleave(shiftX.reshape(new long[] { shiftX.shape[0], 1 }), shiftY.shape[0], dim: 1); | ||
shiftXExpand = shiftXExpand.transpose(0, 1).reshape(-1); | ||
var shiftYExpand = torch.repeat_interleave(shiftY, shiftX.shape[0]); | ||
|
||
List<Tensor> tensors = new List<Tensor> { shiftXExpand, shiftYExpand, shiftXExpand, shiftYExpand }; | ||
var shifts = torch.vstack(tensors).transpose(0, 1); | ||
|
||
var a = anchors.shape[0]; | ||
var k = shifts.shape[0]; | ||
var allAnchors = anchors.reshape(new long[] { 1, a, 4 }) + shifts.reshape(new long[] { 1, k, 4 }).transpose(0, 1); | ||
allAnchors = allAnchors.reshape(new long[] { k * a, 4 }); | ||
|
||
return allAnchors.MoveToOuterDisposeScope(); | ||
} | ||
} | ||
} | ||
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,135 @@ | ||
// Licensed to the .NET Foundation under one or more agreements. | ||
// The .NET Foundation licenses this file to you under the MIT license. | ||
// See the LICENSE file in the project root for more information. | ||
|
||
using System; | ||
using System.Collections.Generic; | ||
using TorchSharp; | ||
using TorchSharp.Modules; | ||
using static TorchSharp.torch; | ||
using static TorchSharp.torch.nn; | ||
|
||
namespace Microsoft.ML.TorchSharp.AutoFormerV2 | ||
{ | ||
/// <summary> | ||
/// The Attention layer. | ||
/// </summary> | ||
public class Attention : Module<Tensor, Tensor, Tensor> | ||
{ | ||
#pragma warning disable MSML_PrivateFieldName // Need to match TorchSharp model names. | ||
private readonly int numHeads; | ||
private readonly double scale; | ||
private readonly int keyChannels; | ||
private readonly int nHkD; | ||
private readonly int d; | ||
private readonly int dh; | ||
private readonly double attnRatio; | ||
|
||
private readonly LayerNorm norm; | ||
private readonly Linear qkv; | ||
private readonly Linear proj; | ||
private readonly Parameter attention_biases; | ||
private readonly TensorIndex attention_bias_idxs; | ||
private readonly Softmax softmax; | ||
#pragma warning restore MSML_PrivateFieldName | ||
|
||
|
||
/// <summary> | ||
/// Initializes a new instance of the <see cref="Attention"/> class. | ||
/// </summary> | ||
/// <param name="inChannels">The input channels.</param> | ||
/// <param name="keyChannels">The key channels.</param> | ||
/// <param name="numHeads">The number of blocks.</param> | ||
/// <param name="attnRatio">The ratio of attention.</param> | ||
/// <param name="windowResolution">The resolution of window.</param> | ||
public Attention(int inChannels, int keyChannels, int numHeads = 8, int attnRatio = 4, List<int> windowResolution = null) | ||
: base(nameof(Attention)) | ||
{ | ||
windowResolution ??= new List<int>() { 14, 14 }; | ||
this.numHeads = numHeads; | ||
this.scale = System.Math.Pow(keyChannels, -0.5); | ||
this.keyChannels = keyChannels; | ||
this.nHkD = numHeads * keyChannels; | ||
this.d = attnRatio * keyChannels; | ||
this.dh = this.d * numHeads; | ||
this.attnRatio = attnRatio; | ||
int h = this.dh + (this.nHkD * 2); | ||
|
||
this.norm = nn.LayerNorm(new long[] { inChannels }); | ||
this.qkv = nn.Linear(inChannels, h); | ||
this.proj = nn.Linear(this.dh, inChannels); | ||
|
||
var points = new List<List<int>>(); | ||
for (int i = 0; i < windowResolution[0]; i++) | ||
{ | ||
for (int j = 0; j < windowResolution[1]; j++) | ||
{ | ||
points.Add(new List<int>() { i, j }); | ||
} | ||
} | ||
|
||
int n = points.Count; | ||
var attentionOffsets = new Dictionary<Tuple<int, int>, int>(); | ||
var idxs = new List<int>(); | ||
var idxsTensor = torch.zeros(new long[] { n, n }, dtype: torch.int64); | ||
for (int i = 0; i < n; i++) | ||
{ | ||
for (int j = 0; j < n; j++) | ||
{ | ||
var offset = new Tuple<int, int>(Math.Abs(points[i][0] - points[j][0]), Math.Abs(points[i][1] - points[j][1])); | ||
if (!attentionOffsets.ContainsKey(offset)) | ||
{ | ||
attentionOffsets.Add(offset, attentionOffsets.Count); | ||
} | ||
|
||
idxs.Add(attentionOffsets[offset]); | ||
idxsTensor[i][j] = attentionOffsets[offset]; | ||
} | ||
} | ||
|
||
this.attention_biases = nn.Parameter(torch.zeros(numHeads, attentionOffsets.Count)); | ||
this.attention_bias_idxs = TensorIndex.Tensor(idxsTensor); | ||
this.softmax = nn.Softmax(dim: -1); | ||
} | ||
|
||
/// <inheritdoc/> | ||
[System.Diagnostics.CodeAnalysis.SuppressMessage("Naming", "MSML_GeneralName:This name should be PascalCased", Justification = "Need to match TorchSharp.")] | ||
public override Tensor forward(Tensor x, Tensor mask) | ||
{ | ||
using (var scope = torch.NewDisposeScope()) | ||
{ | ||
long b = x.shape[0]; | ||
long n = x.shape[1]; | ||
long c = x.shape[2]; | ||
x = this.norm.forward(x); | ||
var qkv = this.qkv.forward(x); | ||
qkv = qkv.view(b, n, this.numHeads, -1); | ||
var tmp = qkv.split(new long[] { this.keyChannels, this.keyChannels, this.d }, dim: 3); | ||
var q = tmp[0]; | ||
var k = tmp[1]; | ||
var v = tmp[2]; | ||
q = q.permute(0, 2, 1, 3); | ||
k = k.permute(0, 2, 1, 3); | ||
v = v.permute(0, 2, 1, 3); | ||
|
||
var attn = (torch.matmul(q, k.transpose(-2, -1)) * this.scale) + this.attention_biases[.., this.attention_bias_idxs]; | ||
if (!(mask is null)) | ||
{ | ||
long nW = mask.shape[0]; | ||
attn = attn.view(-1, nW, this.numHeads, n, n) + mask.unsqueeze(1).unsqueeze(0); | ||
attn = attn.view(-1, this.numHeads, n, n); | ||
attn = this.softmax.forward(attn); | ||
} | ||
else | ||
{ | ||
attn = this.softmax.forward(attn); | ||
} | ||
|
||
x = torch.matmul(attn, v).transpose(1, 2).reshape(b, n, this.dh); | ||
x = this.proj.forward(x); | ||
|
||
return x.MoveToOuterDisposeScope(); | ||
} | ||
} | ||
} | ||
} |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.