Skip to content
This repository was archived by the owner on May 12, 2020. It is now read-only.

Commit ea588ea

Browse files
committedMay 11, 2017
Batch normalization added and training works
1 parent f404dd0 commit ea588ea

File tree

5 files changed

+29
-27
lines changed

5 files changed

+29
-27
lines changed
 

‎README.md

+4-3
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,5 @@
11
# Keras-GAN-Animeface-Character
22

3-
WORK IN PROGRESS.
4-
DOESN'T WORK YET!!
5-
63
GAN example for Keras. Cuz MNIST is too small and there
74
should an example on something more realistic.
85

@@ -21,6 +18,7 @@ should an example on something more realistic.
2118
* https://github.com/tdrussell/IllustrationGAN
2219
* I used slow implementation for the sake of simplicity. However, the correct way is:
2320
* https://ctmakro.github.io/site/on_learning/fast_gan_in_keras.html
21+
* https://github.com/shekkizh/neuralnetworks.thought-experiments/blob/master/Generative%20Models/GAN/Readme.md
2422

2523

2624
## How to run this example
@@ -101,3 +99,6 @@ What I experienced during my training of GAN.
10199
If it stays there for too long, it isn't good, I think.
102100
* In case you're seeing high G loss, it could mean it can't keep up with discriminator.
103101
You might need to increase LR. (Must be slower than discriminator though)
102+
* One final piece of the training I was missing was the parameter in BatchNormalization.
103+
I found about it in this link:
104+
https://github.com/shekkizh/neuralnetworks.thought-experiments/blob/master/Generative%20Models/GAN/Readme.md

‎args.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ class Args :
99

1010
# images size we will work on. (sz, sz, 3)
1111
sz = 64
12-
ch = 1
1312

1413
# alpha, used by leaky relu of D and G networks.
1514
alpha_D = 0.2
@@ -44,3 +43,10 @@ class Args :
4443
# Same as default in Keras, but good for GAN, says
4544
# https://github.com/gheinrich/DIGITS-GAN/blob/master/examples/weight-init/README.md#experiments-with-lenet-on-mnist
4645
kernel_initializer = 'glorot_uniform'
46+
47+
# Since DCGAN paper, everybody uses 0.5 and for me, it works the best too.
48+
# I tried 0.9, 0.1.
49+
adam_beta = 0.5
50+
51+
# BatchNormalization matters too.
52+
bn_momentum = 0.3

‎data.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,11 @@ def normalize4gan(im):
1515
Convert colorspace and
1616
cale the input in [-1, 1] range, as described in ganhacks
1717
'''
18-
im = cv2.cvtColor(im, cv2.COLOR_RGB2YCR_CB).astype(np.float32)
18+
#im = cv2.cvtColor(im, cv2.COLOR_RGB2YCR_CB).astype(np.float32)
19+
# HSV... not helpful.
20+
im = im.astype(np.float32)
1921
im /= 128.0
20-
im -= 1 # now in [-1, 1]
22+
im -= 1.0 # now in [-1, 1]
2123
return im
2224

2325

@@ -30,8 +32,7 @@ def denormalize4gan(im):
3032
'''
3133
im += 1.0 # in [0, 2]
3234
im *= 127.0 # in [0, 255]
33-
#im = cv2.cvtColor(im.astype(np.uint8), cv2.COLOR_YCR_CB2RGB)
34-
return im[:,:,0]
35+
return im.astype(np.uint8)
3536

3637

3738

‎gan.py

+8-14
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def dump_batch(imgs, cnt, ofname):
8585

8686

8787
def build_networks():
88-
shape = (Args.sz, Args.sz, Args.ch)
88+
shape = (Args.sz, Args.sz, 3)
8989

9090
# Learning rate is important.
9191
# Optimizers are important too, try experimenting them yourself to fit your dataset.
@@ -115,8 +115,8 @@ def build_networks():
115115
# now same lr, as we are using history to train D multiple times.
116116
# I don't exactly understand how decay parameter in Adam works. Certainly not exponential.
117117
# Actually faster than exponential, when I look at the code and plot it in Excel.
118-
dopt = Adam(lr=0.00005, beta_1=0.5)
119-
opt = Adam(lr=0.00005, beta_1=0.5)
118+
dopt = Adam(lr=0.0002, beta_1=Args.adam_beta)
119+
opt = Adam(lr=0.0001, beta_1=Args.adam_beta)
120120

121121
# too slow
122122
# Another thing about LR.
@@ -160,7 +160,7 @@ def train_autoenc( dataf ):
160160

