We fine-tune the Segment Anything Model for human part segmentation task.
The research is proposed mainly for two aspects:
- Evaluation of SAM on downstream tasks as the first foundation model for segmentation tasks.
- Using the fine-tuned model, let collaborative robots recognize human parts and operate safely (future works).
sam_FineTune.py implements fine-tuning of the SAM mask decoder.
sam_forward.py implements batched SAM forward using torch
and segment-anything.modeling.sam.Sam
class.
sam_forward_SamPredictor.py implements unbatched SAM forward using segment-anything.SamPredictor
class.
visualize.py visualizes mask labels, SAM mask predictions, and fine-tuned SAM mask predictions.
A random single-pixel is sampled from the annotation label and prompted to the prompt encoder in utils.sam_forward.SamForward
- Install python packages.
Refer to requirements.txt or if you're using
conda
environment,
conda create -n sam python=3.10
conda activate sam
pip install git+https://github.com/facebookresearch/segment-anything.git
pip install opencv-python pycocotools matplotlib onnxruntime onnx numpy scipy
and install PyTorch.
- Move to your working directory and clone this repo.
git clone https://github.com/hyeonbeenlee/segment-anything-fine-tuning.git
cd segment-anything-fine-tuning
- Download PASCAL VOC 2010 train/val/test datasets.
curl -O http://host.robots.ox.ac.uk/pascal/VOC/voc2010/VOCtrainval_03-May-2010.tar
curl -O http://host.robots.ox.ac.uk:8080/eval/downloads/VOC2010test.tar
mkdir -p data/trainval
mkdir -p data/test
tar xvf VOCtrainval_03-May-2010.tar -C data/trainval
tar xvf VOC2010test.tar -C data/test
- Download PASCAL-Part annotations.
curl -O http://roozbehm.info/pascal-parts/trainval.tar.gz
mkdir -p data/annotations
tar xvzf trainval.tar.gz -C data/annotations
- Run data processing code dataprocess.py
python dataprocess.py
- Download pretrained ViT-H base SAM model.
curl -O https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
mkdir model
mv sam_vit_h_4b8939.pth model
Now you're good to go!
- Batch size more than 1 cause error (due to multi-prompt)
https://github.com/facebookresearch/segment-anything/issues/277d Temporarily using unit-sized batch gradient accumulation - Multiprocessing image loading not properly working
Custom dataset not implemented, which will load the entire training data to system memory at once (~300 GB).Fixed at Jul 12th 2023.
- Layerwise LR decay of 0.8
- Drop-path with rate of 0.4
- Decreasing LR with factor of 10 at iteration 60000, 86666...
- The code trains the first mask output only, therefore the last two mask outputs of multi-mask outputs are wasted.
Coded based on https://github.com/facebookresearch/segment-anything with minimal changes.
Thanks to @zuck