From 9e73f69191ff058f54e33f1b7fd985450b3f8d25 Mon Sep 17 00:00:00 2001 From: CJ <63952042+Ahermit01@users.noreply.github.com> Date: Wed, 19 Nov 2025 21:23:21 +0800 Subject: [PATCH] Add action gain logging for tool evaluation --- pipeline/the4kagent_pipeline.py | 102 +++++++++++++++++++++++++++++--- utils/action_gain_predictor.py | 72 ++++++++++++++++++++++ 2 files changed, 167 insertions(+), 7 deletions(-) create mode 100644 utils/action_gain_predictor.py diff --git a/pipeline/the4kagent_pipeline.py b/pipeline/the4kagent_pipeline.py index edb7279..f446aa0 100644 --- a/pipeline/the4kagent_pipeline.py +++ b/pipeline/the4kagent_pipeline.py @@ -31,6 +31,7 @@ ) from facexlib.utils.face_restoration_helper import FaceRestoreHelper from utils.scorer import calculate_cos_dist, calculate_niqe +from utils.action_gain_predictor import ActionGainRecord, append_action_gain_record from .profile_loader import load_profile_config @@ -111,6 +112,7 @@ def _init_state(self) -> None: # } }, }, + "action_gain_logs": [], } self.cur_node = self.work_mem["tree"] self.image_description = "" @@ -167,6 +169,9 @@ def _config( self.fast_4k = self.profile.get("Fast4K", False) self.fast4k_side_thres = self.profile.get("Fast4kSideThres", 1024) + + # predictive action gain logging + self.action_gain_logging = self.profile.get("ActionGainLogging", True) self.project_root = Path(__file__).resolve().parent.parent # 4kagent dir path @@ -869,7 +874,9 @@ def execute_subtask(self, cache: Optional[Path]) -> bool: break else: - best_img_path, best_img_score = self.evaluate_tool_result_onetime(res_degra_level_dict[self.reflect_by]) + best_img_path, best_img_score = self.evaluate_tool_result_onetime( + res_degra_level_dict[self.reflect_by], subtask=subtask + ) best_tool_name = self._get_name_stem(best_img_path.parents[1].name) self.workflow_logger.info(f"Best tool: {best_tool_name}") @@ -911,7 +918,9 @@ def execute_subtask(self, cache: Optional[Path]) -> bool: return success - def evaluate_tool_result_onetime(self, candidates: list[Path]) -> tuple[Path, float]: + def evaluate_tool_result_onetime( + self, candidates: list[Path], *, subtask: Optional[str] = None + ) -> tuple[Path, float]: if not candidates: raise ValueError("`candidates` is empty.") @@ -919,24 +928,38 @@ def evaluate_tool_result_onetime(self, candidates: list[Path]) -> tuple[Path, fl candidates_tmp_dir = os.path.join(str(task_folder), "tmp") os.makedirs(candidates_tmp_dir, exist_ok=True) - candidate_paths = [] + candidate_paths: list[str] = [] + candidate_tool_names: list[str] = [] for i, cand in enumerate(candidates): tool_name = self._get_name_stem(cand.parents[1].name) if len(cand.parents) > 1 else f"tool_{i:02d}" new_filename = f"image_{tool_name}.png" new_path = os.path.join(candidates_tmp_dir, new_filename) shutil.copy(str(cand), new_path) candidate_paths.append(str(cand)) + candidate_tool_names.append(tool_name) - if self.reflect_by == "hpsv2+metric": - metric_scores = [compute_iqa_metric_score(p) for p in candidate_paths] - # metric_scores = compute_iqa_metric_score_batch(candidate_paths) + needs_metric_scores = self.reflect_by == "hpsv2+metric" or self.action_gain_logging + metric_scores = ( + compute_iqa_metric_score_batch(candidate_paths) + if needs_metric_scores + else None + ) + + q_before = None + if self.action_gain_logging: + try: + q_before = compute_iqa_metric_score(self.cur_node["img_path"]) + except Exception as exc: # pragma: no cover - optional logging + self.workflow_logger.warning( + f"Failed to compute q_before for action-gain logging: {exc}" + ) import hpsv2 hps_scores = hpsv2.score(candidate_paths, self.image_description, hps_version="v2.1") hps_scores = [float(s) for s in hps_scores] if self.reflect_by == "hpsv2+metric": - result = [h + m for h, m in zip(hps_scores, metric_scores)] + result = [h + (m if m is not None else 0.0) for h, m in zip(hps_scores, metric_scores)] out_file = os.path.join(candidates_tmp_dir, "result_scores_with_metrics.txt") with open(out_file, "w", encoding="utf-8") as f: for cand, h, m, o in zip(candidates, hps_scores, metric_scores, result): @@ -952,6 +975,17 @@ def evaluate_tool_result_onetime(self, candidates: list[Path]) -> tuple[Path, fl best_idx, best_score = max(enumerate(result), key=lambda x: x[1]) best_image = candidates[best_idx] + + self._log_action_gain_records( + subtask=subtask, + tool_names=candidate_tool_names, + candidate_paths=candidate_paths, + metric_scores=metric_scores, + hps_scores=hps_scores, + overall_scores=result, + q_before=q_before, + best_idx=best_idx, + ) # Release memory del hpsv2 del result @@ -959,6 +993,60 @@ def evaluate_tool_result_onetime(self, candidates: list[Path]) -> tuple[Path, fl torch.cuda.empty_cache() return best_image, float(best_score) + + + def _log_action_gain_records( + self, + *, + subtask: Optional[str], + tool_names: list[str], + candidate_paths: list[str], + metric_scores: Optional[list[Optional[float]]], + hps_scores: list[float], + overall_scores: list[float], + q_before: Optional[float], + best_idx: int, + ) -> None: + """Persist ΔQ samples for downstream predictive modelling.""" + + if not self.action_gain_logging: + return + + metadata_base = { + "image_description": self.image_description, + "reflect_mode": self.reflect_by, + } + + img_before = str(self.cur_node["img_path"]) + for idx, (tool_name, cand_path) in enumerate(zip(tool_names, candidate_paths)): + metric_score = metric_scores[idx] if metric_scores else None + q_after = metric_score + delta_q = ( + (q_after - q_before) + if (q_after is not None and q_before is not None) + else None + ) + record = ActionGainRecord( + subtask=subtask, + tool_name=tool_name, + img_before=img_before, + img_after=cand_path, + q_before=q_before, + q_after=q_after, + delta_q=delta_q, + overall_score=overall_scores[idx] if overall_scores else None, + hps_score=hps_scores[idx] if hps_scores else None, + metric_score=metric_score, + is_selected=idx == best_idx, + metadata={**metadata_base, "candidate_rank": idx}, + ) + append_action_gain_record(self.work_mem, record) + + self.workflow_logger.info( + "Logged action-to-quality stats for %d candidates (subtask=%s)", + len(candidate_paths), + subtask, + ) def evaluate_tool_result_by_gpt4v( diff --git a/utils/action_gain_predictor.py b/utils/action_gain_predictor.py new file mode 100644 index 0000000..e0a8af7 --- /dev/null +++ b/utils/action_gain_predictor.py @@ -0,0 +1,72 @@ +"""Utilities for logging and exporting action-to-quality gain samples. + +This module does not implement the predictive model itself yet, but it +standardizes the structure of ΔQ (quality gain) records collected during +4KAgent runs so that downstream training scripts can consume a single +schema. +""" + +from __future__ import annotations + +from dataclasses import asdict, dataclass, field +from pathlib import Path +from typing import Any, Dict, List, Optional + + +@dataclass +class ActionGainRecord: + """Single training/evaluation sample for the ΔQ predictor. + + Attributes: + subtask: High-level restoration subtask (e.g. "denoising"). + tool_name: Name of the invoked tool. + img_before: Path to the image before executing the tool. + img_after: Path to the candidate image generated by the tool. + q_before: Numeric IQA score before the action. + q_after: Numeric IQA score after the action. + delta_q: Difference between ``q_after`` and ``q_before``. + overall_score: Score used by Q-MoE (e.g. HPS+metric) for selection. + hps_score: Raw HPSv2 score for the candidate. + metric_score: Raw NR-IQA metric score for the candidate. + is_selected: Whether this candidate was chosen for the next step. + metadata: Additional contextual information (e.g. textual + description of degradations) stored as a JSON-serialisable + dictionary. + """ + + subtask: Optional[str] + tool_name: str + img_before: str + img_after: str + q_before: Optional[float] + q_after: Optional[float] + delta_q: Optional[float] + overall_score: Optional[float] + hps_score: Optional[float] + metric_score: Optional[float] + is_selected: bool + metadata: Dict[str, Any] = field(default_factory=dict) + + def to_dict(self) -> Dict[str, Any]: + return asdict(self) + + +def append_action_gain_record(work_mem: Dict[str, Any], record: ActionGainRecord) -> None: + """Append a record to ``work_mem`` in-place. + + The work memory summary is periodically dumped to ``summary.json``, so + downstream scripts can simply parse that file to reconstruct a dataset. + """ + + logs: List[Dict[str, Any]] = work_mem.setdefault("action_gain_logs", []) + logs.append(record.to_dict()) + + +def export_action_gain_logs(records: List[Dict[str, Any]], output_path: Path) -> None: + """Write raw records to disk.""" + + import json + + output_path.parent.mkdir(parents=True, exist_ok=True) + with output_path.open("w", encoding="utf-8") as f: + json.dump(records, f, indent=2)