Skip to content

Commit c512a16

Browse files
committed
Changed List to HashSet to ensure that there are no duplicates
1 parent fc7286c commit c512a16

File tree

5 files changed

+97
-4
lines changed

5 files changed

+97
-4
lines changed

Microsoft.ML.sln

+7
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,8 @@ Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "netstandard2.0", "netstanda
114114
pkg\Microsoft.ML\build\netstandard2.0\Microsoft.ML.targets = pkg\Microsoft.ML\build\netstandard2.0\Microsoft.ML.targets
115115
EndProjectSection
116116
EndProject
117+
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.Sweeper.Tests", "test\Microsoft.ML.Sweeper.Tests\Microsoft.ML.Sweeper.Tests.csproj", "{3DEB504D-7A07-48CE-91A2-8047461CB3D4}"
118+
EndProject
117119
Global
118120
GlobalSection(SolutionConfigurationPlatforms) = preSolution
119121
Debug|Any CPU = Debug|Any CPU
@@ -216,6 +218,10 @@ Global
216218
{362A98CF-FBF7-4EBB-A11B-990BBF845B15}.Debug|Any CPU.Build.0 = Debug|Any CPU
217219
{362A98CF-FBF7-4EBB-A11B-990BBF845B15}.Release|Any CPU.ActiveCfg = Release|Any CPU
218220
{362A98CF-FBF7-4EBB-A11B-990BBF845B15}.Release|Any CPU.Build.0 = Release|Any CPU
221+
{3DEB504D-7A07-48CE-91A2-8047461CB3D4}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
222+
{3DEB504D-7A07-48CE-91A2-8047461CB3D4}.Debug|Any CPU.Build.0 = Debug|Any CPU
223+
{3DEB504D-7A07-48CE-91A2-8047461CB3D4}.Release|Any CPU.ActiveCfg = Release|Any CPU
224+
{3DEB504D-7A07-48CE-91A2-8047461CB3D4}.Release|Any CPU.Build.0 = Release|Any CPU
219225
EndGlobalSection
220226
GlobalSection(SolutionProperties) = preSolution
221227
HideSolutionNode = FALSE
@@ -253,6 +259,7 @@ Global
253259
{362A98CF-FBF7-4EBB-A11B-990BBF845B15} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
254260
{487213C9-E8A9-4F94-85D7-28A05DBBFE3A} = {DEC8F776-49F7-4D87-836C-FE4DC057D08C}
255261
{9252A8EB-ABFB-440C-AB4D-1D562753CE0F} = {487213C9-E8A9-4F94-85D7-28A05DBBFE3A}
262+
{3DEB504D-7A07-48CE-91A2-8047461CB3D4} = {AED9C836-31E3-4F3F-8ABC-929555D3F3C4}
256263
EndGlobalSection
257264
GlobalSection(ExtensibilityGlobals) = postSolution
258265
SolutionGuid = {41165AF1-35BB-4832-A189-73060F82B01D}

src/Microsoft.ML.Core/Prediction/ISweeper.cs

+5
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,11 @@ public override string ToString()
174174
{
175175
return string.Join(" ", _parameterValues.Select(kvp => string.Format("{0}={1}", kvp.Value.Name, kvp.Value.ValueText)).ToArray());
176176
}
177+
178+
public override int GetHashCode()
179+
{
180+
return _hash;
181+
}
177182
}
178183

179184
/// <summary>

src/Microsoft.ML.Sweeper/Algorithms/Grid.cs

