diff --git a/.gitignore b/.gitignore index d7c4706..2337d2a 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,5 @@ +*.DS_Store + # Data files data/ # Checkpoints diff --git a/stepback/plotting.py b/stepback/plotting.py index 04b646a..77d3999 100644 --- a/stepback/plotting.py +++ b/stepback/plotting.py @@ -44,34 +44,15 @@ def plot_stability(R, if isinstance(score, str): score = [score] - base_df = R.base_df.copy() - id_df = R.id_df.copy() - - grouped = base_df.groupby(['name', xaxis]) - max_epoch = grouped['epoch'].max() - assert len(max_epoch.unique()) == 1, "It seems that different setups ran for different number of epochs." - - if cutoff is None: - cutoff_epoch = (max_epoch[0], max_epoch[0]) - else: - cutoff_epoch = (cutoff, max_epoch[0]) fig, axs = plt.subplots(len(score),1,figsize=figsize) for j, s in enumerate(score): - # filter epochs - sub_df = base_df[(base_df.epoch >= cutoff_epoch[0]) & (base_df.epoch <= cutoff_epoch[1])] - # select the columns to group by - grouping_cols = [c for c in id_df.columns if c not in ignore_columns] - # group by all id_cols - df = sub_df.groupby(grouping_cols)[[s, s+'_std']].mean() # use dropna=False if we would have nan values - # move xaxis out of grouping - df = df.reset_index(level=xaxis) - # make xaxis float - df[xaxis] = df[xaxis].astype('float') - - # get method and learning rate with best score - # best_ind, best_x = df.index[df[s].argmax()], df[xaxis][df[s].argmax()] + df = R.build_sweep_df(score=s, + xaxis=xaxis, + ignore_columns=ignore_columns, + cutoff=cutoff + ) ax = axs.ravel()[j] if len(score) > 1 else axs # .unique(level=) might be useful at some point diff --git a/stepback/record.py b/stepback/record.py index c943a1d..3d05828 100644 --- a/stepback/record.py +++ b/stepback/record.py @@ -189,6 +189,38 @@ def _build_base_df(self, agg='mean'): return df + def build_sweep_df(self, score='val_score', xaxis='lr', ignore_columns=list(), cutoff=None): + + base_df = self.base_df.copy() + id_df = self.id_df.copy() + + grouped = base_df.groupby(['name', xaxis]) + max_epoch = grouped['epoch'].max() + assert len(max_epoch.unique()) == 1, "It seems that different setups ran for different number of epochs." + + if cutoff is None: + cutoff_epoch = (max_epoch[0], max_epoch[0]) + else: + cutoff_epoch = (cutoff, max_epoch[0]) + + # filter epochs + sub_df = base_df[(base_df.epoch >= cutoff_epoch[0]) + & + (base_df.epoch <= cutoff_epoch[1])] + # select the columns to group by + grouping_cols = [c for c in id_df.columns if c not in ignore_columns] + # group by all id_cols + df = sub_df.groupby(grouping_cols)[[score, score+'_std']].mean() + # move xaxis out of grouping + df = df.reset_index(level=xaxis) + # make xaxis float + df[xaxis] = df[xaxis].astype('float') + + # get method and learning rate with best score + # best_ind, best_x = df.index[df[s].argmax()], df[xaxis][df[s].argmax()] + + return df + #============ DATABASE ================================= #=======================================================