diff --git a/pyproject.toml b/pyproject.toml index 99901fd..2d93893 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,8 +33,8 @@ training = [ # CASA tools for Measurement Set I/O casa = [ - "python-casacore; sys_platform != 'darwin'", - "casatools; sys_platform == 'darwin'", + "casatools>=6.5.0", + "casatasks>=6.5.0", ] # Development tools diff --git a/rfi_toolbox/__init__.py b/rfi_toolbox/__init__.py index 8b18e7a..af2557f 100644 --- a/rfi_toolbox/__init__.py +++ b/rfi_toolbox/__init__.py @@ -26,12 +26,14 @@ __email__ = "pjaganna@nrao.edu" # Eager imports (required for multiprocessing pickle compatibility) -from . import datasets, preprocessing +import datasets as datasets +import preprocessing as preprocessing # Lazy imports for other modules to avoid circular dependencies def __getattr__(name): """Lazy import for optional modules.""" import importlib + print(f"[DEBUG __getattr__] Lazy loading: {name}") # List of valid lazy-loaded modules valid_modules = { @@ -40,10 +42,12 @@ def __getattr__(name): } if name in valid_modules: + print(f"[DEBUG __getattr__] Calling importlib.import_module for: {name}") # Use importlib to avoid triggering __getattr__ recursion mod = importlib.import_module(f".{name}", __name__) + print(f"[DEBUG __getattr__] Successfully imported: {name}") # Cache in globals to avoid repeated imports globals()[name] = mod return mod - raise AttributeError(f"module '{__name__}' has no attribute '{name}'") \ No newline at end of file + raise AttributeError(f"module '{__name__}' has no attribute '{name}'") diff --git a/rfi_toolbox/datasets/rfi_mask_dataset.py b/rfi_toolbox/datasets/rfi_mask_dataset.py index 1d6c27c..e4d7b7e 100644 --- a/rfi_toolbox/datasets/rfi_mask_dataset.py +++ b/rfi_toolbox/datasets/rfi_mask_dataset.py @@ -15,22 +15,12 @@ from tqdm import tqdm try: - import casacore.tables as ct - - use_casacore = True + from casatools import table CASA_AVAILABLE = True except ImportError: - try: - from casatools import table - - use_casacore = False - CASA_AVAILABLE = True - except ImportError: - # CASA not available - class will fail at instantiation if use_ms=True - ct = None - table = None - use_casacore = False - CASA_AVAILABLE = False + # CASA not available - class will fail at instantiation if use_ms=True + table = None + CASA_AVAILABLE = False class RFIMaskDataset(Dataset): @@ -73,13 +63,11 @@ def __init__( if not CASA_AVAILABLE: raise ImportError( "CASA is required for use_ms=True but is not installed.\n" - "Install with: pip install python-casacore OR pip install casatools" + "Install with: pip install rfi-toolbox[casa]" ) - if use_casacore: - self.tb = ct.table(ms_name, readonly=True) - else: - tb = table() - self.tb = tb.open(ms_name, readonly=True) + tb = table() + tb.open(ms_name) + self.tb = tb self.num_antennas = self.tb.getcol("ANTENNA1").max() + 1 self.spw_array = np.unique(self.tb.getcol("DATA_DESC_ID")) diff --git a/rfi_toolbox/io/__init__.py b/rfi_toolbox/io/__init__.py index 483af1e..ff058c0 100644 --- a/rfi_toolbox/io/__init__.py +++ b/rfi_toolbox/io/__init__.py @@ -4,16 +4,28 @@ Handles CASA measurement set I/O and synthetic data injection. """ +print("[DEBUG io/__init__] Starting io module import") + +print("[DEBUG io/__init__] Attempting to import MSLoader") try: from .ms_loader import MSLoader + + print("[DEBUG io/__init__] MSLoader imported successfully") _HAS_CASA = True -except ImportError: +except ImportError as e: + print(f"[DEBUG io/__init__] MSLoader import failed: {e}") _HAS_CASA = False MSLoader = None +print("[DEBUG io/__init__] Attempting to import inject_synthetic_data") try: from .ms_injection import inject_synthetic_data -except ImportError: + + print("[DEBUG io/__init__] inject_synthetic_data imported successfully") +except ImportError as e: + print(f"[DEBUG io/__init__] inject_synthetic_data import failed: {e}") inject_synthetic_data = None +print("[DEBUG io/__init__] io module import complete") + __all__ = ["MSLoader", "inject_synthetic_data"] diff --git a/rfi_toolbox/io/ms_injection.py b/rfi_toolbox/io/ms_injection.py index 178a018..507a7cc 100644 --- a/rfi_toolbox/io/ms_injection.py +++ b/rfi_toolbox/io/ms_injection.py @@ -11,13 +11,20 @@ import numpy as np from tqdm import tqdm +print("[DEBUG ms_injection] Basic imports complete") + +print("[DEBUG ms_injection] Attempting casatools import") try: from casatools import table + print("[DEBUG ms_injection] casatools.table imported successfully") CASA_AVAILABLE = True -except ImportError: +except ImportError as e: + print(f"[DEBUG ms_injection] casatools import failed: {e}") CASA_AVAILABLE = False +print("[DEBUG ms_injection] ms_injection module import complete") + def inject_synthetic_data( template_ms_path, @@ -44,14 +51,17 @@ def inject_synthetic_data( """ if not CASA_AVAILABLE: raise ImportError( - "casatools is required for MS injection. " "Install with: pip install rfi-toolbox[casa]" + "casatools is required for MS injection. " + "Install with: pip install rfi-toolbox[casa]" ) template_ms_path = Path(template_ms_path) # Default output path if output_ms_path is None: - output_ms_path = template_ms_path.parent / f"{template_ms_path.stem}.synthetic.ms" + output_ms_path = ( + template_ms_path.parent / f"{template_ms_path.stem}.synthetic.ms" + ) else: output_ms_path = Path(output_ms_path) @@ -104,7 +114,10 @@ def inject_synthetic_data( # For simplicity, assume all SPWs have same channel count # and we're filling all SPWs with the same data if len(set(channels_per_spw)) > 1: - print(" WARNING: MS has SPWs with different channel counts. " "Using first SPW only.") + print( + " WARNING: MS has SPWs with different channel counts. " + "Using first SPW only." + ) channels_in_spw = channels_per_spw[0] @@ -130,7 +143,9 @@ def inject_synthetic_data( for spw_idx in range(num_spw): # Query this baseline + SPW - subtable = tb.query(f"DATA_DESC_ID=={spw_idx} && ANTENNA1=={ant1} && ANTENNA2=={ant2}") + subtable = tb.query( + f"DATA_DESC_ID=={spw_idx} && ANTENNA1=={ant1} && ANTENNA2=={ant2}" + ) if subtable.nrows() == 0: print(f" WARNING: No rows for baseline ({ant1},{ant2}), SPW {spw_idx}") @@ -251,7 +266,9 @@ def inject_synthetic_data( subtable.putcell("DATA", row_idx, cell_val) except Exception as e: subtable.close() - raise RuntimeError(f"Failed to write DATA row {row_idx}: {e}") from e + raise RuntimeError( + f"Failed to write DATA row {row_idx}: {e}" + ) from e subtable.close() diff --git a/rfi_toolbox/io/ms_loader.py b/rfi_toolbox/io/ms_loader.py index a8b76dd..8e020bf 100644 --- a/rfi_toolbox/io/ms_loader.py +++ b/rfi_toolbox/io/ms_loader.py @@ -6,10 +6,15 @@ import numpy as np from tqdm import tqdm +import gc +print("[DEBUG ms_loader] Attempting casatools import") try: from casatools import table + + print("[DEBUG ms_loader] casatools.table imported successfully") except Exception as e: + print(f"[DEBUG ms_loader] casatools import failed: {e}") raise ImportError( "MSLoader requires CASA to be properly installed and configured.\n" "Install with: pip install rfi-toolbox[casa]\n" @@ -17,26 +22,34 @@ f"Original error: {e}" ) from e +print("[DEBUG ms_loader] About to define MSLoader class") + class MSLoader: """ Load complex visibilities from CASA measurement sets. Simplified interface: - >>> loader = MSLoader('observation.ms') + >>> loader = MSLoader('observation.ms', field_id=0) >>> loader.load(num_antennas=5, mode='DATA') >>> data = loader.data # Shape: (baselines, pols, channels, times) >>> flags = loader.load_flags() # Load existing flags + + Field handling: + >>> fields = loader.get_available_fields() # Get list of all field IDs + >>> loader.load(field_id=1) # Load specific field """ - def __init__(self, ms_path): + def __init__(self, ms_path, field_id=None): """ Initialize MS loader. Args: ms_path: Path to measurement set + field_id: Optional FIELD_ID to load. If None, loads all fields. """ self.ms_path = str(ms_path) + self.field_id = field_id # Open MS and read metadata tb = table() @@ -57,7 +70,12 @@ def __init__(self, ms_path): self.tb.open(self.ms_path, nomodify=False) # Get number of time samples - subtable = self.tb.query("DATA_DESC_ID==0 && ANTENNA1==0 && ANTENNA2==1") + field_filter = ( + f" && FIELD_ID=={self.field_id}" if self.field_id is not None else "" + ) + subtable = self.tb.query( + f"DATA_DESC_ID==0 && ANTENNA1==0 && ANTENNA2==1{field_filter}" + ) self.num_times = len(subtable.getcol("TIME")) subtable.close() @@ -67,13 +85,68 @@ def __init__(self, ms_path): self.antenna_baseline_map = None self.spw_list = None - def load(self, num_antennas=None, mode="DATA"): + def get_metadata(self, num_antennas=None, mode="DATA"): + """ + Get MS metadata without loading data (fast). + + Args: + num_antennas: Number of antennas (default: all) + mode: Column to get metadata for + + Returns: + dict with keys: num_baselines, num_pols, num_channels, num_times, baseline_map + """ + if num_antennas is None: + num_antennas = self.num_antennas + + # Get shape from dminfo (no data loading) + dminfo = self.tb.getdminfo() + + # Find the storage manager for the DATA column + data_sm = None + for key, info in dminfo.items(): + if mode in info.get("COLUMNS", []): + data_sm = info + break + + if data_sm is None: + raise ValueError(f"Column {mode} not found in MS") + + # Extract shape from first hypercube + hypercubes = data_sm["SPEC"]["HYPERCUBES"] + if hypercubes: + first_cube = list(hypercubes.values())[0] + cell_shape = first_cube["CellShape"] + num_pols, num_channels = cell_shape[0], cell_shape[1] + else: + raise ValueError(f"No hypercube info for {mode}") + + # Build baseline map + baseline_map = [] + for i in range(num_antennas): + for j in range(i + 1, num_antennas): + baseline_map.append((i, j)) + + num_baselines = len(baseline_map) + num_times = self.num_times + + return { + "num_baselines": num_baselines, + "num_pols": num_pols, + "num_channels": num_channels, + "num_times": num_times, + "baseline_map": baseline_map, + "shape": (num_baselines, num_pols, num_channels, num_times), + } + + def load(self, num_antennas=None, mode="DATA", field_id=None): """ Load complex visibilities from MS. Args: num_antennas: Number of antennas to load (default: all) mode: Column to load ('DATA', 'CORRECTED_DATA', etc.) + field_id: Optional FIELD_ID to load. If provided, overrides field_id from __init__. Returns: Loaded data shape: (num_baselines, num_pols, num_channels, num_times) @@ -81,6 +154,10 @@ def load(self, num_antennas=None, mode="DATA"): if num_antennas is None: num_antennas = self.num_antennas + # Allow field_id parameter to override instance field_id + if field_id is not None: + self.field_id = field_id + # Filter to SPWs with same number of channels same_spw_list = [] same_channels_list = [] @@ -100,13 +177,24 @@ def load(self, num_antennas=None, mode="DATA"): print(f"\nLoading {mode} from {self.ms_path}...") print(f" Antennas: {num_antennas}/{self.num_antennas}") - print(f" SPWs: {num_spw} ({num_channels} channels each = {total_channels} total)") + print( + f" SPWs: {num_spw} ({num_channels} channels each = {total_channels} total)" + ) print(f" Times: {self.num_times}") + if self.field_id is not None: + print(f" Field ID: {self.field_id}") + + # Build field filter string for queries + field_filter = ( + f" && FIELD_ID=={self.field_id}" if self.field_id is not None else "" + ) for i in tqdm(range(num_antennas), desc="Antenna 1"): for j in range(i + 1, self.num_antennas): # Allocate array for this baseline - baseline_data = np.zeros([4, total_channels, self.num_times], dtype="complex128") + baseline_data = np.zeros( + [4, total_channels, self.num_times], dtype="complex128" + ) # Check if this baseline has any data has_data = False @@ -114,7 +202,7 @@ def load(self, num_antennas=None, mode="DATA"): # Load all SPWs for this baseline for spw_idx, spw in enumerate(same_spw_list): subtable = self.tb.query( - f"DATA_DESC_ID=={spw} && ANTENNA1=={i} && ANTENNA2=={j}" + f"DATA_DESC_ID=={spw} && ANTENNA1=={i} && ANTENNA2=={j}{field_filter}" ) # Skip if no data for this baseline/SPW @@ -149,7 +237,9 @@ def load(self, num_antennas=None, mode="DATA"): return self.data - def load_single_baseline(self, ant1=0, ant2=1, pol_idx=0, mode="DATA"): + def load_single_baseline( + self, ant1=0, ant2=1, pol_idx=0, mode="DATA", field_id=None + ): """ Load single baseline, single polarization. @@ -158,10 +248,14 @@ def load_single_baseline(self, ant1=0, ant2=1, pol_idx=0, mode="DATA"): ant2: Second antenna pol_idx: Polarization index (0=XX, 1=XY, 2=YX, 3=YY) mode: Column to load ('DATA', 'CORRECTED_DATA', etc.) + field_id: Optional FIELD_ID to load. If provided, overrides field_id from __init__. Returns: Complex array shape: (total_channels, num_times) """ + # Allow field_id parameter to override instance field_id + if field_id is not None: + self.field_id = field_id # Filter to SPWs with same number of channels same_spw_list = [] same_channels_list = [] @@ -177,15 +271,26 @@ def load_single_baseline(self, ant1=0, ant2=1, pol_idx=0, mode="DATA"): print(f"\nLoading single baseline from {self.ms_path}...") print(f" Baseline: {ant1}-{ant2}, Pol: {pol_idx}") - print(f" SPWs: {num_spw} ({num_channels} channels each = {total_channels} total)") + print( + f" SPWs: {num_spw} ({num_channels} channels each = {total_channels} total)" + ) print(f" Times: {self.num_times}") + if self.field_id is not None: + print(f" Field ID: {self.field_id}") + + # Build field filter string for queries + field_filter = ( + f" && FIELD_ID=={self.field_id}" if self.field_id is not None else "" + ) # Allocate array for this baseline baseline_data = np.zeros([total_channels, self.num_times], dtype="complex128") # Load all SPWs for this baseline for spw_idx, spw in enumerate(same_spw_list): - subtable = self.tb.query(f"DATA_DESC_ID=={spw} && ANTENNA1=={ant1} && ANTENNA2=={ant2}") + subtable = self.tb.query( + f"DATA_DESC_ID=={spw} && ANTENNA1=={ant1} && ANTENNA2=={ant2}{field_filter}" + ) if subtable.nrows() == 0: subtable.close() @@ -206,6 +311,129 @@ def load_single_baseline(self, ant1=0, ant2=1, pol_idx=0, mode="DATA"): return baseline_data + def load_baseline(self, ant1, ant2, mode="DATA", field_id=None): + """ + Load one baseline, all pols. Opens/closes table per call. + + Args: + ant1, ant2: Antenna pair + mode: Column ('DATA', 'CORRECTED_DATA', etc.) + field_id: Optional FIELD_ID + + Returns: + Complex array (pols, channels, times) + """ + tb = table() + tb.open(self.ms_path, nomodify=False) + + # Get SPW info + tb_spw = table() + tb_spw.open(self.ms_path + "/SPECTRAL_WINDOW") + channels_per_spw = tb_spw.getcol("NUM_CHAN") + tb_spw.close() + + # Use SPWs with same channel count + same_spw_list = [] + for spw, num_chan in enumerate(channels_per_spw): + if num_chan == channels_per_spw[0]: + same_spw_list.append(spw) + + num_channels = channels_per_spw[0] + total_channels = len(same_spw_list) * num_channels + + # Get num times (query first SPW to get shape) + field_filter = f" && FIELD_ID=={field_id}" if field_id is not None else "" + test_sub = tb.query( + f"DATA_DESC_ID=={same_spw_list[0]} && ANTENNA1=={ant1} && ANTENNA2=={ant2}{field_filter}" + ) + num_times = test_sub.nrows() + test_sub.close() + + # Allocate + baseline_data = np.zeros([4, total_channels, num_times], dtype="complex128") + + # Load each SPW + for spw_idx, spw in enumerate(same_spw_list): + subtable = tb.query( + f"DATA_DESC_ID=={spw} && ANTENNA1=={ant1} && ANTENNA2=={ant2}{field_filter}" + ) + + if subtable.nrows() == 0: + subtable.close() + continue + + spw_data = subtable.getcol(mode) # (pols, channels, times) + + start_ch = spw_idx * num_channels + end_ch = (spw_idx + 1) * num_channels + baseline_data[:, start_ch:end_ch, :] = spw_data + + subtable.close() + + tb.close() + return baseline_data + + def save_baseline_flags(self, ant1, ant2, flags, field_id=None): + """ + Write flags for one baseline. Opens/closes table per call. + + Args: + ant1, ant2: Antenna pair + flags: Boolean array (pols, channels, times) + field_id: Optional FIELD_ID + """ + tb = table() + tb.open(self.ms_path, nomodify=False) + + # Get SPW info + tb_spw = table() + tb_spw.open(self.ms_path + "/SPECTRAL_WINDOW") + channels_per_spw = tb_spw.getcol("NUM_CHAN") + tb_spw.close() + + # Use SPWs with same channel count + same_spw_list = [] + for spw, num_chan in enumerate(channels_per_spw): + if num_chan == channels_per_spw[0]: + same_spw_list.append(spw) + + num_channels = channels_per_spw[0] + + field_filter = f" && FIELD_ID=={field_id}" if field_id is not None else "" + + # Write each SPW + for spw_idx, spw in enumerate(same_spw_list): + start_ch = spw_idx * num_channels + end_ch = (spw_idx + 1) * num_channels + spw_flags = flags[:, start_ch:end_ch, :] + + subtable = tb.query( + f"DATA_DESC_ID=={spw} && ANTENNA1=={ant1} && ANTENNA2=={ant2}{field_filter}" + ) + + if subtable.nrows() > 0: + subtable.putcol("FLAG", spw_flags) + + subtable.close() + + tb.close() + + def get_baseline_pairs(self, num_antennas=None): + """ + Get list of baseline pairs. + + Returns: + List of (ant1, ant2) tuples + """ + if num_antennas is None: + num_antennas = self.num_antennas + + pairs = [] + for i in range(num_antennas): + for j in range(i + 1, num_antennas): + pairs.append((i, j)) + return pairs + def load_flags(self): """ Load existing flags from MS. @@ -217,6 +445,13 @@ def load_flags(self): raise ValueError("Must call load() first to establish baseline map") print("\nLoading flags from MS...") + if self.field_id is not None: + print(f" Field ID: {self.field_id}") + + # Build field filter string for queries + field_filter = ( + f" && FIELD_ID=={self.field_id}" if self.field_id is not None else "" + ) flags_list = [] num_channels = self.channels_per_spw_list[0] @@ -228,7 +463,7 @@ def load_flags(self): for spw_idx, spw in enumerate(self.spw_list): subtable = self.tb.query( - f"DATA_DESC_ID=={spw} && ANTENNA1=={ant1} && ANTENNA2=={ant2}" + f"DATA_DESC_ID=={spw} && ANTENNA1=={ant1} && ANTENNA2=={ant2}{field_filter}" ) spw_flags = subtable.getcol("FLAG") @@ -257,6 +492,13 @@ def save_flags(self, flags): raise ValueError("Must call load() first to establish baseline map") print("\nSaving flags to MS...") + if self.field_id is not None: + print(f" Field ID: {self.field_id}") + + # Build field filter string for queries + field_filter = ( + f" && FIELD_ID=={self.field_id}" if self.field_id is not None else "" + ) num_channels = self.channels_per_spw_list[0] @@ -273,17 +515,32 @@ def save_flags(self, flags): # Write to MS subtable = self.tb.query( - f"DATA_DESC_ID=={spw} && ANTENNA1=={ant1} && ANTENNA2=={ant2}" + f"DATA_DESC_ID=={spw} && ANTENNA1=={ant1} && ANTENNA2=={ant2}{field_filter}" ) subtable.putcol("FLAG", spw_flags) subtable.close() print(" Flags saved successfully") + def get_available_fields(self): + """ + Get list of unique FIELD_IDs present in this measurement set. + + Returns: + list: Sorted list of field IDs + """ + field_ids = np.unique(self.tb.getcol("FIELD_ID")) + return sorted(field_ids.tolist()) + def close(self): """Close the measurement set.""" if hasattr(self, "tb"): self.tb.close() + if hasattr(self, "data"): + del self.data + if hasattr(self, "flags"): + del self.flags + gc.collect() def __del__(self): """Ensure MS is closed on deletion.""" @@ -295,3 +552,6 @@ def magnitude(self): if self.data is None: raise ValueError("Must call load() first") return np.abs(self.data) + + +print("[DEBUG ms_loader] MSLoader class definition complete") diff --git a/rfi_toolbox/scripts/normalize_rfi_data.py b/rfi_toolbox/scripts/normalize_rfi_data.py index 5bf7abb..de2298c 100644 --- a/rfi_toolbox/scripts/normalize_rfi_data.py +++ b/rfi_toolbox/scripts/normalize_rfi_data.py @@ -2,7 +2,6 @@ import argparse import os import numpy as np -from tqdm import tqdm from sklearn.preprocessing import StandardScaler, RobustScaler import shutil diff --git a/rfi_toolbox/scripts/train_model.py b/rfi_toolbox/scripts/train_model.py index 9522daa..a1ca929 100644 --- a/rfi_toolbox/scripts/train_model.py +++ b/rfi_toolbox/scripts/train_model.py @@ -194,5 +194,4 @@ def dice_loss(pred, target, smooth=1.): print(f"Final model saved to {final_model_path}") if __name__ == "__main__": - from datetime import datetime main() diff --git a/rfi_toolbox/visualization/visualize.py b/rfi_toolbox/visualization/visualize.py index a856787..eb9579b 100644 --- a/rfi_toolbox/visualization/visualize.py +++ b/rfi_toolbox/visualization/visualize.py @@ -1,12 +1,11 @@ # rfi_toolbox/visualization/visualize.py import numpy as np import torch -from torch.utils.data import DataLoader from rfi_toolbox.datasets import RFIMaskDataset from rfi_toolbox.models.unet import UNet from bokeh.plotting import figure, show from bokeh.layouts import column, row -from bokeh.models import Slider, ColumnDataSource, Button +from bokeh.models import Slider, ColumnDataSource from bokeh.palettes import Viridis256, Gray256 import random import argparse diff --git a/tests/test_ms_loader_fields.py b/tests/test_ms_loader_fields.py new file mode 100644 index 0000000..ab42f7a --- /dev/null +++ b/tests/test_ms_loader_fields.py @@ -0,0 +1,62 @@ +""" +Unit tests for MSLoader field-based selection functionality. + +Verifies that field_id parameter and get_available_fields() method are properly +integrated into MSLoader without requiring actual MS data. +""" + +import pytest +import inspect + +# Direct import to test actual implementation +try: + from rfi_toolbox.io.ms_loader import MSLoader + CASA_AVAILABLE = True +except ImportError: + CASA_AVAILABLE = False + +pytestmark = pytest.mark.skipif( + not CASA_AVAILABLE, + reason="CASA not available - skipping MSLoader tests" +) + + +class TestMSLoaderFieldFunctionality: + """Verify field-related API exists and is correct.""" + + def test_field_id_in_init(self): + """__init__ should accept optional field_id parameter.""" + sig = inspect.signature(MSLoader.__init__) + assert 'field_id' in sig.parameters + assert sig.parameters['field_id'].default is None + + def test_field_id_in_load(self): + """load() should accept optional field_id parameter.""" + sig = inspect.signature(MSLoader.load) + assert 'field_id' in sig.parameters + assert sig.parameters['field_id'].default is None + + def test_field_id_in_load_single_baseline(self): + """load_single_baseline() should accept optional field_id parameter.""" + sig = inspect.signature(MSLoader.load_single_baseline) + assert 'field_id' in sig.parameters + assert sig.parameters['field_id'].default is None + + def test_get_available_fields_exists(self): + """get_available_fields() method should exist.""" + assert hasattr(MSLoader, 'get_available_fields') + assert callable(MSLoader.get_available_fields) + + def test_docstrings_updated(self): + """Docstrings should mention field functionality.""" + assert MSLoader.__doc__ is not None + assert 'field' in MSLoader.__doc__.lower() + + assert MSLoader.__init__.__doc__ is not None + assert 'field_id' in MSLoader.__init__.__doc__.lower() + + assert MSLoader.load.__doc__ is not None + assert 'field_id' in MSLoader.load.__doc__.lower() + + assert MSLoader.get_available_fields.__doc__ is not None + assert 'field' in MSLoader.get_available_fields.__doc__.lower()