Skip to content

Commit 795ae2e

Browse files
committed
Make GraphSage test more robust
Before it failed on MacOS Discovered by Iavkan in #199
1 parent 632893b commit 795ae2e

File tree

1 file changed

+22
-24
lines changed

1 file changed

+22
-24
lines changed

algo/src/test/java/org/neo4j/gds/embeddings/graphsage/GraphSageModelTrainerTest.java

Lines changed: 22 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
import org.junit.jupiter.api.BeforeEach;
2626
import org.junit.jupiter.api.Test;
2727
import org.junit.jupiter.params.ParameterizedTest;
28-
import org.junit.jupiter.params.provider.CsvSource;
2928
import org.junit.jupiter.params.provider.ValueSource;
3029
import org.neo4j.gds.Orientation;
3130
import org.neo4j.gds.api.Graph;
@@ -318,33 +317,32 @@ void testConvergence() {
318317
assertThat(trainMetrics.ranIterationsPerEpoch()).containsExactly(2);
319318
}
320319

321-
@ParameterizedTest
322-
@CsvSource({
323-
"0.01, false, 10",
324-
"1.0, true, 7"
325-
})
326-
void batchesPerIteration(double batchSamplingRatio, boolean expectedConvergence, int expectedRanEpochs) {
327-
var trainer = new GraphSageModelTrainer(
328-
configBuilder.modelName("convergingModel:)")
329-
.maybeBatchSamplingRatio(batchSamplingRatio)
330-
.embeddingDimension(12)
331-
.aggregator(AggregatorType.POOL)
332-
.epochs(10)
333-
.tolerance(1e-10)
334-
.sampleSizes(List.of(5, 3))
335-
.batchSize(5)
336-
.maxIterations(100)
337-
.randomSeed(42L)
338-
.build(),
320+
@Test
321+
void batchesPerIteration() {
322+
configBuilder.modelName("convergingModel:)")
323+
.embeddingDimension(2)
324+
.aggregator(AggregatorType.POOL)
325+
.epochs(10)
326+
.tolerance(1e-5)
327+
.sampleSizes(List.of(1))
328+
.batchSize(5)
329+
.maxIterations(100)
330+
.randomSeed(42L);
331+
332+
var trainResultWithoutSampling = new GraphSageModelTrainer(
333+
configBuilder.maybeBatchSamplingRatio(1.0).build(),
339334
Pools.DEFAULT,
340335
ProgressTracker.NULL_TRACKER
341-
);
336+
).train(unweightedGraph, features);
342337

343-
var trainResult = trainer.train(unweightedGraph, features);
338+
var trainResultWithSampling = new GraphSageModelTrainer(
339+
configBuilder.maybeBatchSamplingRatio(0.01).build(),
340+
Pools.DEFAULT,
341+
ProgressTracker.NULL_TRACKER
342+
).train(unweightedGraph, features);
344343

345-
var trainMetrics = trainResult.metrics();
346-
assertThat(trainMetrics.didConverge()).isEqualTo(expectedConvergence);
347-
assertThat(trainMetrics.ranEpochs()).isEqualTo(expectedRanEpochs);
344+
// reason: sampling results in more stochastic gradient descent and different losses
345+
assertThat(trainResultWithoutSampling.metrics().epochLosses().get(0)).isNotEqualTo(trainResultWithSampling.metrics().epochLosses().get(0));
348346
}
349347

350348
@ParameterizedTest

0 commit comments

Comments
 (0)