|
25 | 25 | import org.junit.jupiter.api.BeforeEach;
|
26 | 26 | import org.junit.jupiter.api.Test;
|
27 | 27 | import org.junit.jupiter.params.ParameterizedTest;
|
28 |
| -import org.junit.jupiter.params.provider.CsvSource; |
29 | 28 | import org.junit.jupiter.params.provider.ValueSource;
|
30 | 29 | import org.neo4j.gds.Orientation;
|
31 | 30 | import org.neo4j.gds.api.Graph;
|
@@ -318,33 +317,32 @@ void testConvergence() {
|
318 | 317 | assertThat(trainMetrics.ranIterationsPerEpoch()).containsExactly(2);
|
319 | 318 | }
|
320 | 319 |
|
321 |
| - @ParameterizedTest |
322 |
| - @CsvSource({ |
323 |
| - "0.01, false, 10", |
324 |
| - "1.0, true, 7" |
325 |
| - }) |
326 |
| - void batchesPerIteration(double batchSamplingRatio, boolean expectedConvergence, int expectedRanEpochs) { |
327 |
| - var trainer = new GraphSageModelTrainer( |
328 |
| - configBuilder.modelName("convergingModel:)") |
329 |
| - .maybeBatchSamplingRatio(batchSamplingRatio) |
330 |
| - .embeddingDimension(12) |
331 |
| - .aggregator(AggregatorType.POOL) |
332 |
| - .epochs(10) |
333 |
| - .tolerance(1e-10) |
334 |
| - .sampleSizes(List.of(5, 3)) |
335 |
| - .batchSize(5) |
336 |
| - .maxIterations(100) |
337 |
| - .randomSeed(42L) |
338 |
| - .build(), |
| 320 | + @Test |
| 321 | + void batchesPerIteration() { |
| 322 | + configBuilder.modelName("convergingModel:)") |
| 323 | + .embeddingDimension(2) |
| 324 | + .aggregator(AggregatorType.POOL) |
| 325 | + .epochs(10) |
| 326 | + .tolerance(1e-5) |
| 327 | + .sampleSizes(List.of(1)) |
| 328 | + .batchSize(5) |
| 329 | + .maxIterations(100) |
| 330 | + .randomSeed(42L); |
| 331 | + |
| 332 | + var trainResultWithoutSampling = new GraphSageModelTrainer( |
| 333 | + configBuilder.maybeBatchSamplingRatio(1.0).build(), |
339 | 334 | Pools.DEFAULT,
|
340 | 335 | ProgressTracker.NULL_TRACKER
|
341 |
| - ); |
| 336 | + ).train(unweightedGraph, features); |
342 | 337 |
|
343 |
| - var trainResult = trainer.train(unweightedGraph, features); |
| 338 | + var trainResultWithSampling = new GraphSageModelTrainer( |
| 339 | + configBuilder.maybeBatchSamplingRatio(0.01).build(), |
| 340 | + Pools.DEFAULT, |
| 341 | + ProgressTracker.NULL_TRACKER |
| 342 | + ).train(unweightedGraph, features); |
344 | 343 |
|
345 |
| - var trainMetrics = trainResult.metrics(); |
346 |
| - assertThat(trainMetrics.didConverge()).isEqualTo(expectedConvergence); |
347 |
| - assertThat(trainMetrics.ranEpochs()).isEqualTo(expectedRanEpochs); |
| 344 | + // reason: sampling results in more stochastic gradient descent and different losses |
| 345 | + assertThat(trainResultWithoutSampling.metrics().epochLosses().get(0)).isNotEqualTo(trainResultWithSampling.metrics().epochLosses().get(0)); |
348 | 346 | }
|
349 | 347 |
|
350 | 348 | @ParameterizedTest
|
|
0 commit comments