Skip to content

Commit 3faaff3

Browse files
author
Rustam Zaitov
committed
[DigitDetection] port ViewController
1 parent 369ae20 commit 3faaff3

File tree

6 files changed

+189
-34
lines changed

6 files changed

+189
-34
lines changed

ios10/DigitDetection/DigitDetection/DrawView.cs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
namespace DigitDetection
1010
{
1111
// 2 points can give a line and this class is just for that purpose, it keeps a record of a lin
12-
class Line
12+
public class Line
1313
{
1414
public CGPoint Start { get; }
1515
public CGPoint End { get; }
@@ -49,7 +49,7 @@ UIColor Color {
4949
}
5050

5151
// we will keep touches made by user in view in these as a record so we can draw them
52-
readonly List<Line> lines = new List<Line> ();
52+
public List<Line> Lines { get; } = new List<Line> ();
5353
CGPoint lastPoint;
5454

5555
public DrawView (IntPtr handle)
@@ -66,7 +66,7 @@ public override void TouchesMoved (NSSet touches, UIEvent evt)
6666
{
6767
var newPoint = ((UITouch)touches.First ()).LocationInView (this);
6868
// keep all lines drawn by user as touch in record so we can draw them in view
69-
lines.Add (new Line (lastPoint, newPoint));
69+
Lines.Add (new Line (lastPoint, newPoint));
7070

7171
lastPoint = newPoint;
7272

@@ -82,7 +82,7 @@ public override void Draw (CGRect rect)
8282

8383
drawPath.LineCapStyle = CGLineCap.Round;
8484

85-
foreach (var line in lines) {
85+
foreach (var line in Lines) {
8686
drawPath.MoveTo (line.Start);
8787
drawPath.AddLineTo (line.End);
8888
}
@@ -93,7 +93,7 @@ public override void Draw (CGRect rect)
9393
Color.SetStroke ();
9494
}
9595

96-
public CGContext GetViewContext ()
96+
public CGBitmapContext GetViewContext ()
9797
{
9898
// our network takes in only grayscale images as input
9999
var colorSpace = CGColorSpace.CreateDeviceGray ();

ios10/DigitDetection/DigitDetection/MNISTData.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@ namespace DigitDetection
55
{
66
public class MNISTData : IDisposable
77
{
8-
byte [] labels;
9-
byte [] images;
8+
public byte [] Labels { get; private set; }
9+
public byte [] images { get; private set; } // tODO: rename
1010

1111
nuint sizeBias;
1212
nuint sizeWeights;

ios10/DigitDetection/DigitDetection/NeuralNetworkLayers/MnistDeepConvNeuralNetwork.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ public MnistDeepConvNeuralNetwork (IMTLCommandQueue commandQueueIn)
9292
/// <param name="inputImage">Image coming in on which the network will run</param>
9393
/// <param name="imageNum">If the test set is being used we will get a value between 0 and 9999 for which of the 10,000 images is being evaluated</param>
9494
/// <param name="correctLabel">The correct label for the inputImage while testing</param>
95-
public override uint Forward (MPSImage inputImage = null, int imageNum = 9999, uint correctLabel = 10)
95+
public override uint Forward (MPSImage inputImage = null, int imageNum = 9999, int correctLabel = 10)
9696
{
9797
uint label = 99;
9898

ios10/DigitDetection/DigitDetection/NeuralNetworkLayers/MnistFullLayerNeuralNetwork.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,11 @@ public class MnistFullLayerNeuralNetwork
1515
// TODO: convert protected fields to props
1616

1717
// MPSImageDescriptors for different layers outputs to be put in
18-
protected readonly MPSImageDescriptor sid = MPSImageDescriptor.GetImageDescriptor (MPSImageFeatureChannelFormat.Unorm8, 28, 28, 1);
18+
public readonly MPSImageDescriptor sid = MPSImageDescriptor.GetImageDescriptor (MPSImageFeatureChannelFormat.Unorm8, 28, 28, 1);
1919
protected readonly MPSImageDescriptor did = MPSImageDescriptor.GetImageDescriptor (MPSImageFeatureChannelFormat.Float16, 1, 1, 10);
2020

2121
// MPSImages and layers declared
22-
protected MPSImage srcImage;
22+
public MPSImage srcImage { get; protected set; } // TODO: rename
2323
protected MPSImage dstImage;
2424
MPSCnnFullyConnected layer;
2525
protected MPSCnnSoftMax softmax;
@@ -54,7 +54,7 @@ public MnistFullLayerNeuralNetwork (IMTLCommandQueue commandQueueIn)
5454
/// <param name="inputImage">Image coming in on which the network will run</param>
5555
/// <param name="imageNum">If the test set is being used we will get a value between 0 and 9999 for which of the 10,000 images is being evaluated</param>
5656
/// <param name="correctLabel">The correct label for the inputImage while testing</param>
57-
public virtual uint Forward (MPSImage inputImage = null, int imageNum = 9999, uint correctLabel = 10)
57+
public virtual uint Forward (MPSImage inputImage = null, int imageNum = 9999, int correctLabel = 10)
5858
{
5959
uint label = 99;
6060

Lines changed: 162 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,169 @@
1-
using Foundation;
2-
using System;
3-
using UIKit;
4-
5-
namespace DigitDetection
1+
using System;
2+
using System.Runtime.InteropServices;
3+
4+
using UIKit;
5+
using Foundation;
6+
using Metal;
7+
using MetalPerformanceShaders;
8+
9+
namespace DigitDetection
610
{
711
public partial class ViewController : UIViewController
812
{
9-
public ViewController (IntPtr handle)
13+
// some properties used to control the app and store appropriate values
14+
// we will start with the simple 1 layer
15+
bool deep;
16+
17+
IMTLCommandQueue commandQueue;
18+
IMTLDevice device;
19+
20+
// Networks we have
21+
MnistFullLayerNeuralNetwork neuralNetwork;
22+
MnistDeepConvNeuralNetwork neuralNetworkDeep;
23+
MnistFullLayerNeuralNetwork runningNet;
24+
25+
// loading MNIST Test Set here
26+
readonly MNISTData Mnistdata = new MNISTData ();
27+
28+
// MNIST dataset image parameters
29+
nuint mnistInputWidth = 28;
30+
int mnistInputHeight = 28;
31+
int mnistInputNumPixels = 784;
32+
33+
public ViewController (IntPtr handle)
1034
: base (handle)
1135
{
1236
}
13-
}
37+
38+
public override void ViewDidLoad ()
39+
{
40+
base.ViewDidLoad ();
41+
42+
// Load default device.
43+
device = MTLDevice.SystemDefault;
44+
45+
// Make sure the current device supports MetalPerformanceShaders.
46+
if (!MPSKernel.Supports (device)) {
47+
Console.WriteLine ("Metal Performance Shaders not Supported on current Device");
48+
return;
49+
}
50+
51+
// Create new command queue.
52+
commandQueue = device.CreateCommandQueue ();
53+
54+
// initialize the networks we shall use to detect digits
55+
neuralNetwork = new MnistFullLayerNeuralNetwork (commandQueue);
56+
neuralNetworkDeep = new MnistDeepConvNeuralNetwork (commandQueue);
57+
58+
runningNet = neuralNetwork;
59+
}
60+
61+
partial void TappedDeepButton (UIButton sender)
62+
{
63+
// switch network to be used between the deep and the single layered
64+
if (deep) {
65+
sender.SetTitle ("Use Deep Net", UIControlState.Normal);
66+
runningNet = neuralNetwork;
67+
} else {
68+
sender.SetTitle ("Use Single Layer", UIControlState.Normal);
69+
runningNet = neuralNetworkDeep;
70+
}
71+
72+
deep = !deep;
73+
}
74+
75+
partial void TappedClear (UIButton sender)
76+
{
77+
// clear the digitview
78+
DigitView.Lines.Clear ();
79+
DigitView.SetNeedsDisplay ();
80+
PredictionLabel.Hidden = true;
81+
}
82+
83+
partial void TappedTestSet (UIButton sender)
84+
{
85+
// placeholder to count number of correct detections on the test set
86+
var correctDetections = 0;
87+
var total = 10000f;
88+
AccuracyLabel.Hidden = false;
89+
90+
Atomics.Reset ();
91+
92+
// validate NeuralNetwork was initialized properly
93+
if (runningNet == null)
94+
throw new InvalidProgramException ();
95+
96+
for (int i = 0; i < total; i++) {
97+
Inference (i, Mnistdata.Labels.Length);
98+
99+
if (i % 100 == 0) {
100+
AccuracyLabel.Text = $"{i / 100}% Done";
101+
// this command helps update the UI in the loop regularly
102+
NSRunLoop.Current.RunUntil (NSRunLoopMode.Default, NSDate.DistantPast);
103+
}
104+
}
105+
// display accuracy of the network on the MNIST test set
106+
correctDetections = Atomics.GetCount ();
107+
108+
AccuracyLabel.Hidden = false;
109+
AccuracyLabel.Text = $"Accuracy = {(correctDetections * 100) / total}%";
110+
}
111+
112+
partial void TappedDetectDigit (UIButton sender)
113+
{
114+
// get the digitView context so we can get the pixel values from it to intput to network
115+
var context = DigitView.GetViewContext ();
116+
117+
// validate NeuralNetwork was initialized properly
118+
if (runningNet == null)
119+
throw new InvalidProgramException ();
120+
121+
// putting input into MTLTexture in the MPSImage
122+
var region = new MTLRegion (new MTLOrigin (0, 0, 0), new MTLSize ((nint)mnistInputWidth, mnistInputHeight, 1));
123+
runningNet.srcImage.Texture.ReplaceRegion (region,
124+
level: 0,
125+
slice: 0,
126+
pixelBytes: context.Data,
127+
bytesPerRow: mnistInputWidth,
128+
bytesPerImage: 0);
129+
// run the network forward pass
130+
var label = runningNet.Forward ();
131+
132+
// show the prediction
133+
PredictionLabel.Text = $"{label}";
134+
PredictionLabel.Hidden = false;
135+
}
136+
137+
/// <summary>
138+
/// This function runs the inference network on the test set
139+
/// </summary>
140+
/// <param name="imageNum">If the test set is being used we will get a value between 0 and 9999 for which of the 10,000 images is being evaluated</param>
141+
/// <param name="correctLabel">The correct label for the inputImage while testing</param>
142+
void Inference (int imageNum, int correctLabel)
143+
{
144+
// get the correct image pixels from the test set
145+
146+
int startIndex = imageNum * mnistInputNumPixels;
147+
var mnist_input_image = new byte [784];
148+
Array.Copy (Mnistdata.images, startIndex, mnist_input_image, 0, mnist_input_image.Length);
149+
var imageHandle = GCHandle.Alloc (mnist_input_image, GCHandleType.Pinned);
150+
var imagePtr = imageHandle.AddrOfPinnedObject ();
151+
152+
153+
// create a source image for the network to forward
154+
var inputImage = new MPSImage (device, runningNet.sid);
155+
156+
// put image in source texture (input layer)
157+
inputImage.Texture.ReplaceRegion (region: new MTLRegion (new MTLOrigin (0, 0, 0), new MTLSize ((nint)mnistInputWidth, mnistInputHeight, 1)),
158+
level: 0,
159+
slice: 0,
160+
pixelBytes: imagePtr,
161+
bytesPerRow: mnistInputWidth,
162+
bytesPerImage: 0);
163+
imageHandle.Free (); // TODO: request overload with managed array
164+
165+
// run the network forward pass
166+
runningNet.Forward (inputImage, imageNum, correctLabel);
167+
}
168+
}
14169
}

ios10/DigitDetection/DigitDetection/ViewController.designer.cs

Lines changed: 16 additions & 16 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)