Skip to content

Display changes and CPU compatibility #3

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 121 additions & 29 deletions datasets/cifar.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,34 @@
import os.path as osp
import pickle
import numpy as np
import os

import torch
import pandas as pd
from torch.utils.data import Dataset
from datasets import transform as T
from torchvision import transforms

from datasets.randaugment import RandomAugment
from datasets.sampler import RandomSampler, BatchSampler

class_to_idx = {
"chinee apple": 0,
"lantana": 1,
"parkinsonia": 2,
"parthenium": 3,
"prickly acacia": 4,
"rubber vine": 5,
"siam weed": 6,
"snake weed": 7,
"negative": 8,
}

def cleanup(files):
if ".DS_Store" in files:
files.remove(".DS_Store")

return files

def load_data_train(L=250, dataset='CIFAR10', dspth='./data'):
if dataset == 'CIFAR10':
@@ -126,7 +146,52 @@ def compute_mean_var():
print('mean: ', mean)
print('var: ', var)

class CustomDataset(Dataset):
def __init__(self, im_folder, label_file):
self.im_folder = im_folder
self.labels_df = pd.read_csv(label_file)
self.total_imgs = cleanup(sorted(os.listdir(self.im_folder)))
self.labels = [
class_to_idx[label.lower()] for label in self.labels_df["Species"].to_list()
]

def __len__(self):
return len(self.total_imgs)

def __getitem__(self, idx):
img_loc = os.path.join(self.im_folder, self.total_imgs[idx])
image = Image.open(img_loc).convert("RGB")
label = torch.as_tensor(self.labels[idx])
return image, label

class DeepWeeds(CustomDataset):
def __init__(self, im_folder, label_file, is_train=True):
super(DeepWeeds, self).__init__(im_folder, label_file)
self.is_train = is_train
mean, std = (0.485, 0.456, 0.406), (0.229, 0.224, 0.225)

if is_train:
self.trans_weak = T.Compose([
T.Resize((224, 224)),
T.PadandRandomCrop(border=4, cropsize=(224, 224)),
T.RandomHorizontalFlip(p=0.5),
T.Normalize(mean, std),
T.ToTensor(),
])
self.trans_strong = T.Compose([
T.Resize((224, 224)),
T.PadandRandomCrop(border=4, cropsize=(224, 224)),
T.RandomHorizontalFlip(p=0.5),
RandomAugment(2, 10),
T.Normalize(mean, std),
T.ToTensor(),
])
else:
self.trans = T.Compose([
T.Resize((224, 224)),
T.Normalize(mean, std),
T.ToTensor(),
])

class Cifar(Dataset):
def __init__(self, dataset, data, labels, is_train=True):
@@ -175,36 +240,63 @@ def __len__(self):


def get_train_loader(dataset, batch_size, mu, n_iters_per_epoch, L, root='data'):
data_x, label_x, data_u, label_u = load_data_train(L=L, dataset=dataset, dspth=root)
if dataset.startswith("CIFAR"):
data_x, label_x, data_u, label_u = load_data_train(L=L, dataset=dataset, dspth=root)

ds_x = Cifar(
dataset=dataset,
data=data_x,
labels=label_x,
is_train=True
) # return an iter of num_samples length (all indices of samples)
sampler_x = RandomSampler(ds_x, replacement=True, num_samples=n_iters_per_epoch * batch_size)
batch_sampler_x = BatchSampler(sampler_x, batch_size, drop_last=True) # yield a batch of samples one time
dl_x = torch.utils.data.DataLoader(
ds_x,
batch_sampler=batch_sampler_x,
num_workers=2,
pin_memory=True
)
ds_u = Cifar(
dataset=dataset,
data=data_u,
labels=label_u,
is_train=True
)
sampler_u = RandomSampler(ds_u, replacement=True, num_samples=mu * n_iters_per_epoch * batch_size)
batch_sampler_u = BatchSampler(sampler_u, batch_size * mu, drop_last=True)
dl_u = torch.utils.data.DataLoader(
ds_u,
batch_sampler=batch_sampler_u,
num_workers=2,
pin_memory=True
)
else:
ds = DeepWeeds("/home/ubuntu/Home/data/blueriver/DeepWeeds/images", "/home/ubuntu/Home/data/blueriver/DeepWeeds/labels/labels.csv", is_train=True)
size = len(ds)

ds_u, ds_x = torch.utils.data.random_split(ds, [round(size * 0.05), round(size * 0.95)])

