From 5e48a30e9ae260a98613967ca419f66e29e7e0b6 Mon Sep 17 00:00:00 2001 From: Adit Date: Tue, 6 Jun 2023 18:19:35 +0200 Subject: [PATCH 1/3] Minor changes to make it runnable --- TM/__init__.py | 0 TM/data/additional_transforms.py | 28 ++++++++++++++++++++++++++++ TM/data/preprocess.py | 2 +- 3 files changed, 29 insertions(+), 1 deletion(-) create mode 100644 TM/__init__.py create mode 100644 TM/data/additional_transforms.py diff --git a/TM/__init__.py b/TM/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/TM/data/additional_transforms.py b/TM/data/additional_transforms.py new file mode 100644 index 0000000..4f6c2e0 --- /dev/null +++ b/TM/data/additional_transforms.py @@ -0,0 +1,28 @@ +# Copyright 2017-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +import torch +from PIL import ImageEnhance + +transformtypedict=dict(Brightness=ImageEnhance.Brightness, Contrast=ImageEnhance.Contrast, Sharpness=ImageEnhance.Sharpness, Color=ImageEnhance.Color) + + + +class ImageJitter(object): + def __init__(self, transformdict): + self.transforms = [(transformtypedict[k], transformdict[k]) for k in transformdict] + + + def __call__(self, img): + out = img + randtensor = torch.rand(len(self.transforms)) + + for i, (transformer, alpha) in enumerate(self.transforms): + r = alpha*(randtensor[i]*2.0 -1.0) + 1 + out = transformer(out).enhance(r).convert('RGB') + + return out diff --git a/TM/data/preprocess.py b/TM/data/preprocess.py index 1b4de62..d7e63e6 100644 --- a/TM/data/preprocess.py +++ b/TM/data/preprocess.py @@ -5,7 +5,7 @@ ''' from anndata import read_h5ad -import scanpy.api as sc +import scanpy as sc import pandas as pd from collections import Counter import numpy as np From f02b9b8e9f112c191b256e689ef3eea23ff7fe11 Mon Sep 17 00:00:00 2001 From: Adit Date: Tue, 6 Jun 2023 18:22:33 +0200 Subject: [PATCH 2/3] Actually don't need addl transforms, that's just images --- TM/data/__init__.py | 1 - TM/data/additional_transforms.py | 28 ---------------------------- 2 files changed, 29 deletions(-) delete mode 100644 TM/data/additional_transforms.py diff --git a/TM/data/__init__.py b/TM/data/__init__.py index d77d51f..d5e29d0 100755 --- a/TM/data/__init__.py +++ b/TM/data/__init__.py @@ -1,4 +1,3 @@ from . import datamgr from . import dataset -from . import additional_transforms from . import feature_loader diff --git a/TM/data/additional_transforms.py b/TM/data/additional_transforms.py deleted file mode 100644 index 4f6c2e0..0000000 --- a/TM/data/additional_transforms.py +++ /dev/null @@ -1,28 +0,0 @@ -# Copyright 2017-present, Facebook, Inc. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - - -import torch -from PIL import ImageEnhance - -transformtypedict=dict(Brightness=ImageEnhance.Brightness, Contrast=ImageEnhance.Contrast, Sharpness=ImageEnhance.Sharpness, Color=ImageEnhance.Color) - - - -class ImageJitter(object): - def __init__(self, transformdict): - self.transforms = [(transformtypedict[k], transformdict[k]) for k in transformdict] - - - def __call__(self, img): - out = img - randtensor = torch.rand(len(self.transforms)) - - for i, (transformer, alpha) in enumerate(self.transforms): - r = alpha*(randtensor[i]*2.0 -1.0) + 1 - out = transformer(out).enhance(r).convert('RGB') - - return out From 28f4d9f0ebdbb74a0770ef344005c8c9705ccb05 Mon Sep 17 00:00:00 2001 From: Adit Shah Date: Tue, 6 Jun 2023 19:43:54 +0200 Subject: [PATCH 3/3] Remove import --- TM/data/datamgr.py | 1 - 1 file changed, 1 deletion(-) diff --git a/TM/data/datamgr.py b/TM/data/datamgr.py index f4551cd..40bd1b4 100755 --- a/TM/data/datamgr.py +++ b/TM/data/datamgr.py @@ -4,7 +4,6 @@ from PIL import Image import numpy as np import torchvision.transforms as transforms -import data.additional_transforms as add_transforms from data.dataset import SimpleDataset, SetDataset, EpisodicBatchSampler from abc import abstractmethod