Skip to content

Commit 7dcdab8

Browse files
authored
Merge pull request #117 from cbovar/Dropout
Dropout fix in Flow
2 parents 077c9bb + 9baba59 commit 7dcdab8

File tree

6 files changed

+49
-15
lines changed

6 files changed

+49
-15
lines changed

src/ConvNetSharp.Flow/Net.cs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,12 @@ public T Backward(Volume<T> y)
3535
throw new NotImplementedException();
3636
}
3737

38+
/// <summary>
39+
/// Creates a dictionary containing input and evaluates the output Op
40+
/// </summary>
41+
/// <param name="input">input</param>
42+
/// <param name="isTraining">isTraining has no use here.</param>
43+
/// <returns></returns>
3844
public Volume<T> Forward(Volume<T> input, bool isTraining = false)
3945
{
4046
this._dico["input"] = input;

src/ConvNetSharp.Flow/Ops/DropoutGradient.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ public override Volume<T> Evaluate(Session<T> session)
3333

3434
var dropoutOutput = this._dropout.Evaluate(session);
3535
var dropoutInput = this._dropout.Parents[0].Evaluate(session);
36-
var dropoutInputGradient = this._dropout.Derivate.Evaluate(session);
36+
var dropoutOutputGradient = this._dropout.Derivate.Evaluate(session);
3737
var droupoutProb = this._dropout.DropoutProbability.Evaluate(session);
3838

3939
if (this.Result == null || !Equals(this._lastInputShape, dropoutInput.Shape))
@@ -44,7 +44,7 @@ public override Volume<T> Evaluate(Session<T> session)
4444
this.Result = BuilderInstance<T>.Volume.SameAs(dropoutInput.Shape);
4545
}
4646

47-
dropoutOutput.DoDropoutGradient(dropoutInput, this.Result, dropoutInputGradient, droupoutProb);
47+
dropoutOutput.DoDropoutGradient(dropoutInput, dropoutOutputGradient, this.Result, droupoutProb);
4848

4949
return base.Evaluate(session);
5050
}

src/ConvNetSharp.Flow/Session.cs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,12 @@ public void UpdatePlaceHolder(Op<T> fun, Dictionary<string, Volume<T>> dictionar
126126
{
127127
if (op is PlaceHolder<T> placeHolder)
128128
{
129-
placeHolder.SetValue(dictionary[placeHolder.Name]);
129+
if (!dictionary.TryGetValue(placeHolder.Name, out var volume))
130+
{
131+
throw new Exception($"Cannot find key '{placeHolder.Name}' in the provided dictionary");
132+
}
133+
134+
placeHolder.SetValue(volume);
130135
}
131136

132137
if (op is Variable<T> variable && variable.IsLearnable)

src/ConvNetSharp.Volume.Tests/VolumeTests.cs

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1184,19 +1184,42 @@ public void ToArray()
11841184
}
11851185
}
11861186

1187+
/// <summary>
1188+
/// Dropout should let the volume go thru if drop probability is 0
1189+
/// </summary>
1190+
[TestMethod]
1191+
public void DropoutWith0Dropprob()
1192+
{
1193+
var volume = NewVolume(RandomUtilities.RandomDoubleArray(100), new Shape(100));
1194+
var result = NewVolume(new double[100], new Shape(100));
1195+
var dropprob = (T)Convert.ChangeType(0.0, typeof(T));
1196+
1197+
// Forward
1198+
volume.DoDropout(result, dropprob);
1199+
Assert.IsTrue(volume.ToArray().SequenceEqual(result.ToArray()));
1200+
1201+
// Backward
1202+
var inputGradient = BuilderInstance<T>.Volume.SameAs(volume.Storage, volume.Shape);
1203+
var outputActivationGradient = NewVolume(new double[100].Populate(1.0), new Shape(100));
1204+
volume.DoDropoutGradient(volume, outputActivationGradient, inputGradient, dropprob);
1205+
1206+
Assert.IsTrue(inputGradient.ToArray().SequenceEqual(outputActivationGradient.ToArray()));
1207+
}
1208+
11871209
[TestMethod]
11881210
public void Dropout()
11891211
{
11901212
var volume = NewVolume(new double[100].Populate(1.0), new Shape(100));
11911213
var result = NewVolume(new double[100], new Shape(100));
11921214

1193-
var dropProb = 0.5;
1215+
var dropProb = 0.0;
11941216
volume.DoDropout(result, (T)Convert.ChangeType(dropProb, typeof(T)));
11951217

11961218
var array = result.Storage.ToArray();
11971219
var c = array.Count(o => o.Equals(Ops<T>.Zero));
1198-
Assert.IsTrue(c > 0);
1220+
Assert.IsTrue(dropProb > 0 ? c > 0 : c >= 0);
11991221

1222+
// Check magnitude scale up
12001223
var nonZeroEntry = array.First(o => !o.Equals(Ops<T>.Zero));
12011224
AssertNumber.AreEqual(1.0 / (1 - dropProb), nonZeroEntry, 1e-6);
12021225

src/ConvNetSharp.Volume/Double/Volume.cs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -266,13 +266,13 @@ public override void DoDivide(Volume<double> other, Volume<double> result)
266266

267267
public override void DoDropout(Volume<double> result, double dropProbability)
268268
{
269-
if (dropProbability > 0.0)
269+
if (((NcwhVolumeStorage<double>)this.Storage).Dropped == null || ((NcwhVolumeStorage<double>)this.Storage).Dropped.Length != this.Shape.TotalLength)
270270
{
271-
if (((NcwhVolumeStorage<double>)this.Storage).Dropped == null || ((NcwhVolumeStorage<double>)this.Storage).Dropped.Length != this.Shape.TotalLength)
272-
{
273-
((NcwhVolumeStorage<double>)this.Storage).Dropped = new bool[this.Shape.TotalLength];
274-
}
271+
((NcwhVolumeStorage<double>)this.Storage).Dropped = new bool[this.Shape.TotalLength];
272+
}
275273

274+
if (dropProbability > 0.0)
275+
{
276276
// do dropout
277277
this.Storage.Map((x, i) =>
278278
{

src/ConvNetSharp.Volume/Single/Volume.cs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -267,13 +267,13 @@ public override void DoDivide(Volume<float> other, Volume<float> result)
267267

268268
public override void DoDropout(Volume<float> result, float dropProbability)
269269
{
270-
if (dropProbability > 0.0f)
270+
if (((NcwhVolumeStorage<float>)this.Storage).Dropped == null || ((NcwhVolumeStorage<float>)this.Storage).Dropped.Length != this.Shape.TotalLength)
271271
{
272-
if (((NcwhVolumeStorage<float>)this.Storage).Dropped == null || ((NcwhVolumeStorage<float>)this.Storage).Dropped.Length != this.Shape.TotalLength)
273-
{
274-
((NcwhVolumeStorage<float>)this.Storage).Dropped = new bool[this.Shape.TotalLength];
275-
}
272+
((NcwhVolumeStorage<float>)this.Storage).Dropped = new bool[this.Shape.TotalLength];
273+
}
276274

275+
if (dropProbability > 0.0f)
276+
{
277277
// do dropout
278278
this.Storage.Map((x, i) =>
279279
{

0 commit comments

Comments
 (0)