sampler_x = RandomSampler(ds_x, replacement=True, num_samples=n_iters_per_epoch * batch_size)
batch_sampler_x = BatchSampler(sampler_x, batch_size, drop_last=True)

dl_x = torch.utils.data.DataLoader(
ds_x,
batch_sampler=batch_sampler_x,
num_workers=2,
pin_memory=True
)

sampler_u = RandomSampler(ds_u, replacement=True, num_samples=mu * n_iters_per_epoch * batch_size)
batch_sampler_u = BatchSampler(sampler_u, batch_size * mu, drop_last=True)

dl_u = torch.utils.data.DataLoader(
ds_u,
batch_sampler=batch_sampler_u,
num_workers=2,
pin_memory=True
)

ds_x = Cifar(
dataset=dataset,
data=data_x,
labels=label_x,
is_train=True
) # return an iter of num_samples length (all indices of samples)
sampler_x = RandomSampler(ds_x, replacement=True, num_samples=n_iters_per_epoch * batch_size)
batch_sampler_x = BatchSampler(sampler_x, batch_size, drop_last=True) # yield a batch of samples one time
dl_x = torch.utils.data.DataLoader(
ds_x,
batch_sampler=batch_sampler_x,
num_workers=2,
pin_memory=True
)
ds_u = Cifar(
dataset=dataset,
data=data_u,
labels=label_u,
is_train=True
)
sampler_u = RandomSampler(ds_u, replacement=True, num_samples=mu * n_iters_per_epoch * batch_size)
batch_sampler_u = BatchSampler(sampler_u, batch_size * mu, drop_last=True)
dl_u = torch.utils.data.DataLoader(
ds_u,
batch_sampler=batch_sampler_u,
num_workers=2,
pin_memory=True
)
return dl_x, dl_u


Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
24 changes: 13 additions & 11 deletions train.py
Original file line number Diff line number Diff line change
@@ -7,6 +7,7 @@
import sys

import numpy as np
from tqdm import tqdm

import torch
import torch.nn as nn
@@ -21,7 +22,8 @@
from utils import accuracy, setup_default_logging, interleave, de_interleave

from utils import AverageMeter

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print("Using device: {}".format(device))

# os.environ['CUDA_VISIBLE_DEVICES'] = '1'

@@ -31,9 +33,9 @@ def set_model(args):
k=args.wresnet_k, n=args.wresnet_n) # wresnet-28-2

model.train()
model.cuda()
criteria_x = nn.CrossEntropyLoss().cuda()
criteria_u = nn.CrossEntropyLoss(reduction='none').cuda()
model.to(device)
criteria_x = nn.CrossEntropyLoss().to(device)
criteria_u = nn.CrossEntropyLoss(reduction='none').to(device)
return model, criteria_x, criteria_u


@@ -65,18 +67,18 @@ def train_one_epoch(epoch,

epoch_start = time.time() # start time
dl_x, dl_u = iter(dltrain_x), iter(dltrain_u)
for it in range(n_iters):
for it in tqdm(range(n_iters), desc='Epoch {}'.format(epoch)):
ims_x_weak, ims_x_strong, lbs_x = next(dl_x)
ims_u_weak, ims_u_strong, lbs_u_real = next(dl_u)

lbs_x = lbs_x.cuda()
lbs_u_real = lbs_u_real.cuda()
lbs_x = lbs_x.to(device)
lbs_u_real = lbs_u_real.to(device)

# --------------------------------------

bt = ims_x_weak.size(0)
mu = int(ims_u_weak.size(0) // bt)
imgs = torch.cat([ims_x_weak, ims_u_weak, ims_u_strong], dim=0).cuda()
imgs = torch.cat([ims_x_weak, ims_u_weak, ims_u_strong], dim=0).to(device)
imgs = interleave(imgs, 2 * mu + 1)
logits = model(imgs)
logits = de_interleave(logits, 2 * mu + 1)
@@ -145,7 +147,7 @@ def evaluate(ema, dataloader, criterion):
# using EMA params to evaluate performance
ema.apply_shadow()
ema.model.eval()
ema.model.cuda()
ema.model.to(device)

loss_meter = AverageMeter()
top1_meter = AverageMeter()
@@ -154,8 +156,8 @@ def evaluate(ema, dataloader, criterion):
# matches = []
with torch.no_grad():
for ims, lbs in dataloader:
ims = ims.cuda()
lbs = lbs.cuda()
ims = ims.to(device)
lbs = lbs.to(device)
logits = ema.model(ims)
loss = criterion(logits, lbs)
scores = torch.softmax(logits, dim=1)