diff --git a/squeeze.py b/squeeze.py index 59c3563..6df7439 100755 --- a/squeeze.py +++ b/squeeze.py @@ -20,8 +20,9 @@ def reduce_precision_tf(x, npp): """ Reduce the precision of image, the tensorflow version. """ + x_clip = tf.clip_by_value(x, 0., 1.) npp_int = npp - 1 - x_int = tf.rint(tf.multiply(x, npp_int)) + x_int = tf.rint(tf.multiply(x_clip, npp_int)) x_float = tf.div(x_int, npp_int) return x_float diff --git a/utils/visualization.py b/utils/visualization.py index c612447..1c1de77 100755 --- a/utils/visualization.py +++ b/utils/visualization.py @@ -2,23 +2,27 @@ from PIL import Image import numpy as np import pdb +import copy as cp from squeeze import median_filter_np, binary_filter_np IMAGE_SIZE = 28 + filter_m = lambda x: median_filter_np(x, 3) filter_b = lambda x: binary_filter_np(x) filter_mb = lambda x: filter_b(filter_m(x)) filter_bm = lambda x: filter_m(filter_b(x)) + def show_img(pixel_array, mode=None): img = Image.fromarray(pixel_array*255, mode=mode) img.show() -def show_imgs(imgs, width_num=10, height_num=10, x_margin=2, y_margin=2, fpath='/tmp/test.png'): - total_width = width_num * IMAGE_SIZE + (width_num-1)*x_margin - total_height = height_num * IMAGE_SIZE + (height_num-1)*y_margin +adv_x_dict_draw = cp.deepcopy(adv_x_dict)def show_imgs(imgs, width_num=10, height_num=10, x_margin=2, y_margin=2, fpath='/tmp/test.png'): + total_width = int(width_num * IMAGE_SIZE + (width_num-1)*x_margin) + total_height = int(height_num * IMAGE_SIZE + (height_num-1)*y_margin) + new_im = Image.new('RGB', (total_width, total_height), (255,255,255)) @@ -53,17 +57,18 @@ def get_first_example_id_each_class(Y_test): def draw_fgsm_adv_examples(adv_x_dict, Y_test, fpath): + adv_x_dict_draw = cp.deepcopy(adv_x_dict) eps_list = [0,0.1,0.2,0.3,0.4,0.5] width_num=10 selected_example_idx = get_first_example_id_each_class(Y_test) imgs = [] for eps in eps_list: - adv_x_dict[eps] = adv_x_dict[eps][selected_example_idx,:] + adv_x_dict_draw[eps] = adv_x_dict_draw[eps][selected_example_idx,:] for eps in eps_list: - imgs += list(adv_x_dict[eps]) - imgs += list(filter_b(adv_x_dict[eps])) + imgs += list(adv_x_dict_draw[eps]) + imgs += list(filter_b(adv_x_dict_draw[eps])) show_imgs(imgs, width_num=width_num, height_num=len(imgs)/width_num, fpath=fpath)