Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 95 additions & 7 deletions pipeline/the4kagent_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
)
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
from utils.scorer import calculate_cos_dist, calculate_niqe
from utils.action_gain_predictor import ActionGainRecord, append_action_gain_record
from .profile_loader import load_profile_config


Expand Down Expand Up @@ -111,6 +112,7 @@ def _init_state(self) -> None:
# }
},
},
"action_gain_logs": [],
}
self.cur_node = self.work_mem["tree"]
self.image_description = ""
Expand Down Expand Up @@ -167,6 +169,9 @@ def _config(

self.fast_4k = self.profile.get("Fast4K", False)
self.fast4k_side_thres = self.profile.get("Fast4kSideThres", 1024)

# predictive action gain logging
self.action_gain_logging = self.profile.get("ActionGainLogging", True)

self.project_root = Path(__file__).resolve().parent.parent # 4kagent dir path

Expand Down Expand Up @@ -869,7 +874,9 @@ def execute_subtask(self, cache: Optional[Path]) -> bool:
break

else:
best_img_path, best_img_score = self.evaluate_tool_result_onetime(res_degra_level_dict[self.reflect_by])
best_img_path, best_img_score = self.evaluate_tool_result_onetime(
res_degra_level_dict[self.reflect_by], subtask=subtask
)

best_tool_name = self._get_name_stem(best_img_path.parents[1].name)
self.workflow_logger.info(f"Best tool: {best_tool_name}")
Expand Down Expand Up @@ -911,32 +918,48 @@ def execute_subtask(self, cache: Optional[Path]) -> bool:
return success


def evaluate_tool_result_onetime(self, candidates: list[Path]) -> tuple[Path, float]:
def evaluate_tool_result_onetime(
self, candidates: list[Path], *, subtask: Optional[str] = None
) -> tuple[Path, float]:
if not candidates:
raise ValueError("`candidates` is empty.")

task_folder = candidates[0].parents[2] if len(candidates[0].parents) > 2 else candidates[0].parent
candidates_tmp_dir = os.path.join(str(task_folder), "tmp")
os.makedirs(candidates_tmp_dir, exist_ok=True)

candidate_paths = []
candidate_paths: list[str] = []
candidate_tool_names: list[str] = []
for i, cand in enumerate(candidates):
tool_name = self._get_name_stem(cand.parents[1].name) if len(cand.parents) > 1 else f"tool_{i:02d}"
new_filename = f"image_{tool_name}.png"
new_path = os.path.join(candidates_tmp_dir, new_filename)
shutil.copy(str(cand), new_path)
candidate_paths.append(str(cand))
candidate_tool_names.append(tool_name)

if self.reflect_by == "hpsv2+metric":
metric_scores = [compute_iqa_metric_score(p) for p in candidate_paths]
# metric_scores = compute_iqa_metric_score_batch(candidate_paths)
needs_metric_scores = self.reflect_by == "hpsv2+metric" or self.action_gain_logging
metric_scores = (
compute_iqa_metric_score_batch(candidate_paths)
if needs_metric_scores
else None
)

q_before = None
if self.action_gain_logging:
try:
q_before = compute_iqa_metric_score(self.cur_node["img_path"])
except Exception as exc: # pragma: no cover - optional logging
self.workflow_logger.warning(
f"Failed to compute q_before for action-gain logging: {exc}"
)

import hpsv2
hps_scores = hpsv2.score(candidate_paths, self.image_description, hps_version="v2.1")
hps_scores = [float(s) for s in hps_scores]

if self.reflect_by == "hpsv2+metric":
result = [h + m for h, m in zip(hps_scores, metric_scores)]
result = [h + (m if m is not None else 0.0) for h, m in zip(hps_scores, metric_scores)]
out_file = os.path.join(candidates_tmp_dir, "result_scores_with_metrics.txt")
with open(out_file, "w", encoding="utf-8") as f:
for cand, h, m, o in zip(candidates, hps_scores, metric_scores, result):
Expand All @@ -952,13 +975,78 @@ def evaluate_tool_result_onetime(self, candidates: list[Path]) -> tuple[Path, fl

best_idx, best_score = max(enumerate(result), key=lambda x: x[1])
best_image = candidates[best_idx]

self._log_action_gain_records(
subtask=subtask,
tool_names=candidate_tool_names,
candidate_paths=candidate_paths,
metric_scores=metric_scores,
hps_scores=hps_scores,
overall_scores=result,
q_before=q_before,
best_idx=best_idx,
)
# Release memory
del hpsv2
del result
gc.collect()
torch.cuda.empty_cache()

return best_image, float(best_score)


def _log_action_gain_records(
self,
*,
subtask: Optional[str],
tool_names: list[str],
candidate_paths: list[str],
metric_scores: Optional[list[Optional[float]]],
hps_scores: list[float],
overall_scores: list[float],
q_before: Optional[float],
best_idx: int,
) -> None:
"""Persist ΔQ samples for downstream predictive modelling."""

if not self.action_gain_logging:
return

metadata_base = {
"image_description": self.image_description,
"reflect_mode": self.reflect_by,
}

img_before = str(self.cur_node["img_path"])
for idx, (tool_name, cand_path) in enumerate(zip(tool_names, candidate_paths)):
metric_score = metric_scores[idx] if metric_scores else None
q_after = metric_score
delta_q = (
(q_after - q_before)
if (q_after is not None and q_before is not None)
else None
)
record = ActionGainRecord(
subtask=subtask,
tool_name=tool_name,
img_before=img_before,
img_after=cand_path,
q_before=q_before,
q_after=q_after,
delta_q=delta_q,
overall_score=overall_scores[idx] if overall_scores else None,
hps_score=hps_scores[idx] if hps_scores else None,
metric_score=metric_score,
is_selected=idx == best_idx,
metadata={**metadata_base, "candidate_rank": idx},
)
append_action_gain_record(self.work_mem, record)

self.workflow_logger.info(
"Logged action-to-quality stats for %d candidates (subtask=%s)",
len(candidate_paths),
subtask,
)


def evaluate_tool_result_by_gpt4v(
Expand Down
72 changes: 72 additions & 0 deletions utils/action_gain_predictor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
"""Utilities for logging and exporting action-to-quality gain samples.

This module does not implement the predictive model itself yet, but it
standardizes the structure of ΔQ (quality gain) records collected during
4KAgent runs so that downstream training scripts can consume a single
schema.
"""

from __future__ import annotations

from dataclasses import asdict, dataclass, field
from pathlib import Path
from typing import Any, Dict, List, Optional


@dataclass
class ActionGainRecord:
"""Single training/evaluation sample for the ΔQ predictor.

Attributes:
subtask: High-level restoration subtask (e.g. "denoising").
tool_name: Name of the invoked tool.
img_before: Path to the image before executing the tool.
img_after: Path to the candidate image generated by the tool.
q_before: Numeric IQA score before the action.
q_after: Numeric IQA score after the action.
delta_q: Difference between ``q_after`` and ``q_before``.
overall_score: Score used by Q-MoE (e.g. HPS+metric) for selection.
hps_score: Raw HPSv2 score for the candidate.
metric_score: Raw NR-IQA metric score for the candidate.
is_selected: Whether this candidate was chosen for the next step.
metadata: Additional contextual information (e.g. textual
description of degradations) stored as a JSON-serialisable
dictionary.
"""

subtask: Optional[str]
tool_name: str
img_before: str
img_after: str
q_before: Optional[float]
q_after: Optional[float]
delta_q: Optional[float]
overall_score: Optional[float]
hps_score: Optional[float]
metric_score: Optional[float]
is_selected: bool
metadata: Dict[str, Any] = field(default_factory=dict)

def to_dict(self) -> Dict[str, Any]:
return asdict(self)


def append_action_gain_record(work_mem: Dict[str, Any], record: ActionGainRecord) -> None:
"""Append a record to ``work_mem`` in-place.

The work memory summary is periodically dumped to ``summary.json``, so
downstream scripts can simply parse that file to reconstruct a dataset.
"""

logs: List[Dict[str, Any]] = work_mem.setdefault("action_gain_logs", [])
logs.append(record.to_dict())


def export_action_gain_logs(records: List[Dict[str, Any]], output_path: Path) -> None:
"""Write raw records to disk."""

import json

output_path.parent.mkdir(parents=True, exist_ok=True)
with output_path.open("w", encoding="utf-8") as f:
json.dump(records, f, indent=2)