diff --git a/glass.py b/glass.py index b9c3a32..b27be04 100644 --- a/glass.py +++ b/glass.py @@ -66,6 +66,8 @@ def load( svd=0, step=20, limit=392, + es_epoch=10, + tta=False, **kwargs, ): @@ -73,6 +75,9 @@ def load( self.layers_to_extract_from = layers_to_extract_from self.input_shape = input_shape self.device = device + self.es_epoch = es_epoch + print(f"early stopping epochs: {self.es_epoch}") + self.tta = tta self.forward_modules = torch.nn.ModuleDict({}) feature_aggregator = common.NetworkFeatureAggregator( @@ -199,8 +204,9 @@ def trainer(self, training_data, val_data, name): ckpt_path = glob.glob(self.ckpt_dir + '/ckpt_best*') ckpt_path_save = os.path.join(self.ckpt_dir, "ckpt.pth") if len(ckpt_path) != 0: - LOGGER.info("Start testing, ckpt file found!") - return 0., 0., 0., 0., 0., -1. + # LOGGER.info("Start testing, ckpt file found!") + # return 0., 0., 0., 0., 0., -1. + LOGGER.info("Ckpt file found, retrain!") def update_state_dict(): state_dict["discriminator"] = OrderedDict({ @@ -258,6 +264,10 @@ def update_state_dict(): pbar_str1 = "" best_record = None + best_score = -1 + epoch_counter = 0 + best_state = None + with mlflow.start_run(): mlflow.log_param("meta_epochs", self.meta_epochs) mlflow.log_param("eval_epochs", self.eval_epochs) @@ -293,31 +303,25 @@ def update_state_dict(): if (i_epoch + 1) % self.eval_epochs == 0: images, scores, segmentations, labels_gt, masks_gt = self.predict(val_data) - image_auroc, image_ap, pixel_auroc, pixel_ap, pixel_pro, img_threshold, img_f1_max = self._evaluate(images, scores, segmentations, + image_auroc, image_ap, img_threshold, img_f1_max = self._evaluate(images, scores, segmentations, labels_gt, masks_gt, name) mlflow.log_metric("img_auroc", image_auroc, step=i_epoch) - mlflow.log_metric("pixel_auroc", pixel_auroc, step=i_epoch) mlflow.log_metric("img_threshold", img_threshold, step=i_epoch) mlflow.log_metric("img_f1_max", img_f1_max, step=i_epoch) - # self.logger.logger.add_scalar("i-auroc", image_auroc, i_epoch) - # self.logger.logger.add_scalar("i-ap", image_ap, i_epoch) - # self.logger.logger.add_scalar("p-auroc", pixel_auroc, i_epoch) - # self.logger.logger.add_scalar("p-ap", pixel_ap, i_epoch) - # self.logger.logger.add_scalar("p-pro", pixel_pro, i_epoch) - eval_path = './results/eval/' + name + '/' train_path = './results/training/' + name + '/' if best_record is None: - best_record = [image_auroc, image_ap, pixel_auroc, pixel_ap, pixel_pro, img_f1_max, i_epoch] + best_record = [image_auroc, image_ap, img_f1_max, i_epoch] ckpt_path_best = os.path.join(self.ckpt_dir, "ckpt_best_{}.pth".format(i_epoch)) torch.save(state_dict, ckpt_path_best) shutil.rmtree(eval_path, ignore_errors=True) shutil.copytree(train_path, eval_path) - elif image_auroc + pixel_auroc > best_record[0] + best_record[2]: - best_record = [image_auroc, image_ap, pixel_auroc, pixel_ap, pixel_pro, img_f1_max, i_epoch] + # elif image_auroc + pixel_auroc > best_record[0] + best_record[2]: + elif image_auroc > best_record[0]: + best_record = [image_auroc, image_ap, img_f1_max, i_epoch] os.remove(ckpt_path_best) ckpt_path_best = os.path.join(self.ckpt_dir, "ckpt_best_{}.pth".format(i_epoch)) torch.save(state_dict, ckpt_path_best) @@ -325,12 +329,24 @@ def update_state_dict(): shutil.copytree(train_path, eval_path) pbar_str1 = f" IAUC:{round(image_auroc * 100, 2)}({round(best_record[0] * 100, 2)})" \ - f" PAUC:{round(pixel_auroc * 100, 2)}({round(best_record[2] * 100, 2)})" \ - f" IF1-max:{round(img_f1_max * 100, 2)}({round(best_record[5] * 100, 2)})" \ + f" PAUC: Do not have PAUC)" \ + f" IF1-max:{round(img_f1_max * 100, 2)}({round(best_record[2] * 100, 2)})" \ f" E:{i_epoch}({best_record[-1]})" + pbar_str += pbar_str1 pbar.set_description_str(pbar_str) + # current_score = image_auroc*1 + pixel_auroc * 0 + current_score = image_auroc*0.5 + img_f1_max*0.5 + if current_score - best_score > 0.1: + best_score = current_score + epoch_counter = 0 + else: + epoch_counter += 1 + if epoch_counter > self.es_epoch: + LOGGER.info(f"Early stopping triggered at epoch {i_epoch}") + break + torch.save(state_dict, ckpt_path_save) return best_record @@ -362,7 +378,7 @@ def _train_discriminator(self, input_data, cur_epoch, pbar, pbar_str1): true_feats = self._embed(img, evaluation=False)[0] true_feats.requires_grad = True - mask_s_gt = data_item["mask_s"].reshape(-1, 1).to(self.device) + mask_s_gt = data_item["mask_s"].reshape(-1, 1).to(self.device) # feature noise = torch.normal(0, self.noise, true_feats.shape).to(self.device) gaus_feats = true_feats + noise @@ -502,15 +518,21 @@ def tester(self, test_data, name): else: self.load_state_dict(state_dict, strict=False) - images, scores, segmentations, labels_gt, masks_gt = self.predict(test_data) - image_auroc, image_ap, pixel_auroc, pixel_ap, pixel_pro, img_threshold, img_f1_max = self._evaluate(images, scores, segmentations, - labels_gt, masks_gt, name, path='eval') + if self.tta: + print("Do tta") + images, scores, segmentations, labels_gt, masks_gt = self.tta_predict(test_data) + image_auroc, image_ap, img_threshold, img_f1_max = self.tta_evaluate(images, scores, segmentations, + labels_gt, masks_gt, name, path='eval') + else: + images, scores, segmentations, labels_gt, masks_gt = self.predict(test_data) + image_auroc, image_ap, img_threshold, img_f1_max = self._evaluate(images, scores, segmentations, + labels_gt, masks_gt, name, path='eval') epoch = int(ckpt_path[0].split('_')[-1].split('.')[0]) else: - image_auroc, image_ap, pixel_auroc, pixel_ap, pixel_pro, img_threshold, img_f1_max, epoch = 0., 0., 0., 0., 0., 0., 0., -1. + image_auroc, image_ap, img_threshold, img_f1_max, epoch = 0., 0., 0., 0., -1. LOGGER.info("No ckpt file found!") - return image_auroc, image_ap, pixel_auroc, pixel_ap, pixel_pro, img_threshold, img_f1_max, epoch + return image_auroc, image_ap, img_threshold, img_f1_max, epoch def _evaluate(self, images, scores, segmentations, labels_gt, masks_gt, name, path='training'): scores = np.squeeze(np.array(scores)) @@ -530,42 +552,26 @@ def _evaluate(self, images, scores, segmentations, labels_gt, masks_gt, name, pa max_scores = np.max(segmentations) norm_segmentations = (segmentations - min_scores) / (max_scores - min_scores + 1e-10) - pixel_scores = metrics.compute_pixelwise_retrieval_metrics(norm_segmentations, masks_gt, path) - pixel_auroc = pixel_scores["auroc"] - pixel_ap = pixel_scores["ap"] - if path == 'eval': - try: - pixel_pro = metrics.compute_pro(np.squeeze(np.array(masks_gt)), norm_segmentations) - - except: - pixel_pro = 0. - else: - pixel_pro = 0. - - else: - pixel_auroc = -1. - pixel_ap = -1. - pixel_pro = -1. - return image_auroc, image_ap, pixel_auroc, pixel_ap, pixel_pro,img_threshold, img_f1_max - defects = np.array(images) - targets = np.array(masks_gt) + # targets = np.array(masks_gt) for i in range(len(defects)): defect = utils.torch_format_2_numpy_img(defects[i]) - target = utils.torch_format_2_numpy_img(targets[i]) + # target = utils.torch_format_2_numpy_img(targets[i]) mask = cv2.cvtColor(cv2.resize(norm_segmentations[i], (defect.shape[1], defect.shape[0])), cv2.COLOR_GRAY2BGR) mask = (mask * 255).astype('uint8') mask = cv2.applyColorMap(mask, cv2.COLORMAP_JET) - img_up = np.hstack([defect, target, mask]) - img_up = cv2.resize(img_up, (256 * 3, 256)) + # img_up = np.hstack([defect, target, mask]) + # img_up = cv2.resize(img_up, (256 * 3, 256)) + img_up = np.hstack([defect, mask]) + img_up = cv2.resize(img_up, (256 * 2, 256)) full_path = './results/' + path + '/' + name + '/' utils.del_remake_dir(full_path, del_flag=False) cv2.imwrite(full_path + str(i + 1).zfill(3) + '.png', img_up) - return image_auroc, image_ap, pixel_auroc, pixel_ap, pixel_pro, img_threshold, img_f1_max + return image_auroc, image_ap, img_threshold, img_f1_max def predict(self, test_dataloader): """This function provides anomaly scores/maps for full dataloaders.""" @@ -604,7 +610,6 @@ def _predict(self, img): self.discriminator.eval() with torch.no_grad(): - patch_features, patch_shapes = self._embed(img, provide_patch_shapes=True, evaluation=True) if self.pre_proj > 0: patch_features = self.pre_projection(patch_features) @@ -622,3 +627,152 @@ def _predict(self, img): image_scores = image_scores.cpu().numpy() return list(image_scores), list(masks) + + + def tta_evaluate(self, images, scores, segmentations, labels_gt, masks_gt, name, path='training'): + scores = np.squeeze(np.array(scores)) + img_min_scores = min(scores) + img_max_scores = max(scores) + norm_scores = (scores - img_min_scores) / (img_max_scores - img_min_scores + 1e-10) + + image_scores = metrics.compute_imagewise_retrieval_metrics(norm_scores, labels_gt, path) + image_auroc = image_scores["auroc"] + image_ap = image_scores["ap"] + + img_threshold, img_f1_max = metrics.compute_best_pr_re(labels_gt, norm_scores) + + if len(masks_gt) > 0: + segmentations = np.array(segmentations) + min_scores = np.min(segmentations) + max_scores = np.max(segmentations) + norm_segmentations = (segmentations - min_scores) / (max_scores - min_scores + 1e-10) + + # ============== DEBUG ============== + print("\n[DEBUG] Data shape validation:") + print("1. Original segmentation result dtype:", segmentations.dtype, "shape:", segmentations.shape) + print("2. Normalized dtype:", norm_segmentations.dtype, "range:", np.min(norm_segmentations), "-", np.max(norm_segmentations)) + + norm_segmentations = norm_segmentations.astype(np.float32) # Fix 1:convert to float32 + print("3. Converted dtype:", norm_segmentations.dtype) + + defects = np.array(images) + targets = np.array(masks_gt) + for i in range(len(defects)): + if i == 0: # print debug info only for the first sample + print("\n[DEBUG] Sample processing pipeline validation (i=0):") + + defect = utils.torch_format_2_numpy_img(defects[i]) + target = utils.torch_format_2_numpy_img(targets[i]) + + # Maintain single channel during resizing + resized_mask = cv2.resize( + norm_segmentations[i], + (defect.shape[1], defect.shape[0]), + interpolation=cv2.INTER_LINEAR + ) + if i == 0: + print("4. Resized mask shape:", resized_mask.shape, "dtype:", resized_mask.dtype) + + # Convert to 0-255 and uint8 + mask_8bit = (resized_mask * 255).astype(np.uint8) + if i == 0: + print("5. Converted to uint8 range:", np.min(mask_8bit), "-", np.max(mask_8bit), "dtype:", mask_8bit.dtype) + + # Color sapce conversion + try: + mask_color = cv2.cvtColor(mask_8bit, cv2.COLOR_GRAY2BGR) + if i == 0: + print("6. Converted color shape:", mask_color.shape) + except Exception as e: + print(f"\n[ERROR] Color conversion failed!") + print("error mask_8bit parameters:", f"shape:{mask_8bit.shape}", f"dtype:{mask_8bit.dtype}") + raise e + + # Apply color mapping + mask_color = cv2.applyColorMap(mask_color, cv2.COLORMAP_JET) + if i == 0: + print("7. Applied color mapping shape:", mask_color.shape, "dtype:", mask_color.dtype) + + # Concatenate result images + img_up = np.hstack([defect, target, mask_color]) + img_up = cv2.resize(img_up, (256 * 3, 256)) + full_path = './results/' + path + '/' + name + '/' + utils.del_remake_dir(full_path, del_flag=False) + cv2.imwrite(full_path + str(i + 1).zfill(3) + '.png', img_up) + + return image_auroc, image_ap, img_threshold, img_f1_max + + def tta_predict(self, test_dataloader): + """This function provides anomaly scores/maps for full dataloaders.""" + self.forward_modules.eval() + + img_paths = [] + images = [] + scores = [] + masks = [] + labels_gt = [] + masks_gt = [] + + with tqdm.tqdm(test_dataloader, desc="Inferring...", leave=False, unit='batch') as data_iterator: + for data in data_iterator: + if isinstance(data, dict): + labels_gt.extend(data["is_anomaly"].numpy().tolist()) + if data.get("mask_gt", None) is not None: + masks_gt.extend(data["mask_gt"].numpy().tolist()) + image = data["image"] + images.extend(image.numpy().tolist()) + img_paths.extend(data["image_path"]) + _scores, _masks = self.tta__predict(image) + for score, mask in zip(_scores, _masks): + scores.append(score) + masks.append(mask) + + return images, scores, masks, labels_gt, masks_gt + + def tta__predict(self, img): + """Infer score and mask for a batch of images with TTA""" + self.forward_modules.eval() + if self.pre_proj > 0: + self.pre_projection.eval() + self.discriminator.eval() + + tta_transforms = [ + {'name': 'original', + 'transform': lambda x: x, + 'reverse': lambda x: x}, + + {'name': 'h_flip', + 'transform': lambda x: x.flip(-1), + 'reverse': lambda x: x.flip(-1)}, + + {'name': 'v_flip', + 'transform': lambda x: x.flip(-2), + 'reverse': lambda x: x.flip(-2)}, + + # {'name': 'rotate90', + # 'transform': lambda x: x.rot90(1, [-2, -1]).flip(-1), + # 'reverse': lambda x: x.flip(-1).rot90(-1, [-2, -1])}, + + # {'name': 'color_jitter', + # 'transform': lambda x: x * (0.9 + 0.2*torch.rand(1,device=x.device)) + 0.1*torch.randn_like(x), + # 'reverse': lambda x: x}, + ] + + all_scores = [] + all_masks = [] + + with torch.no_grad(): + for aug in tta_transforms: + transformed_img = aug['transform'](img) + _scores, _masks = self._predict(transformed_img) + + mask_tensor = torch.tensor(np.array(_masks)) + reversed_masks = aug['reverse'](torch.tensor(_masks)) + + all_scores.append(_scores) + all_masks.append(reversed_masks.numpy()) + + avg_scores = np.mean(all_scores, axis=0) + avg_masks = np.mean(all_masks, axis=0) + + return avg_scores.tolist(), avg_masks.tolist() \ No newline at end of file diff --git a/main.py b/main.py index 927a9af..cad778b 100644 --- a/main.py +++ b/main.py @@ -27,6 +27,8 @@ def main(**kwargs): @main.command("net") +@click.option("--es_epoch", type=int, default=10, help="Early stopping epochs") +@click.option("--tta", is_flag=True, default=False, help="If using the tta") @click.option("--dsc_margin", type=float, default=0.5) @click.option("--train_backbone", is_flag=True) @click.option("--backbone_names", "-b", type=str, multiple=True, default=[]) @@ -68,6 +70,8 @@ def net( svd, step, limit, + es_epoch, + tta, ): backbone_names = list(backbone_names) if len(backbone_names) > 1: @@ -110,6 +114,8 @@ def get_glass(input_shape, device): svd=svd, step=step, limit=limit, + es_epoch=es_epoch, + tta=tta, ) glasses.append(glass_inst.to(device)) return glasses @@ -304,15 +310,12 @@ def run( df = pd.concat([df, pd.DataFrame(row_dist, index=[0])]) if type(flag) != int: - i_auroc, i_ap, p_auroc, p_ap, p_pro, img_threshold, i_f1_max, epoch = GLASS.tester(dataloaders["testing"], dataset_name) + i_auroc, i_ap, img_threshold, i_f1_max, epoch = GLASS.tester(dataloaders["testing"], dataset_name) result_collect.append( { "dataset_name": dataset_name, "image_auroc": i_auroc, "image_ap": i_ap, - "pixel_auroc": p_auroc, - "pixel_ap": p_ap, - "pixel_pro": p_pro, "image_f1_max": i_f1_max, "f1_max_threshold": img_threshold, "best_epoch": epoch, diff --git a/mvtec.py b/mvtec.py new file mode 100644 index 0000000..ce465a6 --- /dev/null +++ b/mvtec.py @@ -0,0 +1,294 @@ +from torchvision import transforms +from perlin import perlin_mask +from enum import Enum + +import numpy as np +import pandas as pd + +import PIL +import torch +import os +import glob +import cv2 + +_CLASSNAMES = [ + "carpet", + "grid", + "leather", + "tile", + "wood", + "bottle", + "cable", + "capsule", + "hazelnut", + "metal_nut", + "pill", + "screw", + "toothbrush", + "transistor", + "zipper", +] + +IMAGENET_MEAN = [0.485, 0.456, 0.406] +IMAGENET_STD = [0.229, 0.224, 0.225] + + +class DatasetSplit(Enum): + TRAIN = "train" + TEST = "test" + + +class MVTecDataset(torch.utils.data.Dataset): + """ + PyTorch Dataset for MVTec. + """ + + def __init__( + self, + source, + anomaly_source_path='/root/dataset/dtd/images', + dataset_name='mvtec', + classname='leather', + resize=288, + imagesize=288, + split=DatasetSplit.TRAIN, + rotate_degrees=0, + translate=0, + brightness_factor=0, + contrast_factor=0, + saturation_factor=0, + gray_p=0, + h_flip_p=0, + v_flip_p=0, + distribution=0, + mean=0.5, + std=0.1, + fg=0, + rand_aug=1, + scale=0, + batch_size=8, + **kwargs, + ): + """ + Args: + source: [str]. Path to the MVTec data folder. + classname: [str or None]. Name of MVTec class that should be + provided in this dataset. If None, the datasets + iterates over all available images. + resize: [int]. (Square) Size the loaded image initially gets + resized to. + imagesize: [int]. (Square) Size the resized loaded image gets + (center-)cropped to. + split: [enum-option]. Indicates if training or test split of the + data should be used. Has to be an option taken from + DatasetSplit, e.g. mvtec.DatasetSplit.TRAIN. Note that + mvtec.DatasetSplit.TEST will also load mask data. + """ + super().__init__() + self.source = source + self.split = split + self.batch_size = batch_size + self.distribution = distribution + self.mean = mean + self.std = std + self.fg = fg + self.rand_aug = rand_aug + self.resize = resize if self.distribution != 1 else [resize, resize] + self.imgsize = imagesize + self.imagesize = (3, self.imgsize, self.imgsize) + self.classname = classname + self.dataset_name = dataset_name + + if self.distribution != 1 and (self.classname == 'toothbrush' or self.classname == 'wood'): + self.resize = round(self.imgsize * 329 / 288) + + xlsx_path = './datasets/excel/' + self.dataset_name + '_distribution.xlsx' + if self.fg == 2: # choose by file + try: + df = pd.read_excel(xlsx_path) + self.class_fg = df.loc[df['Class'] == self.dataset_name + '_' + classname, 'Foreground'].values[0] + except: + self.class_fg = 1 + elif self.fg == 1: # with foreground mask + self.class_fg = 1 + else: # without foreground mask + self.class_fg = 0 + + self.imgpaths_per_class, self.data_to_iterate = self.get_image_data() + self.anomaly_source_paths = sorted(1 * glob.glob(anomaly_source_path + "/*/*.jpg") + + 0 * list(next(iter(self.imgpaths_per_class.values())).values())[0]) + + self.transform_img = [ + transforms.Resize(self.resize), + transforms.ColorJitter(brightness_factor, contrast_factor, saturation_factor), + transforms.RandomHorizontalFlip(h_flip_p), + transforms.RandomVerticalFlip(v_flip_p), + transforms.RandomGrayscale(gray_p), + transforms.RandomAffine(rotate_degrees, + translate=(translate, translate), + scale=(1.0 - scale, 1.0 + scale), + interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(self.imgsize), + transforms.ToTensor(), + transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD), + ] + self.transform_img = transforms.Compose(self.transform_img) + + self.transform_mask = [ + transforms.Resize(self.resize), + transforms.CenterCrop(self.imgsize), + transforms.ToTensor(), + ] + self.transform_mask = transforms.Compose(self.transform_mask) + + def rand_augmenter(self): + list_aug = [ + transforms.ColorJitter(contrast=(0.8, 1.2)), + transforms.ColorJitter(brightness=(0.8, 1.2)), + transforms.ColorJitter(saturation=(0.8, 1.2), hue=(-0.2, 0.2)), + transforms.RandomHorizontalFlip(p=1), + transforms.RandomVerticalFlip(p=1), + transforms.RandomGrayscale(p=1), + transforms.RandomAutocontrast(p=1), + transforms.RandomEqualize(p=1), + transforms.RandomAffine(degrees=(-45, 45)), + ] + aug_idx = np.random.choice(np.arange(len(list_aug)), 3, replace=False) + + transform_aug = [ + transforms.Resize(self.resize), + list_aug[aug_idx[0]], + list_aug[aug_idx[1]], + list_aug[aug_idx[2]], + transforms.CenterCrop(self.imgsize), + transforms.ToTensor(), + transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD), + ] + + transform_aug = transforms.Compose(transform_aug) + return transform_aug + + def __getitem__(self, idx): + # classname, anomaly, image_path, mask_path = self.data_to_iterate[idx] + classname, anomaly, image_path = self.data_to_iterate[idx] + image_rgb = PIL.Image.open(image_path).convert("RGB") + image = self.transform_img(image_rgb) + + mask_fg = mask_s = aug_image = torch.tensor([1]) + if self.split == DatasetSplit.TRAIN: + aug = PIL.Image.open(np.random.choice(self.anomaly_source_paths)).convert("RGB") + if self.rand_aug: + transform_aug = self.rand_augmenter() + aug = transform_aug(aug) + else: + aug = self.transform_img(aug) + + # if self.class_fg: + # fgmask_path = image_path.split(classname)[0] + 'fg_mask/' + classname + '/' + os.path.split(image_path)[-1] + # mask_fg = PIL.Image.open(fgmask_path) + # mask_fg = torch.ceil(self.transform_mask(mask_fg)[0]) + + if self.class_fg: + fgmask_base_path = os.path.join(image_path.split(classname)[0], classname, 'fg_mask') + if not os.path.exists(fgmask_base_path): + os.makedirs(fgmask_base_path, exist_ok=True) + print('/fg_mask does not exist, already create it') + + fgmask_path = os.path.join(fgmask_base_path, os.path.split(image_path)[-1]) + + if os.path.exists(fgmask_path): + mask_fg = PIL.Image.open(fgmask_path) + mask_fg = torch.ceil(self.transform_mask(mask_fg)[0]) + else: + image_np = np.array(image_rgb) + target_foreground_mask = self.generate_target_foreground_mask(image_np) + mask_img = PIL.Image.fromarray(target_foreground_mask) + mask_img.save(fgmask_path) + mask_fg = torch.ceil(self.transform_mask(mask_img)[0]) + print(f"already created foreground mask for {image_path} in {fgmask_path}") + + mask_all = perlin_mask(image.shape, self.imgsize // 8, 0, 6, mask_fg, 1) + mask_s = torch.from_numpy(mask_all[0]) # freature-level + mask_l = torch.from_numpy(mask_all[1]) # image-level + + beta = np.random.normal(loc=self.mean, scale=self.std) + beta = np.clip(beta, .2, .8) + aug_image = image * (1 - mask_l) + (1 - beta) * aug * mask_l + beta * image * mask_l + + # if self.split == DatasetSplit.TEST and mask_path is not None: + # mask_gt = PIL.Image.open(mask_path).convert('L') + # mask_gt = self.transform_mask(mask_gt) + # else: + # mask_gt = torch.zeros([1, *image.size()[1:]]) + + mask_gt = torch.zeros([1, *image.size()[1:]]) # ground_truth mask + + return { + "image": image, + "aug": aug_image, + "mask_s": mask_s, + "mask_gt": mask_gt, + "is_anomaly": int(anomaly != "good"), + "image_path": image_path, + } + + def __len__(self): + return len(self.data_to_iterate) + + def get_image_data(self): + imgpaths_per_class = {} + # maskpaths_per_class = {} + + classpath = os.path.join(self.source, self.classname, self.split.value) + # maskpath = os.path.join(self.source, self.classname, "ground_truth") + anomaly_types = os.listdir(classpath) + + imgpaths_per_class[self.classname] = {} + # maskpaths_per_class[self.classname] = {} + + for anomaly in anomaly_types: + anomaly_path = os.path.join(classpath, anomaly) + anomaly_files = sorted(os.listdir(anomaly_path)) + imgpaths_per_class[self.classname][anomaly] = [os.path.join(anomaly_path, x) for x in anomaly_files] + + # if self.split == DatasetSplit.TEST and anomaly != "good": + # anomaly_mask_path = os.path.join(maskpath, anomaly) + # anomaly_mask_files = sorted(os.listdir(anomaly_mask_path)) + # maskpaths_per_class[self.classname][anomaly] = [os.path.join(anomaly_mask_path, x) for x in anomaly_mask_files] + # else: + # maskpaths_per_class[self.classname]["good"] = None + + data_to_iterate = [] + for classname in sorted(imgpaths_per_class.keys()): + for anomaly in sorted(imgpaths_per_class[classname].keys()): + for i, image_path in enumerate(imgpaths_per_class[classname][anomaly]): + data_tuple = [classname, anomaly, image_path] + # if self.split == DatasetSplit.TEST and anomaly != "good": + # data_tuple.append(maskpaths_per_class[classname][anomaly][i]) + # else: + # data_tuple.append(None) + data_to_iterate.append(data_tuple) + + return imgpaths_per_class, data_to_iterate + + + def generate_target_foreground_mask(self, img): + filter_size = 7 + + # convert RGB into GRAY scale + img_gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) + + # generate binary mask of gray scale image + _, target_foreground_mask = cv2.threshold(img_gray, 100, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU) + + # h, w = target_foreground_mask.shape + # center = target_foreground_mask[h//2, w//2] + target_foreground_mask = cv2.bitwise_not(target_foreground_mask) # if center == 0 else target_foreground_mask + target_foreground_mask = cv2.medianBlur(target_foreground_mask, filter_size) + + target_foreground_mask = target_foreground_mask.astype(np.uint8) + + return target_foreground_mask + + + diff --git a/run_glass.ipynb b/run_glass.ipynb index f2e16d1..256f0bb 100644 --- a/run_glass.ipynb +++ b/run_glass.ipynb @@ -174,7 +174,8 @@ "metadata": {}, "outputs": [], "source": [ - "## datapath: mvtec ad dataset structure, put downloaded foreground masks in this file folder, named fg_mask\n", + "## datapath: mvtec ad dataset structure, each category has /train, /test, /fg_mask\n", + "## if /fg_mask does not exist, the code will create it and generate foreground masks for normal training images.\n", "## augpath: augmented dtd dataset\n", "## results_path: where the results are saved\n", "## /mlruns: where mlflow tracking results are saved by default\n", @@ -183,7 +184,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 3, "id": "e3df3315", "metadata": {}, "outputs": [ @@ -191,21 +192,94 @@ "name": "stdout", "output_type": "stream", "text": [ - "/home/ting/GLASS/data/mvtec_ad\n", - "capsule fg_mask pill\tscrew\n", - "/home/ting/GLASS/data\n" + "/home/ting/GLASS/data/mvtec_ad/screw\n", + "license.txt readme.txt test train\n", + "/home/ting/GLASS\n" ] } ], "source": [ - "%cd data/mvtec_ad\n", + "%cd data/mvtec_ad/screw\n", "!ls\n", - "%cd ../.." + "%cd ../../.." + ] + }, + { + "cell_type": "markdown", + "id": "54c4c9d3", + "metadata": {}, + "source": [ + "### Without TTA" + ] + }, + { + "cell_type": "markdown", + "id": "6fa61d97", + "metadata": {}, + "source": [ + "Change the es_epoch to set the early stopping epochs" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8e74e2a3", + "metadata": {}, + "outputs": [], + "source": [ + "%%bash\n", + "\n", + "datapath=./data/mvtec_ad\n", + "augpath=./data/dtd/images\n", + "classes=('bottle')\n", + "flags=($(for class in \"${classes[@]}\"; do echo '-d '\"${class}\"; done))\n", + "\n", + "python main.py \\\n", + " --results_path ./results \\\n", + " --gpu 0 \\\n", + " --seed 0 \\\n", + " --test ckpt \\\n", + " net \\\n", + " -b wideresnet50 \\\n", + " -le layer2 \\\n", + " -le layer3 \\\n", + " --pretrain_embed_dimension 1536 \\\n", + " --target_embed_dimension 1536 \\\n", + " --patchsize 3 \\\n", + " --meta_epochs 5 \\\n", + " --es_epoch 10 \\\n", + " --eval_epochs 1 \\\n", + " --dsc_layers 2 \\\n", + " --dsc_hidden 1024 \\\n", + " --pre_proj 1 \\\n", + " --mining 1 \\\n", + " --noise 0.015 \\\n", + " --radius 0.75 \\\n", + " --p 0.5 \\\n", + " --step 20 \\\n", + " --limit 392 \\\n", + " dataset \\\n", + " --distribution 0 \\\n", + " --mean 0.5 \\\n", + " --std 0.1 \\\n", + " --fg 1 \\\n", + " --rand_aug 1 \\\n", + " --batch_size 8 \\\n", + " --resize 288 \\\n", + " --imagesize 288 \"${flags[@]}\" mvtec $datapath $augpath" + ] + }, + { + "cell_type": "markdown", + "id": "6819fd19", + "metadata": {}, + "source": [ + "### With TTA" ] }, { "cell_type": "code", - "execution_count": 10, + "execution_count": null, "id": "d290bccc", "metadata": {}, "outputs": [ @@ -213,12 +287,8 @@ "name": "stderr", "output_type": "stream", "text": [ - "INFO:__main__:Command line arguments: main.py --results_path ./results --gpu 0 --seed 0 --test ckpt net -b wideresnet50 -le layer2 -le layer3 --pretrain_embed_dimension 1536 --target_embed_dimension 1536 --patchsize 3 --meta_epochs 5 --eval_epochs 1 --dsc_layers 2 --dsc_hidden 1024 --pre_proj 1 --mining 1 --noise 0.015 --radius 0.75 --p 0.5 --step 20 --limit 392 dataset --distribution 0 --mean 0.5 --std 0.1 --fg 1 --rand_aug 1 --batch_size 16 --resize 288 --imagesize 288 -d capsule -d pill -d screw mvtec ./data/mvtec_ad ./data/dtd/images\n", - "INFO:__main__:Dataset CAPSULE : train=219 test=132\n", - "INFO:__main__:Dataset PILL : train=267 test=167\n", - "INFO:__main__:Dataset SCREW : train=320 test=160\n", - "INFO:__main__:Selecting dataset [mvtec_capsule] (1/3) 2025-02-08 14:00:03\n", - "2025/02/08 14:00:03 INFO mlflow.tracking.fluent: Experiment with name 'GLASS_Training' does not exist. Creating a new experiment.\n" + "INFO:__main__:Command line arguments: main.py --results_path ./results --gpu 0 --seed 0 --test ckpt net -b wideresnet50 -le layer2 -le layer3 --pretrain_embed_dimension 1536 --target_embed_dimension 1536 --patchsize 3 --meta_epochs 5 --eval_epochs 1 --dsc_layers 2 --dsc_hidden 1024 --pre_proj 1 --mining 1 --noise 0.015 --radius 0.75 --p 0.5 --step 20 --limit 392 --tta dataset --distribution 0 --mean 0.5 --std 0.1 --fg 1 --rand_aug 1 --batch_size 8 --resize 288 --imagesize 288 -d bottle mvtec ./data/mvtec_ad ./data/dtd/images\n", + "INFO:__main__:Dataset BOTTLE : train=209 test=83\n" ] }, { @@ -226,254 +296,44 @@ "output_type": "stream", "text": [ "\n", - "\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "epoch:0 discriminator loss:2.06e+00 pt:79.13 pf:17.47 rt:0.90 rg:1.04 rf:3.40 svd:0 sample:219: 0%| | 0/5 [00:23" ] @@ -748,7 +559,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": "glass", "language": "python", "name": "python3" }, @@ -762,7 +573,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.12" + "version": "3.9.21" } }, "nbformat": 4,