Skip to content

colin2wang/python-uv-pytorch-cpu

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

7 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

PyTorch CPU MNIST Digit Recognition Project

This project demonstrates various machine learning approaches for recognizing handwritten digits from the MNIST dataset. It includes implementations using PyTorch (CNN), K-Nearest Neighbors (KNN), and XGBoost algorithms.

MNIST Example

Table of Contents

Project Overview

The MNIST dataset consists of 70,000 grayscale images of handwritten digits (0-9), each 28x28 pixels. This project showcases different machine learning techniques to classify these digits with high accuracy.

Features

  • Multiple algorithm implementations for comparison
  • Visualization of prediction results
  • Performance evaluation metrics
  • Modular code structure for easy experimentation
  • Support for both CPU and GPU training (PyTorch)
  • Automatic dataset downloading
  • Comprehensive evaluation metrics

Algorithms Implemented

1. Convolutional Neural Network (PyTorch)

A deep learning approach using convolutional layers for feature extraction and classification.

2. K-Nearest Neighbors (KNN)

A traditional machine learning approach using similarity-based classification.

3. XGBoost

A gradient boosting approach using decision trees for classification.

Requirements

  • Python 3.12+
  • uv package manager (recommended) or pip
  • At least 4GB RAM (8GB+ recommended)
  • Internet connection (for first-time dataset download)

Setup with Python uv

This project uses uv for fast dependency management. First, install uv if you haven't already:

pip install uv

Clone the repository and navigate to the project directory:

git clone <repository-url>
cd python-uv-pytorch-cpu

Install all dependencies using uv:

uv sync

Activate the virtual environment:

uv run python --version  # This will automatically activate the environment

Or alternatively, spawn a shell with the virtual environment activated:

uv run bash  # On Windows, use: uv run cmd

Alternative Setup with pip

If you prefer to use pip instead of uv:

# Create a virtual environment
python -m venv venv

# Activate it (Windows)
venv\Scripts\activate
# Activate it (macOS/Linux)
source venv/bin/activate

# Install dependencies
pip install torch torchvision scikit-learn xgboost matplotlib seaborn notebook

Usage

PyTorch CNN Implementation

  1. Train the model:

    python mnist_under_pytorch_train.py

    This will download the MNIST dataset (if not already present) and train the CNN model. The trained model will be saved to the output/ directory.

  2. Test the model:

    python mnist_under_pytorch_test.py

    This will load the trained model and evaluate its performance on the test set, showing visualizations of predictions.

KNN Implementation

Run the KNN classifier:

python mnist_under_knn.py

This will train a KNN model on a subset of the MNIST data and display accuracy metrics along with prediction visualizations.

XGBoost Implementation

Run the XGBoost classifier:

python mnist_under_xgboost.py

This will train an XGBoost model and show detailed performance metrics including a classification report and confusion matrix.

Expected Results

Typical accuracy results for each algorithm:

Algorithm Expected Accuracy Training Time
PyTorch CNN ~99% 5-15 minutes (GPU) / 30-60 minutes (CPU)
KNN ~97% < 1 minute
XGBoost ~98% 1-3 minutes

These results may vary slightly depending on random seeds, hardware, and specific implementation details.

Development Notes

  • All scripts will automatically download the MNIST dataset to the dataset/mnist directory on first run
  • GPU acceleration is supported for PyTorch (automatically detected)
  • Visualization functions display sample predictions to help evaluate model performance visually
  • Model checkpoints are saved in the output/ directory

License

This project is licensed under the MIT License - see the LICENSE file for details.

Acknowledgments

  • MNIST dataset from http://yann.lecun.com/exdb/mnist/
  • PyTorch for the deep learning framework
  • Scikit-learn for traditional ML algorithms
  • XGBoost for gradient boosting implementation

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published