+4-4
Original file line numberDiff line numberDiff line change
@@ -64,10 +64,10 @@ protected SweeperBase(ArgumentsBase args, IHostEnvironment env, IValueGenerator[
6464
SweepParameters = sweepParameters;
6565
}
6666

67-
public virtual ParameterSet[] ProposeSweeps(int maxSweeps, IEnumerable<IRunResult> previousRuns)
67+
public virtual ParameterSet[] ProposeSweeps(int maxSweeps, IEnumerable<IRunResult> previousRuns = null)
6868
{
6969
var prevParamSets = previousRuns?.Select(r => r.ParameterSet).ToList() ?? new List<ParameterSet>();
70-
var result = new List<ParameterSet>();
70+
var result = new HashSet<ParameterSet>();
7171
for (int i = 0; i < maxSweeps; i++)
7272
{
7373
ParameterSet paramSet;
@@ -150,12 +150,12 @@ public RandomGridSweeper(IHostEnvironment env, Arguments args, IValueGenerator[]
150150
}
151151
}
152152

153-
public override ParameterSet[] ProposeSweeps(int maxSweeps, IEnumerable<IRunResult> previousRuns)
153+
public override ParameterSet[] ProposeSweeps(int maxSweeps, IEnumerable<IRunResult> previousRuns = null)
154154
{
155155
if (_nGridPoints == 0)
156156
return base.ProposeSweeps(maxSweeps, previousRuns);
157157

158-
var result = new List<ParameterSet>();
158+
var result = new HashSet<ParameterSet>();
159159
var prevParamSets = (previousRuns != null)
160160
? previousRuns.Select(r => r.ParameterSet).ToList()
161161
: new List<ParameterSet>();
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
<Project Sdk="Microsoft.NET.Sdk">
2+
3+
<PropertyGroup>
4+
<TargetFramework>netcoreapp2.0</TargetFramework>
5+
<DefineConstants>CORECLR</DefineConstants>
6+
<IsPackable>false</IsPackable>
7+
</PropertyGroup>
8+
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|AnyCPU'">
9+
<AllowUnsafeBlocks>true</AllowUnsafeBlocks>
10+
<PlatformTarget>AnyCPU</PlatformTarget>
11+
</PropertyGroup>
12+
<ItemGroup>
13+
<ProjectReference Include="..\..\src\Microsoft.ML.Sweeper\Microsoft.ML.Sweeper.csproj" />
14+
<ProjectReference Include="..\Microsoft.ML.TestFramework\Microsoft.ML.TestFramework.csproj" />
15+
</ItemGroup>
16+
</Project>
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
using Microsoft.ML.Runtime;
2+
using Microsoft.ML.Runtime.CommandLine;
3+
using Microsoft.ML.Runtime.Data;
4+
using Microsoft.ML.Runtime.RunTests;
5+
using Microsoft.ML.Runtime.Sweeper;
6+
using System;
7+
using System.IO;
8+
using Xunit;
9+
10+
namespace Microsoft.ML.Sweeper.Tests
11+
{
12+
public class SweeperTest
13+
{
14+
[Fact]
15+
public void UniformRandomSweeperReturnsDistinctValuesWhenProposeSweep()
16+
{
17+
DiscreteValueGenerator valueGenerator = CreateDiscreteValueGenerator();
18+
19+
using (var writer = new StreamWriter(new MemoryStream()))
20+
using (var env = new TlcEnvironment(42, outWriter: writer, errWriter: writer))
21+
{
22+
var sweeper = new UniformRandomSweeper(env,
23+
new SweeperBase.ArgumentsBase(),
24+
new[] { valueGenerator });
25+
26+
var results = sweeper.ProposeSweeps(3);
27+
Assert.NotNull(results);
28+
29+
int length = results.Length;
30+
Assert.Equal(2, length);
31+
}
32+
}
33+
34+
[Fact]
35+
public void RandomGridSweeperReturnsDistinctValuesWhenProposeSweep()
36+
{
37+
DiscreteValueGenerator valueGenerator = CreateDiscreteValueGenerator();
38+
39+
using (var writer = new StreamWriter(new MemoryStream()))
40+
using (var env = new TlcEnvironment(42, outWriter: writer, errWriter: writer))
41+
{
42+
var sweeper = new RandomGridSweeper(env,
43+
new RandomGridSweeper.Arguments(),
44+
new[] { valueGenerator });
45+
46+
var results = sweeper.ProposeSweeps(3);
47+
Assert.NotNull(results);
48+
49+
int length = results.Length;
50+
Assert.Equal(2, length);
51+
}
52+
}
53+
54+
private static DiscreteValueGenerator CreateDiscreteValueGenerator()
55+
{
56+
var args = new DiscreteParamArguments()
57+
{
58+
Name = "TestParam",
59+
Values = new string[] { "one", "two" }
60+
};
61+
62+
return new DiscreteValueGenerator(args);
63+
}
64+
}
65+
}

0 commit comments

Comments
 (0)