sess = ed.get_session()
Starting evaluation...100%|█████████████████████████████████████████████| 40/40 [01:36<00:00, 2.53s/it]---------- Summary Image 001 ------------Starting evaluation...100%|█████████████████████████████████████████████| 40/40 [01:44<00:00, 2.61s/it]---------- Summary Image 002 ------------Starting evaluation...100%|█████████████████████████████████████████████| 40/40 [01:57<00:00, 3.59s/it]---------- Summary Image 003 ------------Starting evaluation...100%|█████████████████████████████████████████████| 40/40 [02:16<00:00, 3.34s/it]---------- Summary Image 004 ------------Starting evaluation...100%|█████████████████████████████████████████████| 40/40 [02:25<00:00, 3.56s/it]---------- Summary Image 005 ------------Starting evaluation...100%|█████████████████████████████████████████████| 40/40 [02:45<00:00, 4.00s/it]---------- Summary Image 006 ------------Starting evaluation...100%|█████████████████████████████████████████████| 40/40 [02:54<00:00, 4.19s/it]---------- Summary Image 007 ------------Starting evaluation...100%|█████████████████████████████████████████████| 40/40 [03:11<00:00, 4.58s/it]---------- Summary Image 008 ------------Starting evaluation...100%|████████████████████████████████████████████| 40/40 [03:26<00:00, 5.02s/it]---------- Summary Image 009 ------------Starting evaluation...100%|████████████████████████████████████████████| 40/40 [03:38<00:00, 5.58s/it]---------- Summary Image 010 ------------Starting evaluation...100%|████████████████████████████████████████████| 40/40 [03:51<00:00, 5.77s/it]
for i in range(inference_batch_size):
compare_vae_hmc_loss(model.decode_op, model.encode_op, model.discriminator_l_op,
x_ad[i:i+1], samples_to_check[:, i, :], config)
def compare_vae_hmc_loss(P, Q, DiscL, x_gt, samples_to_check, config):
print ("Starting evaluation...")
x_samples_to_check = ...
for i, sample in enumerate(tqdm(x_samples_to_check)):
for j in range(sample_to_vis):
plot_save(x_samples_to_check[j], './out/{}_mcmc_sample_{}.png'.format(img_num, j + 1))
avg_img = np.mean(x_samples_to_check, axis=0)
plot_save(avg_img, './out/{}_mcmcMean.png'.format(img_num))
r_loss = recon_loss(x_gt, sample)
l_loss = l2_loss(x_gt, sample)
lat_loss = l_latent_loss(l_th_x_gt, l_th_layer_samples[i:i+1])
total_recon_loss += r_loss
total_l2_loss += l_loss
total_latent_loss += lat_loss
if r_loss < best_recon_loss:
best_recon_sample = sample
best_recon_loss = r_loss
if l_loss < best_l2_loss:
best_l2_sample = sample
best_l2_loss = l_loss
if lat_loss < best_latent_loss:
best_latent_sample = sample
best_latent_loss = lat_lossdef l2_loss(x_gt, x_hmc):
if jernej_Q_P:
return tf.norm(x_gt - x_hmc).eval()
else:
return tf.norm(x_gt-x_hmc).eval()
def recon_loss(x_gt, x_hmc):
if jernej_Q_P:
return tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(logits=x_hmc, labels=x_gt), 1).eval()
else:
return tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(logits=x_hmc[1], labels=x_gt), 1).eval()
def l_latent_loss(l_th_x_gt, l_th_x_hmc):
return tf.norm(l_th_x_gt - l_th_x_hmc).eval()
--
You received this message because you are subscribed to the Google Groups "Discuss" group.
To unsubscribe from this group and stop receiving emails from it, send an email to discuss+unsubscribe@tensorflow.org.
To post to this group, send email to dis...@tensorflow.org.
To view this discussion on the web visit https://groups.google.com/a/tensorflow.org/d/msgid/discuss/836b4e69-1323-462b-8be3-5a00851aaa3a%40tensorflow.org.