Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions pastis/optimization/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ class Callback(object):
True alpha, to be used by `analysis_function`.
verbose : bool, optional
Verbosity.
restart_iter : int
The number of iterations to restart from.

Attributes
----------
Expand Down Expand Up @@ -124,7 +126,8 @@ def __init__(self, lengths, ploidy, counts=None, multiscale_factor=1,
history=None, analysis_function=None, frequency=None,
on_training_begin=None, on_training_end=None,
on_epoch_end=None, directory=None, struct_true=None,
alpha_true=None, verbose=False):
alpha_true=None, verbose=False, restart_iter=None):
self.restart_iter = restart_iter
self.ploidy = ploidy
self.multiscale_factor = multiscale_factor
self.lengths = decrease_lengths_res(lengths, multiscale_factor)
Expand Down Expand Up @@ -171,6 +174,8 @@ def __init__(self, lengths, ploidy, counts=None, multiscale_factor=1,
self.opt_type = None
self.alpha_loop = None
self.epoch = -1
if self.restart_iter is not None:
self.epoch = self.restart_iter - 1
self.time = '0:00:00.0'
self.structures = None
self.alpha = None
Expand All @@ -195,7 +200,6 @@ def _check_frequency(self, frequency, last_epoch=False):

def _print(self, last_epoch=False):
"""Prints loss every given number of epochs."""

if self._check_frequency(self.frequency['print'], last_epoch):
info_dict = {'At iterate': ' ' * (6 - len(str(self.epoch))) + str(
self.epoch), 'f= ': '%.6g' % self.obj['obj'],
Expand Down Expand Up @@ -273,6 +277,8 @@ def on_training_begin(self, opt_type=None, alpha_loop=None):
self.opt_type = opt_type
self.alpha_loop = alpha_loop
self.epoch = -1
if self.restart_iter is not None:
self.epoch = self.restart_iter - 1
self.seconds = 0
self.time = '0:00:00.0'
self.structures = None
Expand Down
46 changes: 36 additions & 10 deletions pastis/optimization/pastis_algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,8 @@ def infer(counts_raw, lengths, ploidy, outdir='', alpha=None, seed=0,
struct_draft_fullres=None, draft=False, simple_diploid=False,
callback_freq=None, callback_function=None, reorienter=None,
alpha_true=None, struct_true=None, input_weight=None,
exclude_zeros=False, null=False, mixture_coefs=None, verbose=True):
exclude_zeros=False, null=False, mixture_coefs=None, verbose=True,
restart_struct=None, restart_iter=None):
"""Infer 3D structures with PASTIS via Poisson model.

Optimize 3D structure from Hi-C contact counts data for diploid
Expand Down Expand Up @@ -298,6 +299,13 @@ def infer(counts_raw, lengths, ploidy, outdir='', alpha=None, seed=0,
For diploid organisms: whether this optimization is inferring a "simple
diploid" structure in which homologs are assumed to be identical and
completely overlapping with one another.
restart_struct : str
The structure to restart inference from (used as the initialization).
Number of beads per homolog of each chromosome, or hiclib .bed file with
lengths data.
restart_iter : int
The number of iterations to restart from (corresponds to the
restart_struct file).

Returns
-------
Expand Down Expand Up @@ -429,12 +437,18 @@ def infer(counts_raw, lengths, ploidy, outdir='', alpha=None, seed=0,
print(
'INITIALIZATION: initializing with true structure', flush=True)
init = struct_true
struct_init = initialize(
counts=counts, lengths=lengths, init=init, ploidy=ploidy,
random_state=random_state,
alpha=alpha_init if alpha_ is None else alpha_,
bias=bias, multiscale_factor=multiscale_factor, reorienter=reorienter,
mixture_coefs=mixture_coefs, verbose=verbose)

if restart_struct is None:
struct_init = initialize(
counts=counts, lengths=lengths, init=init, ploidy=ploidy,
random_state=random_state,
alpha=alpha_init if alpha_ is None else alpha_,
bias=bias, multiscale_factor=multiscale_factor, reorienter=reorienter,
mixture_coefs=mixture_coefs, verbose=verbose)
else:
if not outdir.endswith("/"):
outdir += "/"
struct_init = np.loadtxt(outdir + restart_struct)

# HOMOLOG-SEPARATING CONSTRAINT
if ploidy == 1 and (hsc_lambda > 0 or mhs_lambda > 0):
Expand Down Expand Up @@ -493,7 +507,8 @@ def infer(counts_raw, lengths, ploidy, outdir='', alpha=None, seed=0,
multiscale_factor=multiscale_factor,
analysis_function=callback_function,
frequency=callback_freq, directory=outdir,
struct_true=struct_true, alpha_true=alpha_true)
struct_true=struct_true, alpha_true=alpha_true,
restart_iter=restart_iter)

# INFER STRUCTURE
pm = PastisPM(counts=counts, lengths=lengths, ploidy=ploidy,
Expand Down Expand Up @@ -605,7 +620,7 @@ def pastis_poisson(counts, lengths, ploidy, outdir='', chromosomes=None,
piecewise_step1_accuracy=1, alpha_true=None,
struct_true=None, init='mds', input_weight=None,
exclude_zeros=False, null=False, mixture_coefs=None,
verbose=True):
verbose=True, restart_struct=None, restart_iter=None):
"""Infer 3D structures with PASTIS via Poisson model.

Infer 3D structure from Hi-C contact counts data for haploid or diploid
Expand Down Expand Up @@ -681,6 +696,16 @@ def pastis_poisson(counts, lengths, ploidy, outdir='', chromosomes=None,
homolog-separating constraint specificying the expected mean inter-
homolog count for each chromosome, scaled by beta and biases. If
not supplied, `mhs_k` will be estimated from the counts data.
save_freq : int
Number of iterations used as an interval to save a checkpoint of the
structure.
restart_struct : str
The structure to restart inference from (used as the initialization).
Number of beads per homolog of each chromosome, or hiclib .bed file with
lengths data.
restart_iter : int
The number of iterations to restart from (corresponds to the
restart_struct file).

Returns
-------
Expand Down Expand Up @@ -731,7 +756,8 @@ def pastis_poisson(counts, lengths, ploidy, outdir='', chromosomes=None,
callback_function=callback_function, callback_freq=callback_freq,
alpha_true=alpha_true, struct_true=struct_true,
input_weight=input_weight, exclude_zeros=exclude_zeros,
null=null, mixture_coefs=mixture_coefs, verbose=verbose)
null=null, mixture_coefs=mixture_coefs, verbose=verbose,
restart_struct=restart_struct, restart_iter=restart_iter)
else:
from .piecewise_whole_genome import infer_piecewise

Expand Down
9 changes: 9 additions & 0 deletions pastis/script/pastis-poisson
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,15 @@ parser.add_argument("--multiscale_rounds", default=1, type=int,
" should be inferred during multiscale optimization."
" Values of 1 or 0 disable multiscale"
" optimization.")
parser.add_argument("--save_freq", type=int, default=None,
help="The number of iterations used as an interval to save"
" a checkpoint of the structure at.")
parser.add_argument("--restart_struct", type=str, default=None,
help="The structure to restart inference from (used as the"
" initialization.")
parser.add_argument("--restart_iter", type=int, default=None,
help="The number of iterations to restart from (corresponds"
" to the restart_struct file).")

# Optimization convergence
parser.add_argument("--max_iter", default=30000, type=int,
Expand Down