Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 36 additions & 4 deletions researchclaw/experiment/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,12 +554,16 @@ def check_class_quality(all_files: dict[str, str]) -> list[str]:

non_dunder = [m for m in methods if not m.startswith("__")]

has_explicit_bases = bool(node.bases)

class_info[f"{fname}:{cls_name}"] = {
"methods": methods,
"non_dunder": non_dunder,
"body_lines": body_lines,
"file": fname,
"has_forward_new_module": has_forward_new_module,
"class_name": cls_name,
"has_explicit_bases": has_explicit_bases,
}

# --- Check 1: Empty or trivial class ---
Expand All @@ -570,7 +574,11 @@ def check_class_quality(all_files: dict[str, str]) -> list[str]:
)

# --- Check 2: Too few methods for an algorithm class ---
if body_lines > 5 and len(non_dunder) < 2:
if (
body_lines > 5
and len(non_dunder) < 2
and not has_explicit_bases
):
warnings.append(
f"[{fname}] Class '{cls_name}' has only {len(non_dunder)} "
f"non-dunder method(s) — algorithm classes should have at "
Expand All @@ -585,13 +593,36 @@ def check_class_quality(all_files: dict[str, str]) -> list[str]:
f"Move to __init__() and register as submodules."
)

# --- Check 4: Duplicate class implementations ---
# --- Check 4: Duplicate class names across files ---
duplicated_class_names: set[str] = set()
classes_by_name: dict[str, list[dict[str, Any]]] = {}
for info in class_info.values():
classes_by_name.setdefault(str(info["class_name"]), []).append(info)

for cls_name, entries in classes_by_name.items():
non_trivial = [entry for entry in entries if int(entry["body_lines"]) > 5]
files = sorted({str(entry["file"]) for entry in non_trivial})
if len(files) >= 2:
duplicated_class_names.add(cls_name)
warnings.append(
f"Class '{cls_name}' is defined in multiple files "
f"({', '.join(files)}). Keep each algorithm/helper class in one "
f"canonical module and import it elsewhere instead of duplicating "
f"the definition."
)

# --- Check 5: Duplicate class implementations ---
# Compare class body hashes to find copy-paste variants
class_names = list(class_info.keys())
for i, name_a in enumerate(class_names):
info_a = class_info[name_a]
for name_b in class_names[i + 1:]:
info_b = class_info[name_b]
if (
str(info_a["class_name"]) == str(info_b["class_name"])
and str(info_a["class_name"]) in duplicated_class_names
):
continue
if (
info_a["body_lines"] > 5
and info_b["body_lines"] > 5
Expand All @@ -606,7 +637,7 @@ def check_class_quality(all_files: dict[str, str]) -> list[str]:
f"may be copy-paste variants with no real algorithmic difference"
)

# --- Check 5: Ablation subclasses must override with different logic ---
# --- Check 6: Ablation subclasses must override with different logic ---
# Parse inheritance relationships and compare method ASTs
for fname_code, code in all_files.items():
if not fname_code.endswith(".py"):
Expand Down Expand Up @@ -673,7 +704,7 @@ def check_class_quality(all_files: dict[str, str]) -> list[str]:
# (new methods that parent doesn't have)
pass

# --- Check 6: Ablation subclass must override >=1 parent method ---
# --- Check 7: Ablation subclass must override >=1 parent method ---
_lname = cls_name.lower()
if ("ablation" in _lname or "no_" in _lname or "without" in _lname):
parent_non_dunder = {
Expand Down Expand Up @@ -1111,4 +1142,5 @@ def deep_validate_files(
continue
warnings.extend(check_variable_scoping(code, fname))
warnings.extend(check_api_correctness(code, fname))
warnings.extend(check_undefined_calls(code, fname))
return warnings
86 changes: 86 additions & 0 deletions tests/test_rc_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3360,3 +3360,89 @@ def test_writes_json_to_stage_dir(self, tmp_path: Path) -> None:
assert "section_analysis" in data
assert "overall_warnings" in data
assert "revision_directives" in data


class TestExperimentValidatorPrecision:
def test_deep_validation_detects_undefined_helper_calls(self) -> None:
from researchclaw.experiment.validator import deep_validate_files

issues = deep_validate_files(
{
"main.py": (
"def main():\n"
" create_empty_csv('tmp.csv', ['a'])\n\n"
"if __name__ == '__main__':\n"
" main()\n"
)
}
)

assert any(
"Call to undefined function 'create_empty_csv()'" in issue
for issue in issues
)

def test_deep_validation_allows_inherited_single_core_method_subclass(
self,
) -> None:
from researchclaw.experiment.validator import deep_validate_files

issues = deep_validate_files(
{
"main.py": (
"class BaseVerifier:\n"
" def __init__(self, scale=1.0):\n"
" self.scale = float(scale)\n\n"
"class ChildVerifier(BaseVerifier):\n"
" def predict(self, value):\n"
" total = value * self.scale\n"
" shifted = total + 1.0\n"
" centered = shifted - 0.5\n"
" bounded = max(centered, 0.0)\n"
" return {'score': bounded}\n"
)
}
)

assert not any(
"Class 'ChildVerifier' has only 1 non-dunder method" in issue
for issue in issues
)

def test_deep_validation_detects_duplicate_algorithm_classes_across_files(
self,
) -> None:
from researchclaw.experiment.validator import deep_validate_files

issues = deep_validate_files(
{
"main.py": (
"class DuplicateVerifier:\n"
" def __init__(self, bias=0.0):\n"
" self.bias = float(bias)\n\n"
" def predict(self, value):\n"
" shifted = value + self.bias\n"
" bounded = max(shifted, 0.0)\n"
" return {'score': bounded}\n"
),
"models.py": (
"class DuplicateVerifier:\n"
" def __init__(self, bias=0.0):\n"
" self.bias = float(bias)\n\n"
" def predict(self, value):\n"
" shifted = value + self.bias\n"
" bounded = max(shifted, 0.0)\n"
" return {'score': bounded}\n"
),
}
)

assert any(
"Class 'DuplicateVerifier' is defined in multiple files" in issue
for issue in issues
)
assert not any(
"Classes 'DuplicateVerifier' and 'DuplicateVerifier' have identical"
in issue
for issue in issues
)