161161
opt = Adam(lr=0.001)
162162

163-
shape = (Args.sz, Args.sz, Args.ch)
163+
shape = (Args.sz, Args.sz, 3)
164164
enc = build_enc( shape )
165165
enc.compile(optimizer=opt, loss='mse')
166166
enc.summary()
@@ -201,8 +201,8 @@ def load_weights(model, wf):
201201
try:
202202
model.load_weights(wf)
203203
except:
204-
print("failed to load weight", wf)
205-
raise
204+
print("failed to load weight, network changed or corrupt hdf5", wf)
205+
sys.exit(1)
206206

207207

208208

@@ -218,12 +218,6 @@ def train_gan( dataf ) :
218218
logger.on_train_begin() # initialize csv file
219219
with h5py.File( dataf, 'r' ) as f :
220220
faces = f.get( 'faces' )
221-
222-
if Args.ch == 1:
223-
faces = np.array(faces[:,:,:,0])
224-
faces = np.expand_dims(faces, 3)
225-
print("xxxxxxxxxxxxxx", faces.shape)
226-
227221
run_batches(gen, disc, gan, faces, logger, range(50000))
228222
logger.on_train_end()
229223

@@ -290,7 +284,7 @@ def end_of_batch_task(batch, gen, disc, reals, fakes):
290284
dump_batch(reals, 4, "reals.png")
291285
dump_batch(fakes, 4, "fakes.png") # to check how noisy the image is
292286
frame = gen.predict(_bits)
293-
animf = os.path.join(Args.anim_dir, "frame_{:08d}.png".format(batch))
287+
animf = os.path.join(Args.anim_dir, "frame_{:08d}.png".format(int(batch/10)))
294288
dump_batch(frame, 4, animf)
295289
dump_batch(frame, 4, "frame.png")
296290

@@ -309,7 +303,7 @@ def end_of_batch_task(batch, gen, disc, reals, fakes):
309303

310304

311305
def generate( genw, cnt ):
312-
shape = (Args.sz, Args.sz, Args.ch)
306+
shape = (Args.sz, Args.sz, 3)
313307
gen = build_gen( shape )
314308
gen.compile(optimizer='sgd', loss='mse')
315309
load_weights(gen, Args.genw)

‎nets.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def conv2d( x, filters, shape=(4, 4), **kwargs ) :
4747
kernel_initializer=Args.kernel_initializer,
4848
**kwargs )( x )
4949
#x = MaxPooling2D()( x )
50-
x = BatchNormalization()( x )
50+
x = BatchNormalization(momentum=Args.bn_momentum)( x )
5151
x = LeakyReLU(alpha=Args.alpha_D)( x )
5252
return x
5353

@@ -113,7 +113,7 @@ def deconv2d( x, filters, shape=(4, 4) ) :
113113
#x = bilinear2x( x, filters )
114114
#x = Conv2D( filters, shape, padding='same' )( x )
115115

116-
x = BatchNormalization()( x )
116+
x = BatchNormalization(momentum=Args.bn_momentum)( x )
117117
x = LeakyReLU(alpha=Args.alpha_G)( x )
118118
return x
119119

@@ -127,7 +127,7 @@ def deconv2d( x, filters, shape=(4, 4) ) :
127127

128128
x= Conv2DTranspose( 512, (4, 4),
129129
kernel_initializer=Args.kernel_initializer )(x)
130-
x = BatchNormalization()( x )
130+
x = BatchNormalization(momentum=Args.bn_momentum)( x )
131131
x = LeakyReLU(alpha=Args.alpha_G)( x )
132132
# 4x4
133133
x = deconv2d( x, 256 )
@@ -140,11 +140,11 @@ def deconv2d( x, filters, shape=(4, 4) ) :
140140
# Extra layer
141141
x = Conv2D( 64, (3, 3), padding='same',
142142
kernel_initializer=Args.kernel_initializer )( x )
143-
x = BatchNormalization()( x )
143+
x = BatchNormalization(momentum=Args.bn_momentum)( x )
144144
x = LeakyReLU(alpha=Args.alpha_G)( x )
145145
# 32x32
146146

147-
x= Conv2DTranspose( Args.ch, (4, 4), padding='same', activation='tanh',
147+
x= Conv2DTranspose( 3, (4, 4), padding='same', activation='tanh',
148148
strides=(2, 2), kernel_initializer=Args.kernel_initializer )(x)
149149
# 64x64
150150

0 commit comments

Comments
 (0)
This repository has been archived.