Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 107 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,108 @@
*.pyc
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py

# pyenv
.python-version

# pipenv
Pipfile.lock

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/

# Project specific
saved_models/
*.npz
*.pth
*.pt
*.pkl
*.pickle
test_images/
VOCdevkit/

# IDE
.vscode/
.idea/
*.swp
*.swo
*~

# OS
.DS_Store
Thumbs.db
2 changes: 1 addition & 1 deletion Pipfile
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@ ipython = "*"
scipy = "*"

[requires]
python_version = "3.6"
python_version = "3.8"
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Approximate Convolutional Sparse Coding (ACSC)

A pytorch implementation of a ACSC model based on **Lerned Convolutional Sparse Coding** model proposed [here](https://arxiv.org/abs/1711.00328) and or [here](https://ieeexplore.ieee.org/abstract/document/8462313).
A PyTorch implementation of an ACSC model based on **Learned Convolutional Sparse Coding** model proposed [here](https://arxiv.org/abs/1711.00328) and [here](https://ieeexplore.ieee.org/abstract/document/8462313).


## ACSC block description
Expand Down
12 changes: 6 additions & 6 deletions analyze_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,21 +48,21 @@ def plot_dict(model, save_path):
my_subplot(cd, [kers_per_row, kers_per_col], 'conv-dictionary', save_path)


def evaluate_thrshold(model, save_path, name):
thrshold_avg = [float(model.softthrsh0.thrshold.mean())]
def evaluate_threshold(model, save_path, name):
threshold_avg = [float(model.softthrsh0.threshold.mean())]

for thrsh in model.softthrsh1:
thrshold_avg.append(float(thrsh.thrshold.mean()))
threshold_avg.append(float(thrsh.threshold.mean()))

plt.plot(range(len(thrshold_avg)), thrshold_avg, '*')
plt.plot(range(len(threshold_avg)), threshold_avg, '*')
plt.savefig(os.path.join(save_path, name))
plt.clf()

def evaluate_csc(model, img_n, save_path, im_name):
"""Plot CSC
"""
sparse_code_delta = []
for csc, csc_res, lista_iter in model.forward_enc_generataor(img_n.unsqueeze(0)):
for csc, csc_res, lista_iter in model.forward_enc_generator(img_n.unsqueeze(0)):
_, depth, rows, cols = csc.shape
sc_per_col = int(np.sqrt(depth))
sc_per_row = sc_per_col + (depth - sc_per_col**2)
Expand Down Expand Up @@ -94,7 +94,7 @@ def evaluate(args):

plot_dict(model, log_dir)
evaluate_csc(model, testset[7][0], log_dir, testset.image_filenames[7])
evaluate_thrshold(model, log_dir, 'thrshold')
evaluate_threshold(model, log_dir, 'threshold')

def main():
"""Run test on trained model.
Expand Down
52 changes: 15 additions & 37 deletions common.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def to_np(_x): return _x.data.cpu().numpy()

def I(_x): return _x

def normilize(_x, _val=255, shift=0):
def normalize(_x, _val=255, shift=0):
return (_x - shift)/ _val

def count_parameters(model):
Expand Down Expand Up @@ -48,21 +48,6 @@ def init_model_dir(path, name):
os.mkdir(full_path)
return full_path

'''
Either string defining an activation function or module (e.g. nn.ReLU)
'''
if isinstance(act_fun, str):
if act_fun == 'LeakyReLU':
return nn.LeakyReLU(0.2, inplace=True)
elif act_fun == 'ELU':
return nn.ELU()
elif act_fun == 'none':
return nn.Sequential()
else:
assert False
else:
return act_fun()


def flip(x, dim):
dim = x.dim() + dim if dim < 0 else dim
Expand Down Expand Up @@ -113,7 +98,7 @@ def delete_pixels(ins, is_training, sample_prob=0.3):
return ins * mask + (1 - mask)
return ins

def reconsturction_loss(distance='l1', use_cuda=True):
def reconstruction_loss(distance='l1', use_cuda=True):

if distance == 'l1':
dist = nn.L1Loss()
Expand All @@ -124,8 +109,6 @@ def reconsturction_loss(distance='l1', use_cuda=True):
else:
raise ValueError(f"unidentified value {distance}")

#if use_cuda:
# dist = dist.cuda()
return dist

def get_criterion(losses_types, factors, use_cuda=True):
Expand All @@ -138,13 +121,10 @@ def get_criterion(losses_types, factors, use_cuda=True):
"""
losses = []
for loss_type in losses_types:
losses.append(reconsturction_loss(loss_type))

#if use_cuda:
# losses = [l.cuda() for l in losses]
losses.append(reconstruction_loss(loss_type))

def total_loss(results, targets):
"""Cacluate total loss
"""Calculate total loss
total_loss = sum_i losses_i(results_i, targets_i)
Args:
results(tensor): nn outputs.
Expand Down Expand Up @@ -181,20 +161,19 @@ def clean(save_path, save_count=10):
print('removing', f)
os.remove(f)

def save_train(path, model, optimizer, schedular=None, epoch=None):
def save_train(path, model, optimizer, scheduler=None, epoch=None):
state = {
'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
}
#TODO(hillel): fix this so we can save schedular state
#if schedular is not None:
# state['schedular'] = schedular.state_dict()
if scheduler is not None:
state['scheduler'] = scheduler.state_dict()
if epoch is not None:
state['epoch'] = epoch
torch.save(state, os.path.join(path, 'epoch_{}'.format(epoch)))
return os.path.join(path, 'epoch_{}'.format(epoch))

def load_train(path, model, optimizer, schedular=None):
def load_train(path, model, optimizer, scheduler=None):
state = torch.load(path)

pretrained = state['model']
Expand All @@ -205,12 +184,12 @@ def load_train(path, model, optimizer, schedular=None):
except Exception as e:
print(f'did not restore optimizer due to error {e}')
else:
print('Optimizer not inilized since no data for it exists in supplied path')
if schedular is not None:
if 'schedular' in state:
schedular.load_state_dict(state['schedular'])
print('Optimizer not initialized since no data for it exists in supplied path')
if scheduler is not None:
if 'scheduler' in state:
scheduler.load_state_dict(state['scheduler'])
else:
print('Schedular not inilized since no data for it exists in supplied path')
print('Scheduler not initialized since no data for it exists in supplied path')
if 'epoch' in state:
e = state['epoch']
else:
Expand All @@ -224,10 +203,9 @@ def load_eval(path, model):

state = torch.load(path, map_location='cpu')
pretrained = state['model']
current = model.state_dict()

# very dangerous!!!
pretrained = {k:v for k, v in zip(current.keys(), pretrained.values())}
# Load state dict with strict=False to allow for model architecture changes
# This will warn about missing or unexpected keys
model.load_state_dict(pretrained, strict=False)
model.eval()

9 changes: 4 additions & 5 deletions convsparse_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,9 @@ def forward_enc(self, inputs):
csc = self.softthrsh1[lyr](csc + sc_residual)
return csc

def forward_enc_generataor(self, inputs):
"""forwar encoder generator
Use for debug and anylize model.
def forward_enc_generator(self, inputs):
"""forward encoder generator
Use for debug and analyze model.
"""
csc = self.softthrsh0(self.encode_conv0(inputs))

Expand Down Expand Up @@ -135,9 +135,8 @@ def __init__(self, _lambd):
self._lambd = _lambd

@property
def thrshold(self):
def threshold(self):
return self._lambd
# self._lambd.register_hook(print)

def forward(self, inputs):
""" sign(inputs) * (abs(inputs) - thrshold)"""
Expand Down
5 changes: 2 additions & 3 deletions datasets.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from __future__ import division
from torch.autograd import Variable
import torch.utils.data as data
from functools import partial
import torch
Expand All @@ -14,7 +13,7 @@ def is_image_file(filename):

def load_img(filepath, convert='L'):
img = np.array(Image.open(filepath).convert(convert))
img = Variable(torch.from_numpy(img[None,...]),requires_grad=False).float()
img = torch.from_numpy(img[None,...]).float()
return img

def find_file_in_folder(folder, file_name):
Expand Down Expand Up @@ -110,7 +109,7 @@ def __init__(self, npz_path, key, pre_transform, inputs_transform, use_cuda=True
self._inputs_transform = inputs_transform

def __getitem__(self, index):
_targets = Variable(torch.from_numpy(self._targets[index]).float(), requires_grad=False)
_targets = torch.from_numpy(self._targets[index]).float()
_inputs = self._inputs_transform(_targets)
if self._use_cuda:
_targets = _targets.cuda()
Expand Down
11 changes: 11 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Core dependencies
torch>=1.7.0
torchvision>=0.8.0
numpy>=1.19.0
matplotlib>=3.3.0
scipy>=1.5.0
Pillow>=8.3.2
ipython>=7.0.0

# Development dependencies
pybm3d>=3.0.0
Loading