From 1f546db6120af8ab3a04ca0baad7c347097e17ed Mon Sep 17 00:00:00 2001 From: daichitakahashi Date: Mon, 16 Mar 2026 09:42:49 +0900 Subject: [PATCH] =?UTF-8?q?feat(analysisrun):=20preprocess=E3=81=AE?= =?UTF-8?q?=E5=BD=B9=E5=89=B2=E3=82=92=E6=95=B4=E7=90=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/analysisrun/__init__.py | 2 + src/analysisrun/pipeable.py | 248 +++++++++++++++++++++--------------- tests/test_pipeable.py | 46 ++++--- 3 files changed, 178 insertions(+), 118 deletions(-) diff --git a/src/analysisrun/__init__.py b/src/analysisrun/__init__.py index 73c8b14..859efa6 100644 --- a/src/analysisrun/__init__.py +++ b/src/analysisrun/__init__.py @@ -5,6 +5,7 @@ Fields, ManualInput, PreprocessArgs, + ProcessedInputs, PostprocessArgs, PostprocessArgsWithPreprocess, VirtualFile, @@ -20,6 +21,7 @@ "Fields", "ManualInput", "PreprocessArgs", + "ProcessedInputs", "PostprocessArgs", "PostprocessArgsWithPreprocess", "VirtualFile", diff --git a/src/analysisrun/pipeable.py b/src/analysisrun/pipeable.py index c5238ae..110f01f 100644 --- a/src/analysisrun/pipeable.py +++ b/src/analysisrun/pipeable.py @@ -10,6 +10,7 @@ from io import BytesIO from pathlib import Path from threading import Lock, Thread +from types import MappingProxyType from typing import ( IO, Any, @@ -17,6 +18,7 @@ Iterable, Literal, LiteralString, + Mapping, Optional, Protocol, Type, @@ -39,7 +41,7 @@ exit_with_error_streaming, redirect_stdout_to_stderr, ) -from analysisrun.scanner import Fields, Lanes +from analysisrun.scanner import Fields, scan from analysisrun.tar import FileIO, create_tar_from_dict, read_tar_as_dict @@ -279,21 +281,33 @@ class PreprocessArgs[Params: BaseModel]: """ 解析全体に関わるパラメータ """ - image_analysis_results: dict[str, pd.DataFrame] + image_analysis_results: Mapping[str, pd.DataFrame] """ - 画像解析結果(DataFrame) + cleansing済みの画像解析結果(DataFrame) """ - targets: dict[str, str] + targets: Mapping[str, str] """ 解析対象データ。keyはdata_name、valueはsample_name """ +@dataclass +class ProcessedInputs[Extra]: + image_analysis_results: Mapping[str, pd.DataFrame] + """ + preprocess後にanalyzeへ渡されるcleansing済み画像解析結果(DataFrame) + """ + extra: Extra + """ + preprocessで生成された追加データ + """ + + @dataclass class AnalyzeArgsWithPreprocess[ Params: BaseModel, ImageAnalysisResults: NamedTupleLike[Fields], - PreprocessedData, + Extra, ]: params: Params """ @@ -315,9 +329,9 @@ class AnalyzeArgsWithPreprocess[ """ 画像を保存するためのOutput実装 """ - preprocessed_data: PreprocessedData + extra: Extra """ - preprocessで生成された前処理済みデータ + preprocessで生成された追加データ """ @@ -334,7 +348,7 @@ class PostprocessArgs[Params: BaseModel]: @dataclass -class PostprocessArgsWithPreprocess[Params: BaseModel, PreprocessedData]: +class PostprocessArgsWithPreprocess[Params: BaseModel, Extra]: params: Params """ 解析全体に関わるパラメータ @@ -343,9 +357,9 @@ class PostprocessArgsWithPreprocess[Params: BaseModel, PreprocessedData]: """ 各レーンの解析結果を格納したDataFrame """ - preprocessed_data: PreprocessedData + extra: Extra """ - preprocessの結果 + preprocessで生成された追加データ """ @@ -370,7 +384,7 @@ class _AnalyzeSeqState[ParamsT: BaseModel, ImageInputModelT: BaseModel](_BaseSta @dataclass class _SequentialState(_BaseState): raw_data: dict[str, pd.DataFrame] - cleansed_data: dict[str, CleansedData] + cleansed_data: dict[str, pd.DataFrame] sample_pairs: list[tuple[str, str]] field_numbers: list[int] @@ -378,6 +392,7 @@ class _SequentialState(_BaseState): @dataclass class _ParallelState(_BaseState): raw_data: dict[str, pd.DataFrame] + cleansed_data: dict[str, pd.DataFrame] sample_pairs: list[tuple[str, str]] output_dir: Path entrypoint: Path @@ -387,6 +402,7 @@ class _ParallelState(_BaseState): @dataclass class _ParallelStreamingState(_BaseState): raw_data: dict[str, pd.DataFrame] + cleansed_data: dict[str, pd.DataFrame] sample_pairs: list[tuple[str, str]] entrypoint: Path field_numbers: list[int] @@ -474,26 +490,26 @@ def run_analysis( ) def run_analysis_with_preprocess[ - PreprocessedData, + Extra, ]( self, preprocess: Callable[ [PreprocessArgs[Params]], - PreprocessedData, + ProcessedInputs[Extra], ], analyze: Callable[ [ AnalyzeArgsWithPreprocess[ Params, ImageAnalysisResults, - PreprocessedData, + Extra, ] ], pd.Series, ], postprocess: Optional[ Callable[ - [PostprocessArgsWithPreprocess[Params, PreprocessedData]], + [PostprocessArgsWithPreprocess[Params, Extra]], pd.DataFrame, ] ] = None, @@ -599,7 +615,7 @@ def _run_analyze_seq( with redirect_stdout_to_stderr(stderr): series = analyze( AnalyzeArgs( - self.params, + _copy_params(self.params), data_name, sample_name, lanes, @@ -631,19 +647,19 @@ def _run_analyze_seq( sys.exit(0) def _run_analyze_seq_with_preprocess[ - PreprocessedData, + Extra, ]( self, preprocess: Callable[ [PreprocessArgs[Params]], - PreprocessedData, + ProcessedInputs[Extra], ], analyze: Callable[ [ AnalyzeArgsWithPreprocess[ Params, ImageAnalysisResults, - PreprocessedData, + Extra, ] ], pd.Series, @@ -656,30 +672,15 @@ def _run_analyze_seq_with_preprocess[ parsed: AnalyzeSeqInputModel = state.parsed_input field_numbers = state.field_numbers - specs = _get_image_analysis_specs(self.image_analysis_results) - raw_data = _load_image_results_raw( - parsed.image_analysis_results, - serialization="pickle", - ) - cleansed_data = _load_and_cleanse_image_results( - parsed.image_analysis_results, - specs, - serialization="pickle", - ) - with redirect_stdout_to_stderr(stderr): - if parsed.preprocessed_data: - preprocessed_data = _deserialize_preprocessed_data( - parsed.preprocessed_data.unwrap() - ) - else: - preprocessed_data = preprocess( - PreprocessArgs( - self.params, - raw_data, - dict(parsed.targets.items()), - ) - ) + if parsed.preprocessed_data is None: + raise ValueError("preprocessed_data is required in analyzeseq mode") + extra = _deserialize_preprocessed_data(parsed.preprocessed_data.unwrap()) + preprocessed_data = _load_image_results_raw( + parsed.image_analysis_results, + serialization="pickle", + ) + cleansed_data = _wrap_processed_image_results(preprocessed_data) with tarfile.open(fileobj=stdout, mode="w|") as tar: for data_name, sample_name in parsed.targets.items(): @@ -695,12 +696,12 @@ def _run_analyze_seq_with_preprocess[ with redirect_stdout_to_stderr(stderr): series = analyze( AnalyzeArgsWithPreprocess( - self.params, + _copy_params(self.params), data_name, sample_name, lanes, output, - preprocessed_data, + extra, ) ) _ensure_result_annotations(series, data_name, sample_name) @@ -722,14 +723,6 @@ def _run_analyze_seq_with_preprocess[ tar_info.pax_headers = {"is_file": "true"} tar.addfile(tar_info, BytesIO(result_data)) - preprocessed = _serialize_preprocessed_data(preprocessed_data) - preprocessed.seek(0) - preprocessed_bytes = preprocessed.read() - preprocessed_info = tarfile.TarInfo(name="preprocessed_data") - preprocessed_info.size = len(preprocessed_bytes) - preprocessed_info.pax_headers = {"is_file": "true"} - tar.addfile(preprocessed_info, BytesIO(preprocessed_bytes)) - stdout.flush() sys.exit(0) @@ -753,57 +746,67 @@ def _run_sequential( field_numbers, ) series = analyze( - AnalyzeArgs(self.params, data_name, sample_name, lanes, self.output) + AnalyzeArgs( + _copy_params(self.params), + data_name, + sample_name, + lanes, + self.output, + ) ) _ensure_result_annotations(series, data_name, sample_name) results.append(series) result_df = pd.DataFrame(results) if postprocess: - postprocessed = postprocess(PostprocessArgs(self.params, result_df)) + postprocessed = postprocess( + PostprocessArgs(_copy_params(self.params), result_df) + ) if postprocessed is not None: result_df = postprocessed return result_df def _run_sequential_with_preprocess[ - PreprocessedData, + Extra, ]( self, preprocess: Callable[ [PreprocessArgs[Params]], - PreprocessedData, + ProcessedInputs[Extra], ], analyze: Callable[ [ AnalyzeArgsWithPreprocess[ Params, ImageAnalysisResults, - PreprocessedData, + Extra, ] ], pd.Series, ], postprocess: Optional[ Callable[ - [PostprocessArgsWithPreprocess[Params, PreprocessedData]], + [PostprocessArgsWithPreprocess[Params, Extra]], pd.DataFrame, ] ], ) -> pd.DataFrame: assert isinstance(self.state, _SequentialState) state = self.state - raw_data = state.raw_data cleansed_data = state.cleansed_data sample_pairs = state.sample_pairs field_numbers = state.field_numbers - preprocessed_data = preprocess( - PreprocessArgs( + processed_inputs = preprocess( + _build_preprocess_args( self.params, - raw_data, + cleansed_data, {data_name: sample_name for data_name, sample_name in sample_pairs}, ) ) + cleansed_data = _wrap_processed_image_results( + processed_inputs.image_analysis_results, + ) results: list[pd.Series] = [] for data_name, sample_name in sample_pairs: @@ -815,12 +818,12 @@ def _run_sequential_with_preprocess[ ) series = analyze( AnalyzeArgsWithPreprocess( - self.params, + _copy_params(self.params), data_name, sample_name, lanes, self.output, - preprocessed_data, + processed_inputs.extra, ) ) _ensure_result_annotations(series, data_name, sample_name) @@ -829,7 +832,11 @@ def _run_sequential_with_preprocess[ result_df = pd.DataFrame(results) if postprocess: postprocessed = postprocess( - PostprocessArgsWithPreprocess(self.params, result_df, preprocessed_data) + PostprocessArgsWithPreprocess( + _copy_params(self.params), + result_df, + processed_inputs.extra, + ) ) if postprocessed is not None: result_df = postprocessed @@ -869,22 +876,24 @@ def _save_streamed_image( _save_images_to_dir(flat_images, output_dir) if postprocess: - postprocessed = postprocess(PostprocessArgs(self.params, result_df)) + postprocessed = postprocess( + PostprocessArgs(_copy_params(self.params), result_df) + ) if postprocessed is not None: result_df = postprocessed return result_df def _run_parallel_with_preprocess[ - PreprocessedData, + Extra, ]( self, preprocess: Callable[ [PreprocessArgs[Params]], - PreprocessedData, + ProcessedInputs[Extra], ], postprocess: Optional[ Callable[ - [PostprocessArgsWithPreprocess[Params, PreprocessedData]], + [PostprocessArgsWithPreprocess[Params, Extra]], pd.DataFrame, ] ], @@ -905,10 +914,10 @@ def _save_streamed_image( with image_write_lock: _save_image_bytes_to_dir(image_name, image_bytes, output_dir) - preprocessed_data = preprocess( - PreprocessArgs( + processed_inputs = preprocess( + _build_preprocess_args( self.params, - state.raw_data, + state.cleansed_data, { data_name: sample_name for data_name, sample_name in state.sample_pairs @@ -916,12 +925,13 @@ def _save_streamed_image( ) ) result_df, images_by_data = self._run_parallel_entrypoint( - raw_data=state.raw_data, + raw_data=dict(processed_inputs.image_analysis_results.items()), sample_pairs=state.sample_pairs, entrypoint=state.entrypoint, stdout=state.stdout, stderr=state.stderr, - preprocessed_data=preprocessed_data, + preprocessed_data=processed_inputs.extra, + include_preprocessed_data=True, output_dir=output_dir, on_image=_save_streamed_image, ) @@ -932,9 +942,9 @@ def _save_streamed_image( if postprocess: postprocessed = postprocess( PostprocessArgsWithPreprocess( - self.params, + _copy_params(self.params), result_df, - preprocessed_data, + processed_inputs.extra, ) ) if postprocessed is not None: @@ -973,6 +983,7 @@ def _stream_image( sample_pairs=state.sample_pairs, entrypoint=state.entrypoint, preprocessed_data=None, + include_preprocessed_data=False, on_image=_stream_image, ) if errors: @@ -991,7 +1002,7 @@ def _stream_image( if postprocess: with redirect_stdout_to_stderr(stderr): postprocessed = postprocess( - PostprocessArgs(self.params, result_df) + PostprocessArgs(_copy_params(self.params), result_df) ) if postprocessed is not None: result_df = postprocessed @@ -1012,16 +1023,16 @@ def _stream_image( sys.exit(0) def _run_parallel_streaming_with_preprocess[ - PreprocessedData, + Extra, ]( self, preprocess: Callable[ [PreprocessArgs[Params]], - PreprocessedData, + ProcessedInputs[Extra], ], postprocess: Optional[ Callable[ - [PostprocessArgsWithPreprocess[Params, PreprocessedData]], + [PostprocessArgsWithPreprocess[Params, Extra]], pd.DataFrame, ] ], @@ -1050,10 +1061,10 @@ def _stream_image( try: with redirect_stdout_to_stderr(stderr): - preprocessed_data = preprocess( - PreprocessArgs( + processed_inputs = preprocess( + _build_preprocess_args( self.params, - state.raw_data, + state.cleansed_data, { data_name: sample_name for data_name, sample_name in state.sample_pairs @@ -1062,10 +1073,11 @@ def _stream_image( ) result_df, errors = self._run_parallel_entrypoint_streaming( - raw_data=state.raw_data, + raw_data=dict(processed_inputs.image_analysis_results.items()), sample_pairs=state.sample_pairs, entrypoint=state.entrypoint, - preprocessed_data=preprocessed_data, + preprocessed_data=processed_inputs.extra, + include_preprocessed_data=True, on_image=_stream_image, ) if errors: @@ -1085,9 +1097,9 @@ def _stream_image( with redirect_stdout_to_stderr(stderr): postprocessed = postprocess( PostprocessArgsWithPreprocess( - self.params, + _copy_params(self.params), result_df, - preprocessed_data, + processed_inputs.extra, ) ) if postprocessed is not None: @@ -1114,6 +1126,7 @@ def _run_parallel_entrypoint_streaming( sample_pairs: list[tuple[str, str]], entrypoint: Path, preprocessed_data: Any | None, + include_preprocessed_data: bool, on_image: Callable[[str, str, bytes, Optional[str]], None], ) -> tuple[pd.DataFrame, list[tuple[str, str, str, str]]]: if not sample_pairs: @@ -1132,6 +1145,7 @@ def _run_chunk( targets, raw_data, preprocessed_data=preprocessed_data, + include_preprocessed_data=include_preprocessed_data, ) result = _run_entrypoint_with_tar_streaming( entrypoint, @@ -1202,6 +1216,7 @@ def _run_parallel_entrypoint( stdout: IO[bytes], stderr: IO[bytes], preprocessed_data: Any | None = None, + include_preprocessed_data: bool = False, output_dir: Path | None = None, on_image: Optional[Callable[[str, str, bytes, Optional[str]], None]] = None, ) -> tuple[pd.DataFrame, dict[str, dict[str, BytesIO]]]: @@ -1225,6 +1240,7 @@ def _run_chunk( targets, raw_data, preprocessed_data=preprocessed_data, + include_preprocessed_data=include_preprocessed_data, ) if on_image is not None: result = _run_entrypoint_with_tar_streaming( @@ -1562,15 +1578,16 @@ def read_context[ _stderr, ) + cleansed_data = { + name: _apply_cleansing_pipeline(df, specs[name]) + for name, df in raw_data.items() + } + if mode == "sequential": if output_dir_path is None: output_dir_path = _derive_output_dir(runtime_input.image_analysis_results) output_dir_path.mkdir(parents=True, exist_ok=True) output_impl = _FileOutput(output_dir_path) - cleansed_data = { - name: _apply_cleansing_pipeline(df, specs[name]) - for name, df in raw_data.items() - } return AnalysisContext[Params, ImageAnalysisResults]( params=runtime_input.params, image_analysis_results=image_analysis_results, @@ -1607,6 +1624,7 @@ def read_context[ stdout=_stdout, stderr=_stderr, raw_data=raw_data, + cleansed_data=cleansed_data, sample_pairs=sample_pairs, output_dir=output_dir_path, entrypoint=entrypoint, @@ -1622,6 +1640,7 @@ def read_context[ stdout=_stdout, stderr=_stderr, raw_data=raw_data, + cleansed_data=cleansed_data, sample_pairs=sample_pairs, entrypoint=entrypoint, field_numbers=field_numbers, @@ -1699,17 +1718,34 @@ def _load_image_results_raw( return raw +def _build_preprocess_args[Params: BaseModel]( + params: Params, + cleansed_data: Mapping[str, pd.DataFrame], + targets: dict[str, str], +) -> PreprocessArgs[Params]: + copied = {name: df.copy() for name, df in cleansed_data.items()} + return PreprocessArgs( + params=_copy_params(params), + image_analysis_results=MappingProxyType(copied), + targets=targets, + ) + + +def _copy_params[Params: BaseModel](params: Params) -> Params: + return params.model_copy(deep=True) + + def _apply_cleansing_pipeline( - data: pd.DataFrame | CleansedData, spec: _ImageAnalysisResultSpec -) -> CleansedData: + data: pd.DataFrame, spec: _ImageAnalysisResultSpec +) -> pd.DataFrame: cleansed: pd.DataFrame | CleansedData = data for fn in spec.cleansing: cleansed = fn(cleansed) if isinstance(cleansed, CleansedData): - return cleansed + return cleansed._data if isinstance(cleansed, pd.DataFrame): - return CleansedData(_data=cleansed) + return cleansed def _load_and_cleanse_image_results( @@ -1717,8 +1753,8 @@ def _load_and_cleanse_image_results( specs: dict[str, _ImageAnalysisResultSpec], *, serialization: _ImageResultsSerialization, -) -> dict[str, CleansedData]: - cleansed_data: dict[str, CleansedData] = {} +) -> dict[str, pd.DataFrame]: + cleansed_data: dict[str, pd.DataFrame] = {} raw_data = _load_image_results_raw( image_analysis_results_model, serialization=serialization, @@ -1729,18 +1765,24 @@ def _load_and_cleanse_image_results( return cleansed_data +def _wrap_processed_image_results( + image_analysis_results: Mapping[str, pd.DataFrame], +) -> dict[str, pd.DataFrame]: + return dict(image_analysis_results.items()) + + def _build_fields_namedtuple[ ImageAnalysisResults: NamedTupleLike[Fields], ]( image_analysis_results_type: Type[ImageAnalysisResults], - cleansed_data: dict[str, CleansedData], + cleansed_data: dict[str, pd.DataFrame], data_name: str, field_numbers: list[int], ) -> ImageAnalysisResults: lanes: dict[str, Fields] = {} - for name, cleansed in cleansed_data.items(): - lane_iter = Lanes( - whole_data=cleansed, + for name, df in cleansed_data.items(): + lane_iter = scan( + whole_data=df, target_data=[data_name], field_numbers=field_numbers, ) @@ -1900,6 +1942,8 @@ def _build_analyze_seq_tar_buffer( targets: list[tuple[str, str]], image_results: dict[str, pd.DataFrame], preprocessed_data: Any | None = None, + *, + include_preprocessed_data: bool = False, ) -> BytesIO: targets_payload = json.dumps( {data_name: sample_name for data_name, sample_name in targets} @@ -1908,7 +1952,7 @@ def _build_analyze_seq_tar_buffer( "targets": targets_payload, "params": params_payload, } - if preprocessed_data is not None: + if include_preprocessed_data: payload["preprocessed_data"] = _serialize_preprocessed_data(preprocessed_data) for name, df in image_results.items(): payload[f"image_analysis_results/{name}"] = _serialize_dataframe_pickle(df) diff --git a/tests/test_pipeable.py b/tests/test_pipeable.py index fe2d027..14c0d7a 100644 --- a/tests/test_pipeable.py +++ b/tests/test_pipeable.py @@ -19,6 +19,7 @@ from analysisrun.pipeable import ( ManualInput, + ProcessedInputs, create_image_analysis_results_input_model, entity_filter, image_analysis_result_spec, @@ -1133,20 +1134,26 @@ def test_run_analysis_with_preprocess_sequential_with_manual_input(monkeypatch): def preprocess(args): calls["count"] += 1 df = args.image_analysis_results["activity_spots"] - return {"row_count": int(len(df)), "threshold": int(args.params.threshold)} + with pytest.raises(TypeError): + args.image_analysis_results["other"] = df + df["DoubleValue"] = df["Value"] * 2 + return ProcessedInputs( + image_analysis_results={"activity_spots": df}, + extra={"row_count": int(len(df)), "threshold": int(args.params.threshold)}, + ) def analyze(args): df = args.image_analysis_results.activity_spots.data return pd.Series( { - "total_value": int(df["Value"].sum()), - "row_count": args.preprocessed_data["row_count"], + "total_value": int(df["DoubleValue"].sum()), + "row_count": args.extra["row_count"], } ) def postprocess(args): df = args.analysis_results.copy() - df["pre_threshold"] = args.preprocessed_data["threshold"] + df["pre_threshold"] = args.extra["threshold"] return df result_df = ctx.run_analysis_with_preprocess( @@ -1222,12 +1229,17 @@ def fake_run_stream(entrypoint_path, tar_buf, mode, on_image): stdout=stdout_buf, ) - def preprocess(_): - return {"multiplier": 3} + def preprocess(args): + return ProcessedInputs( + image_analysis_results={ + name: df for name, df in args.image_analysis_results.items() + }, + extra={"multiplier": 3}, + ) def postprocess(args): df = args.analysis_results.copy() - df["scaled"] = df["total_value"] * args.preprocessed_data["multiplier"] + df["scaled"] = df["total_value"] * args.extra["multiplier"] return df with pytest.raises(SystemExit) as excinfo: @@ -1312,21 +1324,23 @@ def fake_run_stream(entrypoint_path, tar_buf, mode, on_image): def postprocess(args): df = args.analysis_results.copy() df["scaled"] = df.apply( - lambda row: ( - row["total_value"] - * args.preprocessed_data["multipliers"][row["data"]] - ), + lambda row: row["total_value"] * args.extra["multipliers"][row["data"]], axis=1, ) return df def preprocess(args): calls["count"] += 1 - return { - "multipliers": { - data_name: 5 + i * 2 for i, data_name in enumerate(args.targets) - } - } + return ProcessedInputs( + image_analysis_results={ + name: df for name, df in args.image_analysis_results.items() + }, + extra={ + "multipliers": { + data_name: 5 + i * 2 for i, data_name in enumerate(args.targets) + } + }, + ) result_df = ctx.run_analysis_with_preprocess( preprocess=preprocess,