diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..5685766 --- /dev/null +++ b/.gitignore @@ -0,0 +1,112 @@ +submission/data +# nvidia-docker-compose +nvidia-docker-compose.yml +# JetBrains IDE +.idea/ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# pyenv +.python-version + +# celery beat schedule file +celerybeat-schedule + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ + +# vscode +.vscode/ diff --git a/featurevis/ops.py b/featurevis/ops.py index 676020c..8f49ba9 100644 --- a/featurevis/ops.py +++ b/featurevis/ops.py @@ -23,7 +23,7 @@ def __init__(self, weight=1, isotropic=False): @varargin def __call__(self, x): # Using the definitions from Wikipedia. - diffs_y = torch.abs(x[:, :, 1:] - x[:, :, -1:]) + diffs_y = torch.abs(x[:, :, 1:] - x[:, :, :-1]) diffs_x = torch.abs(x[:, :, :, 1:] - x[:, :, :, :-1]) if self.isotropic: tv = torch.sqrt(diffs_y[:, :, :, :-1] ** 2 + @@ -163,22 +163,27 @@ class RandomCrop(): """ Take a random crop of the input image. Arguments: - height (int): Height of the crop. - width (int): Width of the crop + c_height (int): Height of the crop. + c_width (int): Width of the crop + n_crops (int): Number of crops taken from each image in the batch + x (4-d tensor): A batch of images to take crops from (n x c x h x w) + Returns: + crops (5-d tensor): n x n_crops x c x c_height x c_width """ - def __init__(self, height, width): - self.height = height - self.width = width + def __init__(self, c_height, c_width, n_crops=1): + self.height = c_height + self.width = c_width + self.n_crops = n_crops @varargin def __call__(self, x): - crop_y = torch.randint(0, max(0, x.shape[-2] - self.height) + 1, (1,), - dtype=torch.int32).item() - crop_x = torch.randint(0, max(0, x.shape[-1] - self.width) + 1, (1,), - dtype=torch.int32).item() - cropped_x = x[..., crop_y: crop_y + self.height, crop_x: crop_x + self.width] + crop_y = torch.randint(0, max(0, x.shape[-2] - self.height) + 1, (self.n_crops,), + dtype=torch.int32) + crop_x = torch.randint(0, max(0, x.shape[-1] - self.width) + 1, (self.n_crops,), + dtype=torch.int32) + crops = torch.stack([x[..., cy: cy + self.height, cx: cx + self.width] for cy, cx in zip(crop_y, crop_x)]).transpose(0, 1) + return crops - return cropped_x class BatchedCrops(): @@ -288,7 +293,6 @@ class Identity(): def __call__(self, x): return x - ############################## GRADIENT OPERATIONS ####################################### class ChangeNorm(): """ Change the norm of the input. @@ -375,20 +379,36 @@ class MultiplyBy(): const: Number x will be multiplied by decay_factor: Compute const every iteration as `const + decay_factor * (iteration - 1)`. Ignored if None. + every_n_iterations: apply deca_factor every n iterations """ - def __init__(self, const, decay_factor=None): + + def __init__(self, const, decay_factor=None, every_n_iterations=1): self.const = const self.decay_factor = decay_factor + self.every_n_iterations = every_n_iterations @varargin def __call__(self, x, iteration=None): if self.decay_factor is None: const = self.const else: - const = self.const + self.decay_factor * (iteration - 1) + const = self.const + self.decay_factor * ((iteration - 1) // self.every_n_iterations) return const * x +class Slicing(): + """ + Slice x by one certain index. + """ + def __init__(self, idx): + self.idx = idx + + @varargin + def __call__(self, x, iteration=None): + if type(x) == 'tuple': + return x[self.idx] + else: + return x[:, self.idx] ########################### POST UPDATE OPERATIONS ####################################### class GaussianBlur(): @@ -438,21 +458,26 @@ def __call__(self, x, iteration=None): return final_x - -class ChangeStd(): +class ChangeStats(): """ Change the standard deviation of input. Arguments: std (float or tensor): Desired std. If tensor, it should be the same length as x. + mean (float or tensor): Desired mean. If tensor, it should be the same length as x. """ - def __init__(self, std): + def __init__(self, std, mean=None): self.std = std - + self.mean = mean + @varargin def __call__(self, x): - x_std = torch.std(x.view(len(x), -1), dim=-1) - fixed_std = x * (self.std / (x_std + 1e-9)).view(len(x), *[1, ] * (x.dim() - 1)) - return fixed_std + x_std = torch.std(x.view(len(x), -1), dim=-1, keepdim=True) + if self.mean is None: + fixed_im = x * (self.std / (x_std + 1e-9)).view(len(x), *[1, ] * (x.dim() - 1)) + else: + x_mean = torch.mean(x, (-1, -2), keepdim=True) + fixed_im = (x - x_mean) * (self.std / (x_std + 1e-9)).view(len(x), *[1, ] * (x.dim() - 1)) + self.mean + return fixed_im ####################################### LOSS ############################################# @@ -498,4 +523,4 @@ def __init__(self, target): @varargin def __call__(self, x): - return -F.poisson_nll_loss(x, self.target, log_input=False) + return -F.poisson_nll_loss(x, self.target, log_input=False) \ No newline at end of file