diff --git a/pastis/optimization/callbacks.py b/pastis/optimization/callbacks.py index cc43cb229..36bd2baef 100644 --- a/pastis/optimization/callbacks.py +++ b/pastis/optimization/callbacks.py @@ -52,6 +52,8 @@ class Callback(object): True alpha, to be used by `analysis_function`. verbose : bool, optional Verbosity. + restart_iter : int + The number of iterations to restart from. Attributes ---------- @@ -124,7 +126,8 @@ def __init__(self, lengths, ploidy, counts=None, multiscale_factor=1, history=None, analysis_function=None, frequency=None, on_training_begin=None, on_training_end=None, on_epoch_end=None, directory=None, struct_true=None, - alpha_true=None, verbose=False): + alpha_true=None, verbose=False, restart_iter=None): + self.restart_iter = restart_iter self.ploidy = ploidy self.multiscale_factor = multiscale_factor self.lengths = decrease_lengths_res(lengths, multiscale_factor) @@ -171,6 +174,8 @@ def __init__(self, lengths, ploidy, counts=None, multiscale_factor=1, self.opt_type = None self.alpha_loop = None self.epoch = -1 + if self.restart_iter is not None: + self.epoch = self.restart_iter - 1 self.time = '0:00:00.0' self.structures = None self.alpha = None @@ -195,7 +200,6 @@ def _check_frequency(self, frequency, last_epoch=False): def _print(self, last_epoch=False): """Prints loss every given number of epochs.""" - if self._check_frequency(self.frequency['print'], last_epoch): info_dict = {'At iterate': ' ' * (6 - len(str(self.epoch))) + str( self.epoch), 'f= ': '%.6g' % self.obj['obj'], @@ -273,6 +277,8 @@ def on_training_begin(self, opt_type=None, alpha_loop=None): self.opt_type = opt_type self.alpha_loop = alpha_loop self.epoch = -1 + if self.restart_iter is not None: + self.epoch = self.restart_iter - 1 self.seconds = 0 self.time = '0:00:00.0' self.structures = None diff --git a/pastis/optimization/pastis_algorithms.py b/pastis/optimization/pastis_algorithms.py index 48a1cd8c9..33d895245 100644 --- a/pastis/optimization/pastis_algorithms.py +++ b/pastis/optimization/pastis_algorithms.py @@ -205,7 +205,8 @@ def infer(counts_raw, lengths, ploidy, outdir='', alpha=None, seed=0, struct_draft_fullres=None, draft=False, simple_diploid=False, callback_freq=None, callback_function=None, reorienter=None, alpha_true=None, struct_true=None, input_weight=None, - exclude_zeros=False, null=False, mixture_coefs=None, verbose=True): + exclude_zeros=False, null=False, mixture_coefs=None, verbose=True, + restart_struct=None, restart_iter=None): """Infer 3D structures with PASTIS via Poisson model. Optimize 3D structure from Hi-C contact counts data for diploid @@ -298,6 +299,13 @@ def infer(counts_raw, lengths, ploidy, outdir='', alpha=None, seed=0, For diploid organisms: whether this optimization is inferring a "simple diploid" structure in which homologs are assumed to be identical and completely overlapping with one another. + restart_struct : str + The structure to restart inference from (used as the initialization). + Number of beads per homolog of each chromosome, or hiclib .bed file with + lengths data. + restart_iter : int + The number of iterations to restart from (corresponds to the + restart_struct file). Returns ------- @@ -429,12 +437,18 @@ def infer(counts_raw, lengths, ploidy, outdir='', alpha=None, seed=0, print( 'INITIALIZATION: initializing with true structure', flush=True) init = struct_true - struct_init = initialize( - counts=counts, lengths=lengths, init=init, ploidy=ploidy, - random_state=random_state, - alpha=alpha_init if alpha_ is None else alpha_, - bias=bias, multiscale_factor=multiscale_factor, reorienter=reorienter, - mixture_coefs=mixture_coefs, verbose=verbose) + + if restart_struct is None: + struct_init = initialize( + counts=counts, lengths=lengths, init=init, ploidy=ploidy, + random_state=random_state, + alpha=alpha_init if alpha_ is None else alpha_, + bias=bias, multiscale_factor=multiscale_factor, reorienter=reorienter, + mixture_coefs=mixture_coefs, verbose=verbose) + else: + if not outdir.endswith("/"): + outdir += "/" + struct_init = np.loadtxt(outdir + restart_struct) # HOMOLOG-SEPARATING CONSTRAINT if ploidy == 1 and (hsc_lambda > 0 or mhs_lambda > 0): @@ -493,7 +507,8 @@ def infer(counts_raw, lengths, ploidy, outdir='', alpha=None, seed=0, multiscale_factor=multiscale_factor, analysis_function=callback_function, frequency=callback_freq, directory=outdir, - struct_true=struct_true, alpha_true=alpha_true) + struct_true=struct_true, alpha_true=alpha_true, + restart_iter=restart_iter) # INFER STRUCTURE pm = PastisPM(counts=counts, lengths=lengths, ploidy=ploidy, @@ -605,7 +620,7 @@ def pastis_poisson(counts, lengths, ploidy, outdir='', chromosomes=None, piecewise_step1_accuracy=1, alpha_true=None, struct_true=None, init='mds', input_weight=None, exclude_zeros=False, null=False, mixture_coefs=None, - verbose=True): + verbose=True, restart_struct=None, restart_iter=None): """Infer 3D structures with PASTIS via Poisson model. Infer 3D structure from Hi-C contact counts data for haploid or diploid @@ -681,6 +696,16 @@ def pastis_poisson(counts, lengths, ploidy, outdir='', chromosomes=None, homolog-separating constraint specificying the expected mean inter- homolog count for each chromosome, scaled by beta and biases. If not supplied, `mhs_k` will be estimated from the counts data. + save_freq : int + Number of iterations used as an interval to save a checkpoint of the + structure. + restart_struct : str + The structure to restart inference from (used as the initialization). + Number of beads per homolog of each chromosome, or hiclib .bed file with + lengths data. + restart_iter : int + The number of iterations to restart from (corresponds to the + restart_struct file). Returns ------- @@ -731,7 +756,8 @@ def pastis_poisson(counts, lengths, ploidy, outdir='', chromosomes=None, callback_function=callback_function, callback_freq=callback_freq, alpha_true=alpha_true, struct_true=struct_true, input_weight=input_weight, exclude_zeros=exclude_zeros, - null=null, mixture_coefs=mixture_coefs, verbose=verbose) + null=null, mixture_coefs=mixture_coefs, verbose=verbose, + restart_struct=restart_struct, restart_iter=restart_iter) else: from .piecewise_whole_genome import infer_piecewise diff --git a/pastis/script/pastis-poisson b/pastis/script/pastis-poisson index 7017bb6db..f4cc9439d 100644 --- a/pastis/script/pastis-poisson +++ b/pastis/script/pastis-poisson @@ -51,6 +51,15 @@ parser.add_argument("--multiscale_rounds", default=1, type=int, " should be inferred during multiscale optimization." " Values of 1 or 0 disable multiscale" " optimization.") +parser.add_argument("--save_freq", type=int, default=None, + help="The number of iterations used as an interval to save" + " a checkpoint of the structure at.") +parser.add_argument("--restart_struct", type=str, default=None, + help="The structure to restart inference from (used as the" + " initialization.") +parser.add_argument("--restart_iter", type=int, default=None, + help="The number of iterations to restart from (corresponds" + " to the restart_struct file).") # Optimization convergence parser.add_argument("--max_iter", default=30000, type=int,