Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 112 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
submission/data
# nvidia-docker-compose
nvidia-docker-compose.yml
# JetBrains IDE
.idea/
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/
.pytest_cache/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
target/

# Jupyter Notebook
.ipynb_checkpoints

# pyenv
.python-version

# celery beat schedule file
celerybeat-schedule

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/

# vscode
.vscode/
71 changes: 48 additions & 23 deletions featurevis/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def __init__(self, weight=1, isotropic=False):
@varargin
def __call__(self, x):
# Using the definitions from Wikipedia.
diffs_y = torch.abs(x[:, :, 1:] - x[:, :, -1:])
diffs_y = torch.abs(x[:, :, 1:] - x[:, :, :-1])
diffs_x = torch.abs(x[:, :, :, 1:] - x[:, :, :, :-1])
if self.isotropic:
tv = torch.sqrt(diffs_y[:, :, :, :-1] ** 2 +
Expand Down Expand Up @@ -163,22 +163,27 @@ class RandomCrop():
""" Take a random crop of the input image.

Arguments:
height (int): Height of the crop.
width (int): Width of the crop
c_height (int): Height of the crop.
c_width (int): Width of the crop
n_crops (int): Number of crops taken from each image in the batch
x (4-d tensor): A batch of images to take crops from (n x c x h x w)
Returns:
crops (5-d tensor): n x n_crops x c x c_height x c_width
"""
def __init__(self, height, width):
self.height = height
self.width = width
def __init__(self, c_height, c_width, n_crops=1):
self.height = c_height
self.width = c_width
self.n_crops = n_crops

@varargin
def __call__(self, x):
crop_y = torch.randint(0, max(0, x.shape[-2] - self.height) + 1, (1,),
dtype=torch.int32).item()
crop_x = torch.randint(0, max(0, x.shape[-1] - self.width) + 1, (1,),
dtype=torch.int32).item()
cropped_x = x[..., crop_y: crop_y + self.height, crop_x: crop_x + self.width]
crop_y = torch.randint(0, max(0, x.shape[-2] - self.height) + 1, (self.n_crops,),
dtype=torch.int32)
crop_x = torch.randint(0, max(0, x.shape[-1] - self.width) + 1, (self.n_crops,),
dtype=torch.int32)
crops = torch.stack([x[..., cy: cy + self.height, cx: cx + self.width] for cy, cx in zip(crop_y, crop_x)]).transpose(0, 1)
return crops

return cropped_x


class BatchedCrops():
Expand Down Expand Up @@ -288,7 +293,6 @@ class Identity():
def __call__(self, x):
return x


############################## GRADIENT OPERATIONS #######################################
class ChangeNorm():
""" Change the norm of the input.
Expand Down Expand Up @@ -375,20 +379,36 @@ class MultiplyBy():
const: Number x will be multiplied by
decay_factor: Compute const every iteration as `const + decay_factor * (iteration
- 1)`. Ignored if None.
every_n_iterations: apply deca_factor every n iterations
"""
def __init__(self, const, decay_factor=None):

def __init__(self, const, decay_factor=None, every_n_iterations=1):
self.const = const
self.decay_factor = decay_factor
self.every_n_iterations = every_n_iterations

@varargin
def __call__(self, x, iteration=None):
if self.decay_factor is None:
const = self.const
else:
const = self.const + self.decay_factor * (iteration - 1)
const = self.const + self.decay_factor * ((iteration - 1) // self.every_n_iterations)

return const * x

class Slicing():
"""
Slice x by one certain index.
"""
def __init__(self, idx):
self.idx = idx

@varargin
def __call__(self, x, iteration=None):
if type(x) == 'tuple':
return x[self.idx]
else:
return x[:, self.idx]

########################### POST UPDATE OPERATIONS #######################################
class GaussianBlur():
Expand Down Expand Up @@ -438,21 +458,26 @@ def __call__(self, x, iteration=None):

return final_x


class ChangeStd():
class ChangeStats():
""" Change the standard deviation of input.

Arguments:
std (float or tensor): Desired std. If tensor, it should be the same length as x.
mean (float or tensor): Desired mean. If tensor, it should be the same length as x.
"""
def __init__(self, std):
def __init__(self, std, mean=None):
self.std = std

self.mean = mean

@varargin
def __call__(self, x):
x_std = torch.std(x.view(len(x), -1), dim=-1)
fixed_std = x * (self.std / (x_std + 1e-9)).view(len(x), *[1, ] * (x.dim() - 1))
return fixed_std
x_std = torch.std(x.view(len(x), -1), dim=-1, keepdim=True)
if self.mean is None:
fixed_im = x * (self.std / (x_std + 1e-9)).view(len(x), *[1, ] * (x.dim() - 1))
else:
x_mean = torch.mean(x, (-1, -2), keepdim=True)
fixed_im = (x - x_mean) * (self.std / (x_std + 1e-9)).view(len(x), *[1, ] * (x.dim() - 1)) + self.mean
return fixed_im


####################################### LOSS #############################################
Expand Down Expand Up @@ -498,4 +523,4 @@ def __init__(self, target):

@varargin
def __call__(self, x):
return -F.poisson_nll_loss(x, self.target, log_input=False)
return -F.poisson_nll_loss(x, self.target, log_input=False)