@@ -153,6 +153,7 @@ def significance_test(session, target_lstm, data_loader, output_file):
153
153
# generator
154
154
generator = SeqGAN (seq_length , vocab_size , gen_emb_dim , gen_hidden_dim , start_token , oracle = True ).to_gpu ()
155
155
if args .gen :
156
+ print (args .gen )
156
157
serializers .load_hdf5 (args .gen , generator )
157
158
158
159
# discriminator
@@ -212,10 +213,9 @@ def significance_test(session, target_lstm, data_loader, output_file):
212
213
summary = sess .run (target_loss_summary , feed_dict = {loss_ : test_loss })
213
214
summary_writer .add_summary (summary , test_count )
214
215
215
- with open (os .path .join (out_dir , "models" , "gen_pretrain.model" ), "wb" ) as f :
216
- pickle .dump (generator , f )
217
- with open (os .path .join (out_dir , "models" , "gen_pretrain.opt" ), 'wb' ) as f :
218
- pickle .dump (gen_optimizer , f )
216
+ serializers .save_hdf5 (os .path .join (out_dir , "models" , "gen_pretrain.model" ), generator )
217
+ serializers .save_hdf5 (os .path .join (out_dir , "models" , "gen_pretrain.opt" ), gen_optimizer )
218
+
219
219
else :
220
220
test_count = gen_pretrain_epoch
221
221
test_loss = generator .target_loss (target_lstm , 1000 , gen_batch_size , sess )
@@ -258,8 +258,8 @@ def significance_test(session, target_lstm, data_loader, output_file):
258
258
summary_writer .add_summary (summary , dis_train_count )
259
259
summary = sess .run (dis_acc_summary , feed_dict = {loss_ : np .mean (sum_train_accuracy )})
260
260
summary_writer .add_summary (summary , dis_train_count )
261
- with open (os .path .join (out_dir , "models" , "dis_pretrain.model" ), "wb" ) as f :
262
- pickle . dump ( discriminator , f )
261
+ serializers . save_hdf5 (os .path .join (out_dir , "models" , "dis_pretrain.model" ), discriminator )
262
+ serializers . save_hdf5 ( os . path . join ( out_dir , "models" , "dis_pretrain.opt" ), dis_optimizer )
263
263
264
264
# roll out generator
265
265
rollout_generator = copy .deepcopy (generator )
@@ -275,10 +275,10 @@ def significance_test(session, target_lstm, data_loader, output_file):
275
275
print ('total batch: ' , epoch )
276
276
277
277
for step in range (g_steps ):
278
- samples = generator .generate (gen_batch_size , train = True )
278
+ samples = generator .generate (gen_batch_size , train = True , random_input = True )
279
279
rewards = rollout_generator .get_rewards (samples , discriminator , rollout_num = 16 , pool = pool , gpu = args .gpu )
280
280
print (rewards [:30 ])
281
- loss = generator .reinforcement_step (samples , rewards , g_steps = g_steps )
281
+ loss = generator .reinforcement_step (samples , rewards , g_steps = g_steps , random_input = True )
282
282
gen_optimizer .zero_grads ()
283
283
loss .backward ()
284
284
gen_optimizer .update ()
0 commit comments