This project implements an advanced Convolutional Neural Network (CNN) in PyTorch to classify images from the CIFAR-10 dataset into 10 categories: airplane
, automobile
, bird
, cat
, deer
, dog
, frog
, horse
, ship
, and truck
.
- Pretrained Dataset: Utilizes CIFAR-10, a 60,000-image dataset split into 50,000 training and 10,000 test images.
- Customizable CNN Model: Includes multiple convolutional, batch normalization, pooling, and fully connected layers for accurate classification.
- GPU Support: Automatically utilizes CUDA if available for faster computation.
- Data Augmentation: Enhances training with random cropping, flipping, and normalization.
- Training Scheduler: Reduces learning rate dynamically for fine-tuning.
- Python 3.x
- Required packages:
torch
,torchvision
,tqdm
,Pillow
Install dependencies with:
pip install torch torchvision tqdm Pillow
git clone https://github.com/your-username/advanced-cnn-cifar10.git
cd advanced-cnn-cifar10
Run the script to train the model:
python advanced_cnn_cifar10.py
- Adjust
num_epochs
,batch_size
, orlearning_rate
in the script for your requirements. - Trained models are saved in the
checkpoints/
directory.
from predict import predict_image
model_path = 'checkpoints/advanced_cnn.pth'
image_path = 'path/to/your/image.jpg'
prediction = predict_image(image_path, model_path)
print(f'Predicted class: {prediction}')
from predict import predict_batch
batch_predictions = predict_batch('path/to/dataset', model_path)
print(batch_predictions)
- Achieves 85-90% accuracy on the CIFAR-10 test set after 30 epochs.
- Checkpoint and accuracy logs are generated during training.
- Modify the architecture or hyperparameters to experiment with different configurations.
- GPU is highly recommended for faster training.
For more details, check the code files: advanced_cnn_cifar10.py
(training) and predict.py
(inference).