diff --git a/bfast/monitor/python/base.py b/bfast/monitor/python/base.py index 622ec8c..fd598c0 100644 --- a/bfast/monitor/python/base.py +++ b/bfast/monitor/python/base.py @@ -10,7 +10,6 @@ import numpy as np np.warnings.filterwarnings('ignore') np.set_printoptions(suppress=True) -from sklearn import linear_model from bfast.base import BFASTMonitorBase from bfast.monitor.utils import compute_end_history, compute_lam, map_indices @@ -102,7 +101,7 @@ def __init__(self, self._timers = {} self.use_mp = use_mp - def fit(self, data, dates, nan_value=0): + def fit(self, data, dates, nan_value=0, **kwargs): """ Fits the models for the ndarray 'data' Parameters @@ -151,30 +150,14 @@ def fit(self, data, dates, nan_value=0): self.magnitudes = rval[:,:,2].astype(np.float32) self.valids = rval[:,:,3].astype(np.int32) else: - means_global = np.zeros((data.shape[1], data.shape[2]), dtype=np.float32) - magnitudes_global = np.zeros((data.shape[1], data.shape[2]), dtype=np.float32) - breaks_global = np.zeros((data.shape[1], data.shape[2]), dtype=np.int32) - valids_global = np.zeros((data.shape[1], data.shape[2]), dtype=np.int32) - - for i in range(data.shape[1]): - if self.verbose > 0: - print("Processing row {}".format(i)) - - for j in range(data.shape[2]): - y = data[:,i,j] - (pix_break, - pix_mean, - pix_magnitude, - pix_num_valid) = self.fit_single(y) - breaks_global[i,j] = pix_break - means_global[i,j] = pix_mean - magnitudes_global[i,j] = pix_magnitude - valids_global[i,j] = pix_num_valid - - self.breaks = breaks_global - self.means = means_global - self.magnitudes = magnitudes_global - self.valids = valids_global + + rval = np.apply_along_axis(self.fit_single, 0, data) + + #print(rval.shape) + self.breaks = rval[0].astype(np.int32) + self.means = rval[1].astype(np.float32) + self.magnitudes = rval[2].astype(np.float32) + self.valids = rval[3].astype(np.int32) return self @@ -210,7 +193,10 @@ def fit_single(self, y): magnitude = 0.0 if self.verbose > 1: print("WARNING: Not enough observations: ns={ns}, Ns={Ns}".format(ns=ns, Ns=Ns)) - return brk, mean, magnitude, Ns + + rval = np.array([brk, mean, magnitude, Ns]) + + return rval val_inds = val_inds[ns:] val_inds -= self.n @@ -224,30 +210,26 @@ def fit_single(self, y): X_nn_m = X_nn[:, ns:] y_nn_h = y_nn[:ns] y_nn_m = y_nn[ns:] - + # (1) fit linear regression model for history period - model = linear_model.LinearRegression(fit_intercept=False) - model.fit(X_nn_h.T, y_nn_h) + coef = np.linalg.pinv(X_nn_h@X_nn_h.T)@X_nn_h@y_nn_h if self.verbose > 1: - column_names = np.array(["Intercept", - "trend", - "harmonsin1", + column_names = np.array(["harmonsin1", "harmoncos1", "harmonsin2", "harmoncos2", "harmonsin3", "harmoncos3"]) if self.trend: - indxs = np.array([0, 1, 3, 5, 7, 2, 4, 6]) + indxs = np.array([1, 3, 5, 7, 2, 4, 6]) else: - indxs = np.array([0, 2, 4, 6, 1, 3, 5]) - # print(column_names[indxs]) + indxs = np.array([2, 4, 6, 1, 3, 5]) print(column_names[indxs]) - print(model.coef_[indxs]) + print(coef[indxs]) # get predictions for all non-nan points - y_pred = model.predict(X_nn.T) + y_pred = X_nn.T@coef y_error = y_nn - y_pred # (2) evaluate model on monitoring period mosum_nn process @@ -277,14 +259,16 @@ def fit_single(self, y): print("bounds", bounds) breaks = np.abs(mosum) > bounds - first_break = np.where(breaks)[0] + first_break = np.nonzero(breaks)[0] if first_break.shape[0] > 0: - first_break = first_break[0] + first_break = first_break[0].item() else: first_break = -1 - return first_break, mean, magnitude, Ns + rval = np.array([first_break, mean.item(), magnitude.item(), Ns.item()]) + + return rval def get_timers(self): """ Returns runtime measurements for the diff --git a/bfast/monitor/utils.py b/bfast/monitor/utils.py index 6141a50..bde3692 100644 --- a/bfast/monitor/utils.py +++ b/bfast/monitor/utils.py @@ -384,7 +384,7 @@ __critval_h = np.array([0.25, 0.5, 1]) __critval_period = np.arange(2, 12, 2) __critval_level = np.arange(0.95, 0.999, 0.001) -__critval_mr = np.array(["max", "range"]) +__critval_mr = ["max", "range"] def _check_par(val, name, arr, fun=lambda x: x): if not val in arr: @@ -403,13 +403,12 @@ def get_critval(h, period, level, mr): index = np.zeros(4, dtype=np.int) # Get index into table from arguments - index[0] = np.where(mr == __critval_mr)[0][0] - index[1] = np.where(level == __critval_level)[0][0] - # index[2] = np.where(period == __critval_period)[0][0] - # print((np.abs(__critval_period - period)).argmin()) + index[0] = next(i for i, v in enumerate(__critval_mr) if v == mr) + index[1] = np.nonzero(level == __critval_level)[0][0] index[2] = (np.abs(__critval_period - period)).argmin() - index[3] = np.where(h == __critval_h)[0][0] - # For historical reasons, the critvals are scaled by sqrt(2) + index[3] = np.nonzero(h == __critval_h)[0][0] + + # For legacy reasons, the critvals are scaled by sqrt(2) return __critvals[tuple(index)] * np.sqrt(2) def _find_index_date(dates, t):