@@ -269,6 +269,222 @@ def get_transformer_config():
269
269
}
270
270
271
271
272
+ def get_CNN_RNN_config ():
273
+ """Configuration template for training basic CNN-RNN hybrid model."""
274
+ return {
275
+ "model" : {
276
+ # CNN Parameters
277
+ "conv_layers" : [32 , 64 , 128 ], # Number of filters in each conv layer
278
+ "kernel_size" : 3 ,
279
+ "pool_size" : 2 ,
280
+ "cnn_activation" : "relu" ,
281
+
282
+ # RNN Parameters
283
+ "rnn_type" : "LSTM" , # LSTM or GRU
284
+ "hidden_size" : 128 ,
285
+ "num_rnn_layers" : 2 ,
286
+ "bidirectional" : True ,
287
+
288
+ # General Architecture
289
+ "input_channels" : 1 ,
290
+ "dropout_rate" : 0.5 ,
291
+ "final_dense_layers" : [256 , 128 ],
292
+ "output_size" : 10 # Number of classes
293
+ },
294
+
295
+ "training" : {
296
+ "num_epochs" : 100 ,
297
+ "batch_size" : 32 ,
298
+ "shuffle" : True ,
299
+ "validation_split" : 0.2 ,
300
+ "early_stopping_patience" : 10 ,
301
+ "save_best_only" : True ,
302
+ "monitor_metric" : "val_accuracy"
303
+ },
304
+
305
+ "optimization" : {
306
+ "optimizer" : "adam" ,
307
+ "learning_rate" : 0.001 ,
308
+ "weight_decay" : 1e-4 ,
309
+ "lr_scheduler" : "reduce_on_plateau" ,
310
+ "lr_patience" : 5 ,
311
+ "lr_factor" : 0.1 ,
312
+ "min_lr" : 1e-6 ,
313
+ "clip_grad_norm" : 1.0
314
+ },
315
+
316
+ "data" : {
317
+ "dataset" : "MNIST" ,
318
+ "data_dir" : "./data" ,
319
+ "num_workers" : 4 ,
320
+ "pin_memory" : True ,
321
+ "normalize" : True ,
322
+ "augmentation" : {
323
+ "random_rotation" : 10 ,
324
+ "random_zoom" : 0.1 ,
325
+ "width_shift" : 0.1 ,
326
+ "height_shift" : 0.1
327
+ }
328
+ }
329
+ }
330
+
331
+
332
+ def get_decision_tree_config ():
333
+ return {
334
+ "model" : {
335
+ "type" : "decision_tree" ,
336
+ "input_dim" : None , # Set dynamically
337
+ "output_dim" : None , # Set dynamically
338
+ "max_depth" : 5 ,
339
+ "min_samples_split" : 2 ,
340
+ "min_samples_leaf" : 1 ,
341
+ "max_features" : "auto" ,
342
+ "criterion" : "gini" ,
343
+ "task" : "classification"
344
+ },
345
+ "training" : {
346
+ "batch_size" : 32 ,
347
+ "epochs" : 10 ,
348
+ "validation_split" : 0.2 ,
349
+ "early_stopping" : True ,
350
+ "patience" : 3
351
+ },
352
+ "preprocessing" : {
353
+ "scaling" : None , # Trees don't require scaling
354
+ "handle_missing" : "median" ,
355
+ "handle_categorical" : "label_encoding"
356
+ },
357
+ "logging" : {
358
+ "tensorboard" : True ,
359
+ "log_interval" : 10 ,
360
+ "metrics" : ["accuracy" , "precision" , "recall" , "f1" ],
361
+ "feature_importance" : True
362
+ }
363
+ }
364
+
365
+
366
+ def get_naive_bayes_config ():
367
+ return {
368
+ "model" : {
369
+ "type" : "naive_bayes" ,
370
+ "input_dim" : None , # Set dynamically based on data
371
+ "output_dim" : None , # Set dynamically based on data
372
+ "var_smoothing" : 1e-9 , # Smoothing parameter
373
+ "priors" : None , # Class priors, None for automatic
374
+ "task" : "classification"
375
+ },
376
+ "training" : {
377
+ "batch_size" : 32 ,
378
+ "epochs" : 10 ,
379
+ "learning_rate" : 0.001 ,
380
+ "early_stopping" : True ,
381
+ "patience" : 5 ,
382
+ "validation_split" : 0.2
383
+ },
384
+ "preprocessing" : {
385
+ "scaling" : "standard" ,
386
+ "handle_missing" : "mean" ,
387
+ "feature_selection" : None
388
+ },
389
+ "logging" : {
390
+ "tensorboard" : True ,
391
+ "log_interval" : 10 ,
392
+ "metrics" : ["accuracy" , "precision" , "recall" , "f1" ]
393
+ }
394
+ }
395
+
396
+
397
+ def get_Thermodynamic_Diffusion_config ():
398
+ """Configuration template for training thermodynamic diffusion models."""
399
+ return {
400
+ "model" : {
401
+ # Model Architecture
402
+ "image_size" : 28 ,
403
+ "channels" : 1 ,
404
+ "time_embedding_dim" : 256 ,
405
+ "model_channels" : 64 , # Base channel multiplier
406
+ "channel_multipliers" : [1 , 2 , 4 , 8 ], # For different U-Net levels
407
+ "num_res_blocks" : 2 , # Number of residual blocks per level
408
+ "attention_levels" : [2 , 3 ], # Which levels to apply attention
409
+ "dropout_rate" : 0.1 ,
410
+ "num_heads" : 4 , # Number of attention heads
411
+
412
+ # Diffusion Process
413
+ "num_timesteps" : 1000 ,
414
+ "beta_schedule" : "linear" , # Options: linear, cosine, quadratic
415
+ "beta_start" : 0.0001 ,
416
+ "beta_end" : 0.02 ,
417
+
418
+ # Ising Model Parameters
419
+ "temperature_schedule" : "linear" , # How temperature changes during diffusion
420
+ "initial_temperature" : 0.1 ,
421
+ "final_temperature" : 2.0 ,
422
+ "coupling_constant" : 1.0 , # J in Ising model
423
+ },
424
+
425
+ "training" : {
426
+ # Training Process
427
+ "num_epochs" : 500 ,
428
+ "save_interval" : 5000 ,
429
+ "eval_interval" : 1000 ,
430
+ "log_interval" : 100 ,
431
+ "sample_interval" : 1000 ,
432
+ "num_samples" : 64 , # Number of samples to generate during evaluation
433
+
434
+ # Loss Configuration
435
+ "loss_type" : "l2" , # Options: l1, l2, huber
436
+ "loss_weight_type" : "simple" , # Options: simple, snr, truncated
437
+
438
+ # Sampling Configuration
439
+ "sampling_steps" : 250 , # Steps for fast sampling (< num_timesteps)
440
+ "clip_samples" : True ,
441
+ "clip_range" : [- 1 , 1 ],
442
+ },
443
+
444
+ "optimization" : {
445
+ # Optimizer Configuration
446
+ "optimizer" : "AdamW" , # Options: Adam, AdamW, RMSprop
447
+ "learning_rate" : 2e-4 ,
448
+ "weight_decay" : 1e-4 ,
449
+ "eps" : 1e-8 ,
450
+ "betas" : (0.9 , 0.999 ),
451
+
452
+ # Learning Rate Scheduling
453
+ "lr_schedule" : "cosine" , # Options: cosine, linear, constant
454
+ "warmup_steps" : 5000 ,
455
+ "min_lr" : 1e-6 ,
456
+
457
+ # Gradient Configuration
458
+ "grad_clip" : 1.0 ,
459
+ "ema_decay" : 0.9999 , # Exponential moving average of model weights
460
+ "update_ema_interval" : 1 ,
461
+ },
462
+
463
+ "data" : {
464
+ # Data Configuration
465
+ "dataset" : "MNIST" , # Dataset name
466
+ "data_dir" : "./data" ,
467
+ "train_batch_size" : 128 ,
468
+ "eval_batch_size" : 256 ,
469
+
470
+ # Data Processing
471
+ "num_workers" : 4 ,
472
+ "pin_memory" : True ,
473
+ "persistence" : True ,
474
+
475
+ # Augmentation
476
+ "random_flip" : False ,
477
+ "random_rotation" : False ,
478
+ "normalize" : True ,
479
+ "rescale" : [- 1 , 1 ], # Range to rescale images to
480
+
481
+ # Caching
482
+ "cache_size" : 5000 , # Number of batches to cache in memory
483
+ "prefetch_factor" : 2 ,
484
+ }
485
+ }
486
+
487
+
272
488
def get_cnn_config ():
273
489
"""Configuration for CNN models."""
274
490
return {
0 commit comments