diff --git a/docs/CHANGELOG.md b/docs/CHANGELOG.md index 54ff9f88..db2f5a50 100644 --- a/docs/CHANGELOG.md +++ b/docs/CHANGELOG.md @@ -39,6 +39,7 @@ coordinate system, by [Nils Harmening](https://github.com/harmening). ([#110](ht - Changed the names of several motion correction algorithms from `motion_correct.motion_correct_X` to `motion_correct.X`. Argument names were made PEP8 compliant. The example `22_motion_artefacts_and_correction` was improved. By [Eike Middell](https://github.com/emiddell). - The function `cedalion.vis.anatomy.plot_montage3D` now accepts a `landmarks` parameter to specify which landmarks should be highlighted. Pass `None` (default) to show all available canonical registration landmarks (e.g. Nz, Iz, LPA, RPA, Cz), a list of landmark names to show specific ones, or an empty list to show none, by [Mohammad Orabe](https://github.com/orabe). ([#84](https://github.com/ibs-lab/cedalion/issues/84)) +- Included t-stat thresholding in `cedalion.vis.misc.plot_probe_gui`, by [Shannon Kelley](https://github.com/shankell212). ([#131](https://github.com/ibs-lab/cedalion/pull/131)) ### Deprecated diff --git a/src/cedalion/vis/misc/plot_probe_gui.py b/src/cedalion/vis/misc/plot_probe_gui.py index 30f5a363..7b0bf258 100644 --- a/src/cedalion/vis/misc/plot_probe_gui.py +++ b/src/cedalion/vis/misc/plot_probe_gui.py @@ -16,7 +16,7 @@ import numpy as np from matplotlib.backends.backend_qtagg import FigureCanvas from matplotlib.backends.backend_qtagg import NavigationToolbar2QT as NavigationToolbar -from matplotlib.backends.qt_compat import QtWidgets +from matplotlib.backends.qt_compat import QtWidgets, QtGui from matplotlib.figure import Figure import cedalion @@ -26,10 +26,11 @@ class _MAIN_GUI(QtWidgets.QMainWindow): - def __init__(self, snirfData=None, geo2d=None, geo3d=None): + def __init__(self, snirfData=None, stderr=None, geo2d=None, geo3d=None): # Initialize super().__init__() self.snirfData = snirfData + self.stderr = stderr self.geo2d = geo2d self.geo3d = geo3d @@ -43,7 +44,7 @@ def __init__(self, snirfData=None, geo2d=None, geo3d=None): window_layout.setSpacing(10) # Set Minimum Size - self.setMinimumSize(1000, 850) + self.setMinimumSize(800, 600) # Set Window Title self.setWindowTitle("Plot Probe") @@ -139,6 +140,28 @@ def __init__(self, snirfData=None, geo2d=None, geo3d=None): ## Add Prune Channels control_panel_layout.addWidget(prune_channels, stretch=1) + # Create t-stat thresh control - only if standard error is provided + if self.stderr is not None: + tstat_control = QtWidgets.QGroupBox("T-Stat Threshold") + tstat_control_layout = QtWidgets.QVBoxLayout() + tstat_control_layout.setSpacing(10) + tstat_control.setLayout(tstat_control_layout) + + ## Set up T-stat threshold controller + tstat_threshold_layout = QtWidgets.QHBoxLayout() + self.tstat_threshold = QtWidgets.QDoubleSpinBox() + self.tstat_threshold.setValue(0) # Default: no threshold + self.tstat_threshold.setRange(-100, 100) + self.tstat_threshold.setSingleStep(0.5) + self.tstat_threshold.setDecimals(2) + self.tstat_threshold.valueChanged.connect(self._tstat_threshold_changed) + tstat_threshold_layout.addWidget(QtWidgets.QLabel("Threshold")) + tstat_threshold_layout.addWidget(self.tstat_threshold) + tstat_control_layout.addLayout(tstat_threshold_layout) + + ## Add T-stat control + control_panel_layout.addWidget(tstat_control, stretch=1) + # Create Probe Control probe_control = QtWidgets.QGroupBox("Probe") probe_control_layout = QtWidgets.QVBoxLayout() @@ -186,7 +209,7 @@ def __init__(self, snirfData=None, geo2d=None, geo3d=None): # control_panel_layout.addWidget(ref_point,stretch=1) # Create button action for opening file - open_btn = QtWidgets.QAction("Open...", self) + open_btn = QtGui.QAction("Open...", self) open_btn.setStatusTip("Open SNIRF file") open_btn.triggered.connect(self._open_dialog) @@ -199,9 +222,10 @@ def __init__(self, snirfData=None, geo2d=None, geo3d=None): file_menu.addAction(open_btn) if self.snirfData is not None: + time_dim = 'reltime' if 'reltime' in self.snirfData.dims else 'time' # Detect time dimension if np.shape(self.snirfData)[1] != len(self.snirfData.channel): self.snirfData = self.snirfData.transpose( - "trial_type", "channel", "chromo", "reltime" + "trial_type", "channel", "chromo", time_dim ) self.sPos = self.geo2d.sel( @@ -272,6 +296,25 @@ def _init_calc(self): self.fade_factor = 0.3 ##### Connect? self.lineWidth = 0.7 ##### Connect? + if 'reltime' in self.snirfData.dims: # handle 'time' or 'reltime' + self.time_dim = 'reltime' + elif 'time' in self.snirfData.dims: + self.time_dim = 'time' + else: + raise ValueError("Data must have either 'time' or 'reltime' dimension") + + + # T-stat calculation + self.tstat_thresh = 0 # Initialize threshold + if self.stderr is not None: + # Calculate t-statistic: mean / stderr + self.tstat = self.snirfData / self.stderr + # Get max absolute t-stat per channel across time for thresholding + self.tstat_max = np.abs(self.tstat).max(dim=self.time_dim) # Max across time + else: + self.tstat = None + self.tstat_max = None + self.conditions.clear() self.opt2circ.setChecked(False) self.measline.setChecked(False) @@ -354,14 +397,21 @@ def _init_calc(self): # Extract time information try: - self.t = self.snirfData.time.values - except Exception: - pass - - try: - self.t = self.snirfData.reltime.values + self.t = self.snirfData[self.time_dim].values # CHANGE to use self.time_dim except Exception: - pass + # Fallback to trying both + try: + self.t = self.snirfData.time.values + except Exception: + self.t = self.snirfData.reltime.values + # try: + # self.t = self.snirfData.time.values + # except Exception: + # pass + # try: + # self.t = self.snirfData.reltime.values + # except Exception: + # pass self.minT = min(self.t) self.maxT = max(self.t) @@ -424,37 +474,86 @@ def _init_calc(self): self._draw_hrf() self.conditions.setCurrentRow(0) + # def _change_hrf_vis(self): # orig + # for i_con in range(self.trial_types): + # if i_con == self.conditions.currentRow(): + # for i_ch in range(self.channels): + # if ( + # self.chan_dist[i_ch] >= self.channel_min_dist + # and self.chan_dist[i_ch] <= self.ssFadeThres + # ): + # for i_col in range(self.chromophores): + # self.hrf[ + # i_con * self.channels * self.chromophores + # + i_ch * self.chromophores + # + i_col + # ].set_color(self.chrom[i_col] + [self.fade_factor]) + # elif ( + # self.chan_dist[i_ch] >= self.ssFadeThres + # and self.chan_dist[i_ch] <= self.channel_max_dist + # ): + # for i_col in range(self.chromophores): + # self.hrf[ + # i_con * self.channels * self.chromophores + # + i_ch * self.chromophores + # + i_col + # ].set_color(self.chrom[i_col] + [1]) + # else: + # for i_col in range(self.chromophores): + # self.hrf[ + # i_con * self.channels * self.chromophores + # + i_ch * self.chromophores + # + i_col + # ].set_color(self.chrom[i_col] + [0]) + # else: + # for i_ch in range(self.channels): + # for i_col in range(self.chromophores): + # self.hrf[ + # i_con * self.channels * self.chromophores + # + i_ch * self.chromophores + # + i_col + # ].set_color(self.chrom[i_col] + [0]) + + # self._ax.figure.canvas.draw() + def _change_hrf_vis(self): for i_con in range(self.trial_types): if i_con == self.conditions.currentRow(): for i_ch in range(self.channels): + # Check if channel meets t-stat threshold + meets_tstat = True + if self.tstat_max is not None: + # Check if ANY chromophore meets threshold for this channel + meets_tstat = any( + self.tstat_max.sel(trial_type=self.snirfData.trial_type[i_con]).values[i_ch, i_col] >= self.tstat_thresh + for i_col in range(self.chromophores) + ) + + # Determine alpha based on distance and t-stat if ( self.chan_dist[i_ch] >= self.channel_min_dist and self.chan_dist[i_ch] <= self.ssFadeThres ): - for i_col in range(self.chromophores): - self.hrf[ - i_con * self.channels * self.chromophores - + i_ch * self.chromophores - + i_col - ].set_color(self.chrom[i_col] + [self.fade_factor]) + base_alpha = self.fade_factor elif ( self.chan_dist[i_ch] >= self.ssFadeThres and self.chan_dist[i_ch] <= self.channel_max_dist ): - for i_col in range(self.chromophores): - self.hrf[ - i_con * self.channels * self.chromophores - + i_ch * self.chromophores - + i_col - ].set_color(self.chrom[i_col] + [1]) + base_alpha = 1 else: - for i_col in range(self.chromophores): - self.hrf[ - i_con * self.channels * self.chromophores - + i_ch * self.chromophores - + i_col - ].set_color(self.chrom[i_col] + [0]) + base_alpha = 0 + + # Apply t-stat threshold: fade if doesn't meet threshold + if not meets_tstat and base_alpha > 0: + base_alpha = base_alpha * 0.5 # Further fade channels below threshold + + # Set color for each chromophore + for i_col in range(self.chromophores): + self.hrf[ + i_con * self.channels * self.chromophores + + i_ch * self.chromophores + + i_col + ].set_color(self.chrom[i_col] + [base_alpha]) else: for i_ch in range(self.channels): for i_col in range(self.chromophores): @@ -491,6 +590,8 @@ def _toggle_circles(self): if self.opt2circ.isChecked(): self.src_optodes.set_color([1, 0, 0]) self.det_optodes.set_color([0, 0, 1]) + self.src_optodes.set_markersize(3) # ADD THIS - Make circles smaller + self.det_optodes.set_markersize(3) # ADD THIS - Make circles smaller for idx, source in enumerate(self.sPos.label): self.src_label[idx].set_color([1, 0, 0, 0]) @@ -499,6 +600,8 @@ def _toggle_circles(self): else: self.src_optodes.set_color([1, 0, 0, 0]) self.det_optodes.set_color([0, 0, 1, 0]) + self.src_optodes.set_markersize(5) # ADD THIS - Reset to original size + self.det_optodes.set_markersize(5) # ADD THIS - Reset to original size for idx, source in enumerate(self.sPos.label): self.src_label[idx].set_color([1, 0, 0, 1]) @@ -567,6 +670,11 @@ def _ssfade_changed(self, i): self.ssFadeThres = i self._change_hrf_vis() + def _tstat_threshold_changed(self, i): + # Pass the new t-stat threshold and update HRF visibility + self.tstat_thresh = i + self._change_hrf_vis() + def _draw_hrf(self): print("Plotting Optodes!") t0 = time.time() @@ -656,6 +764,7 @@ def run_vis( blockaverage: cdt.NDTimeSeries, geo2d: cdt.LabeledPoints, geo3d: cdt.LabeledPoints, + stderr: cdt.NDTimeSeries = None, # optional standerr input ): """Opens the visualization GUI. @@ -666,6 +775,7 @@ def run_vis( """ app = QtWidgets.QApplication(sys.argv) - main_gui = _MAIN_GUI(snirfData=blockaverage, geo2d=geo2d, geo3d=geo3d) + #main_gui = _MAIN_GUI(snirfData=blockaverage, geo2d=geo2d, geo3d=geo3d) + main_gui = _MAIN_GUI(snirfData=blockaverage, stderr=stderr, geo2d=geo2d, geo3d=geo3d) main_gui.show() sys.exit(app.exec()) diff --git a/src/cedalion/vis/misc/time_series_gui.py b/src/cedalion/vis/misc/time_series_gui.py index 9e9f1f5d..a04e5405 100644 --- a/src/cedalion/vis/misc/time_series_gui.py +++ b/src/cedalion/vis/misc/time_series_gui.py @@ -12,7 +12,7 @@ from matplotlib.backends.backend_qtagg import FigureCanvas from matplotlib.backends.backend_qtagg import NavigationToolbar2QT as NavigationToolbar -from matplotlib.backends.qt_compat import QtCore, QtWidgets +from matplotlib.backends.qt_compat import QtCore, QtWidgets, QtGui from matplotlib.figure import Figure import cedalion @@ -25,12 +25,12 @@ class _MAIN_GUI(QtWidgets.QMainWindow): def __init__(self, snirfRec=None): # Initialize super().__init__() - + # Check what type of data passed in - if type(snirfRec) == cdc.recording.Recording: + if isinstance(snirfRec, cdc.recording.Recording): self.snirfRec = snirfRec self.oftype = "rec" - elif type(snirfRec) == dict: + elif isinstance(snirfRec, dict): self.cfg_dataset = snirfRec["cfg_dataset"] self.recs = snirfRec["rec"] self.i_subj = 0 @@ -39,9 +39,9 @@ def __init__(self, snirfRec=None): self.oftype = "pkl" else: raise Exception("Unexpected format passed!") - + self._UI_SETUP() - + def _UI_SETUP(self): # Set central widget self._main = QtWidgets.QWidget() @@ -97,14 +97,14 @@ def _UI_SETUP(self): control_panel_layout.setSpacing(20) control_panel.setLayout(control_panel_layout) window_layout.addWidget(control_panel, stretch=1) - + # Create File Control Layout file_layout = QtWidgets.QGridLayout() file_layout.setAlignment(QtCore.Qt.AlignTop) control_panel_layout.addLayout( file_layout, ) - + ## Subject Selector self.subj = QtWidgets.QComboBox() if self.oftype == "rec": @@ -116,7 +116,7 @@ def _UI_SETUP(self): self.subj.currentTextChanged.connect(self._subj_changed) file_layout.addWidget(QtWidgets.QLabel("Subject:"), 0, 0) file_layout.addWidget(self.subj, 0, 1) - + ## Run Selector self.run = QtWidgets.QComboBox() if self.oftype == "rec": @@ -190,7 +190,7 @@ def _UI_SETUP(self): control_panel_layout.addStretch() # Create button action for opening file - open_btn = QtWidgets.QAction("Open...", self) + open_btn = QtGui.QAction("Open...", self) open_btn.setStatusTip("Open SNIRF file") open_btn.triggered.connect(self._open_dialog) @@ -424,20 +424,20 @@ def _toggle_circles(self): def _toggle_stims(self, s): self.plot_stims = s self._draw_timeseries() - + def _subj_changed(self, s): # TODO if s == "None": return - + self.i_subj = self.cfg_dataset["subj_ids"].index(s) self.snirfRec = self.recs[self.i_subj][self.i_run] self._dataTimeSeries_ax.clear() self._init_calc() - + def _run_changed(self, s): # TODO if s == "None": return - + self.i_run = self.cfg_dataset["file_ids"].index(s) self.snirfRec = self.recs[self.i_subj][self.i_run] self._dataTimeSeries_ax.clear()