Skip to content

Commit a75a67b

Browse files
committedNov 21, 2016
add free parameter and free pretrain
1 parent bfc8823 commit a75a67b

File tree

3 files changed

+32
-14
lines changed

3 files changed

+32
-14
lines changed
 

‎models/seqgan.py

+18-5
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ def __init__(self, sequence_length, vocab_size, emb_dim, hidden_dim,
1919
self.hidden_dim = hidden_dim
2020
self.sequence_length = sequence_length
2121
self.start_token = start_token
22+
self.x0 = None
2223
self.reward_gamma = reward_gamma
2324
self.g_params = []
2425
self.d_params = []
@@ -104,7 +105,11 @@ def decode_one_step(self, x, train=True):
104105
y = self.out(h)
105106
return y
106107
else:
107-
h0 = self.embed(x)
108+
if len(x.data.shape) == 2:
109+
h0 = x
110+
else:
111+
h0 = self.embed(x)
112+
108113
h = self.lstm1(h0)
109114
if hasattr(self, "lstm2"):
110115
h = self.lstm2(h)
@@ -115,13 +120,18 @@ def decode_one_step(self, x, train=True):
115120
y = self.out(h)
116121
return y
117122

118-
def generate(self, batch_size, train=False, pool=None):
123+
def generate(self, batch_size, train=False, pool=None, random_input=False):
119124
"""
120125
:return: (batch_size, self.seq_length)
121126
"""
122127

123128
self.reset_state()
124-
x = chainer.Variable(self.xp.asanyarray([self.start_token] * batch_size, 'int32'), volatile=True)
129+
if random_input:
130+
self.x0 = np.random.normal(scale=1, size=(batch_size, self.emb_dim))
131+
x = chainer.Variable(self.xp.asanyarray(self.x0, 'float32'), volatile=True)
132+
else:
133+
x = chainer.Variable(self.xp.asanyarray([self.start_token] * batch_size, 'int32'), volatile=True)
134+
125135
gen_x = np.zeros((batch_size, self.sequence_length), 'int32')
126136

127137
for i in range(self.sequence_length):
@@ -169,7 +179,7 @@ def pretrain_step(self, x_input):
169179

170180
return accum_loss / self.sequence_length
171181

172-
def reinforcement_step(self, x_input, rewards, g_steps):
182+
def reinforcement_step(self, x_input, rewards, g_steps, random_input=False):
173183
"""
174184
:param x_input: (batch_size, seq_length)
175185
:param rewards: (batch_size, seq_length)
@@ -181,7 +191,10 @@ def reinforcement_step(self, x_input, rewards, g_steps):
181191
accum_loss = 0
182192
for j in range(self.sequence_length):
183193
if j == 0:
184-
x = chainer.Variable(self.xp.asanyarray([self.start_token] * batch_size, 'int32'))
194+
if random_input:
195+
x = chainer.Variable(self.xp.asanyarray(self.x0, 'float32'))
196+
else:
197+
x = chainer.Variable(self.xp.asanyarray([self.start_token] * batch_size, 'int32'))
185198
else:
186199
x = chainer.Variable(self.xp.asanyarray(x_input[:, j - 1], 'int32'))
187200

‎models/text_cnn.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,12 @@ def __init__(self, num_classes, vocab_size,
3939

4040
def forward(self, x_input, ratio=0.5, train=True):
4141

42-
batch_size, seq_length = x_input.shape
42+
try:
43+
batch_size, seq_length = x_input.shape
44+
except:
45+
batch_size = len(x_input)
46+
seq_length = len(x_input[0])
47+
4348
x = chainer.Variable(self.xp.asarray(x_input, 'int32'))
4449

4550
# embedding

‎oracle_test/run_sequence_gan.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,7 @@ def significance_test(session, target_lstm, data_loader, output_file):
153153
# generator
154154
generator = SeqGAN(seq_length, vocab_size, gen_emb_dim, gen_hidden_dim, start_token, oracle=True).to_gpu()
155155
if args.gen:
156+
print(args.gen)
156157
serializers.load_hdf5(args.gen, generator)
157158

158159
# discriminator
@@ -212,10 +213,9 @@ def significance_test(session, target_lstm, data_loader, output_file):
212213
summary = sess.run(target_loss_summary, feed_dict={loss_: test_loss})
213214
summary_writer.add_summary(summary, test_count)
214215

215-
with open(os.path.join(out_dir, "models", "gen_pretrain.model"), "wb") as f:
216-
pickle.dump(generator, f)
217-
with open(os.path.join(out_dir, "models", "gen_pretrain.opt"), 'wb') as f:
218-
pickle.dump(gen_optimizer, f)
216+
serializers.save_hdf5(os.path.join(out_dir, "models", "gen_pretrain.model"), generator)
217+
serializers.save_hdf5(os.path.join(out_dir, "models", "gen_pretrain.opt"), gen_optimizer)
218+
219219
else:
220220
test_count = gen_pretrain_epoch
221221
test_loss = generator.target_loss(target_lstm, 1000, gen_batch_size, sess)
@@ -258,8 +258,8 @@ def significance_test(session, target_lstm, data_loader, output_file):
258258
summary_writer.add_summary(summary, dis_train_count)
259259
summary = sess.run(dis_acc_summary, feed_dict={loss_: np.mean(sum_train_accuracy)})
260260
summary_writer.add_summary(summary, dis_train_count)
261-
with open(os.path.join(out_dir, "models", "dis_pretrain.model"), "wb") as f:
262-
pickle.dump(discriminator, f)
261+
serializers.save_hdf5(os.path.join(out_dir, "models", "dis_pretrain.model"), discriminator)
262+
serializers.save_hdf5(os.path.join(out_dir, "models", "dis_pretrain.opt"), dis_optimizer)
263263

264264
# roll out generator
265265
rollout_generator = copy.deepcopy(generator)
@@ -275,10 +275,10 @@ def significance_test(session, target_lstm, data_loader, output_file):
275275
print('total batch: ', epoch)
276276

277277
for step in range(g_steps):
278-
samples = generator.generate(gen_batch_size, train=True)
278+
samples = generator.generate(gen_batch_size, train=True, random_input=True)
279279
rewards = rollout_generator.get_rewards(samples, discriminator, rollout_num=16, pool=pool, gpu=args.gpu)
280280
print(rewards[:30])
281-
loss = generator.reinforcement_step(samples, rewards, g_steps=g_steps)
281+
loss = generator.reinforcement_step(samples, rewards, g_steps=g_steps, random_input=True)
282282
gen_optimizer.zero_grads()
283283
loss.backward()
284284
gen_optimizer.update()

0 commit comments

Comments
 (0)
Please sign in to comment.