This repository is supposed to be a place for curated, high quality benchmarks of Graph Neural Networks, implemented with PyTorch Lightning and Hydra.
Only datasets big enough to provide good measures are taken into consideration.
Built with lightning-hydra-template.
- Open Graph Benchmarks (graph property prediction)
- Image classification from superpixels (MNIST, FashionMNIST, CIFAR10)
Install dependencies
# clone project
git clone https://github.com/ashleve/graph_classification
cd graph_classification
# create conda environment
bash setup_conda.sh
conda activate env_nameTrain model with default configuration
# default
python run.py
# train on CPU
python run.py trainer.gpus=0
# train on GPU
python run.py trainer.gpus=1Train model with chosen experiment configuration from configs/experiment/
python run.py experiment=GAT/gat_ogbg_molpcba
python run.py experiment=GraphSAGE/graphsage_mnist_sp75
python run.py experiment=GraphSAGE/graphsage_cifar10_sp100You can override any parameter from command line like this
python run.py trainer.max_epochs=20 datamodule.batch_size=64Coming soon...
| Architecture | MNIST-sp75 | FashionMNIST-sp75 | CIFAR10-sp100 | ogbg-molhiv | ogbg-molcpba |
|---|---|---|---|---|---|
| GCN | 0.955 ± 0.014 | 0.835 ± 0.016 | 0.518 ± 0.007 | 0.755 ± 0.019 | 0.231 ± 0.003 |
| GIN | 0.966 ± 0.008 | 0.861 ± 0.012 | 0.512 ± 0.020 | 0.757 ± 0.025 | 0.240 ± 0.001 |
| GAT | 0.976 ± 0.008 | 0.889 ± 0.003 | 0.617 ± 0.005 | 0.751 ± 0.026 | 0.234 ± 0.003 |
| GraphSAGE | 0.981 ± 0.005 | 0.897 ± 0.012 | 0.629 ± 0.012 | 0.761 ± 0.025 | 0.256 ± 0.004 |