-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
253 lines (216 loc) · 10.5 KB
/
main.py
File metadata and controls
253 lines (216 loc) · 10.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
import numpy as np
from config import Config
from llm_interface import LLMInterface
from agents.project_manager import ProjectManager
from agents.specialist_agent import SpecialistAgent
from agents.integrator import IntegratorAgent
from agents.expert_programer import ExpertProgrammer
from dsl.common.io_utils import load_task, create_composite_task_image
from utils.logger import get_logger
import ast
import json
def execute_generated_code(code_str: str, input_grid: np.ndarray, expected_output: np.ndarray):
"""
Executes the generated code string and verifies the solution.
"""
try:
# Create a scope for the execution with DSL functions available
import sys
import os
# Add the DSL modules to the path so they can be imported
dsl_path = os.path.join(os.path.dirname(__file__), 'dsl')
if dsl_path not in sys.path:
sys.path.insert(0, dsl_path)
# Import commonly used libraries and make DSL modules available
import scipy
import dsl
import dsl.color
import dsl.geometry
import dsl.common
# Import individual functions for backward compatibility
from dsl.geometry.move import move
from dsl.geometry.rotate import rotate_grid
from dsl.color.recolor import recolor
from dsl.common.find_objects import find_objects
local_scope = {
"np": np,
"scipy": scipy,
"dsl": dsl,
"move": move,
"rotate_grid": rotate_grid,
"recolor": recolor,
"find_objects": find_objects
}
exec(code_str, local_scope)
# Prefer explicit function names to avoid picking helper callables by accident
preferred_function_names = ["transform_grid", "solve_puzzle"]
solution_func = None
for candidate in preferred_function_names:
if candidate in local_scope and callable(local_scope[candidate]):
solution_func = local_scope[candidate]
break
if solution_func is None:
# Fallback: try to pick a non-internal callable not from DSL imports
excluded_names = {'np', 'scipy', 'dsl', 'move', 'rotate_grid', 'recolor', 'find_objects'}
func_candidates = [name for name in local_scope
if callable(local_scope[name])
and name not in excluded_names
and not name.startswith('_')]
if not func_candidates:
return False, {"error": "No solution function found in generated code", "code": code_str}
solution_func = local_scope[func_candidates[0]]
actual_output = solution_func(input_grid)
print(input_grid)
print(actual_output)
print(expected_output)
if np.array_equal(actual_output, expected_output):
return True, None
else:
error_context = dict(
input_grid=input_grid,
expected_output=expected_output,
actual_output=actual_output,
code_str=code_str
)
return False, error_context
except Exception as e:
return False, {"error": str(e), "code": code_str}
def run_workshop_flow(*args, **kwargs) -> bool:
# Toolsmith/QA workflow removed in NL-only pipeline
return False
def main():
"""Main entry point for the agentic neural network."""
logger = get_logger("main")
try:
# Validate config eagerly (ensures paths and envs are set)
Config.validate()
# --- Task Configuration ---
arc_task_path = Config.ARC_DATA_DIR / Config.DEFAULT_TASK
# --- Initialization of all agents ---
logger.info("Initializing Hive-Mind Agent Network...")
llm_client = LLMInterface()
project_manager = ProjectManager(llm_interface=llm_client)
integrator = IntegratorAgent(llm_interface=llm_client)
expert_programmer = ExpertProgrammer(llm_interface=llm_client)
# --- 1. Load Task ---
task = load_task(str(arc_task_path))
task_id = arc_task_path.stem
temp_dir = Config.TEMP_DIR / task_id
input_grids = [np.array(pair['input']) for pair in task['train']]
output_grids = [np.array(pair['output']) for pair in task['train']]
temp_dir.mkdir(exist_ok=True)
# Create a single composite image showing all training examples
composite_image_b64 = create_composite_task_image(task['train'], temp_dir=str(temp_dir), problem_id=task_id)
# --- Single Execution (no loop) ---
previous_loop_info = None
manager_plan = None
# 2. Project Manager Decomposes Task
logger.info(f"Loading Task: {arc_task_path.name}")
logger.info("1. Project Manager is decomposing the task...")
required_experts, manager_plan = project_manager.decompose_task(
images=[composite_image_b64],
task_data=task['train'],
previous_loop_info=previous_loop_info
)
if not required_experts:
logger.error("Project Manager failed. Aborting.")
return
logger.info(f"Plan received. Required specialists: {required_experts}")
# 3. Specialists Generate NL Partial Plans
logger.info("2. Specialists are generating NL partial plans...")
partial_plans = []
all_specialist_contexts = {}
for expert_name in required_experts:
specialist = SpecialistAgent(llm_interface=llm_client, name=expert_name)
all_specialist_contexts[expert_name] = specialist.context_data.get("transformations", {})
response_data = specialist.create_partial_plan(
task_examples=task['train'],
previous_loop_info=previous_loop_info,
manager_plan=manager_plan
)
if not response_data:
logger.error(f"Specialist '{expert_name}' failed to provide a valid response.")
continue
print(response_data)
if "nl_steps" in response_data:
partial_plans.append({"specialist": expert_name, "plan": response_data})
# 4. Integrator Synthesizes Final Algorithm
logger.info("3. Integrator is synthesizing the NL final plan...")
curated_tool_docs = all_specialist_contexts
final_algorithm = integrator.synthesize_plan(
specialist_contributions=partial_plans,
tool_documentation=curated_tool_docs,
input_grids=input_grids,
output_grids=output_grids,
previous_loop_info=previous_loop_info,
manager_plan=manager_plan
)
if not final_algorithm:
logger.error("Integrator failed. Aborting.")
return
logger.info("Final conceptual algorithm synthesized.")
# 4. Expert Programmer writes numpy-only solution code with retry loop
logger.info("4. Expert Programmer is generating the numpy-only solution code...")
test_input_grid = np.array(task['test'][0]['input'])
expected_test_output = np.array(task['test'][0]['output'])
MAX_RETRY_ATTEMPTS = 5
previous_failure_info = None
for attempt in range(1, MAX_RETRY_ATTEMPTS + 1):
logger.info(f" Attempt {attempt}/{MAX_RETRY_ATTEMPTS}")
solution_code = expert_programmer.generate_code(
conceptual_algorithm=final_algorithm,
tool_documentation=curated_tool_docs,
input_grid=test_input_grid,
training_examples=task['train'],
previous_loop_info=previous_failure_info,
manager_plan=manager_plan
)
if not solution_code:
logger.error(f"Expert Programmer failed to produce code on attempt {attempt}.")
if attempt == MAX_RETRY_ATTEMPTS:
logger.error("All retry attempts exhausted. Aborting.")
return
continue
logger.info("5. Executing and verifying the generated code...")
success, error_info = execute_generated_code(solution_code, test_input_grid, expected_test_output)
if success:
logger.info("🎉 Verification successful!")
print("\n--- VALIDATION ---")
print(f"SUCCESS: {success}")
print(f"ATTEMPTS USED: {attempt}/{MAX_RETRY_ATTEMPTS}")
print("\n--- INPUT ---")
print(json.dumps(test_input_grid.tolist(), indent=2))
print("\n--- PREDICTED OUTPUT ---")
print(json.dumps(error_info['actual_output'].tolist(), indent=2) if error_info and 'actual_output' in error_info else "N/A")
print("\n--- EXPECTED OUTPUT ---")
print(json.dumps(expected_test_output.tolist(), indent=2))
return
else:
logger.warning(f"Verification failed on attempt {attempt}.")
# Prepare failure info for next attempt
previous_failure_info = {
"attempt": attempt,
"failed_code": solution_code,
"error_details": error_info,
"expected_output": expected_test_output.tolist(),
"test_input": test_input_grid.tolist()
}
if attempt == MAX_RETRY_ATTEMPTS:
logger.error("All retry attempts exhausted. Final failure.")
print("\n--- FINAL VALIDATION ---")
print(f"SUCCESS: {success}")
print(f"ATTEMPTS USED: {attempt}/{MAX_RETRY_ATTEMPTS}")
print("\n--- INPUT ---")
print(json.dumps(test_input_grid.tolist(), indent=2))
print("\n--- PREDICTED OUTPUT ---")
print(json.dumps(error_info['actual_output'].tolist(), indent=2) if error_info and 'actual_output' in error_info else "N/A")
print("\n--- EXPECTED OUTPUT ---")
print(json.dumps(expected_test_output.tolist(), indent=2))
return
else:
logger.info(f"Preparing retry attempt {attempt + 1} with failure information...")
except Exception as e:
logger.error(f"Critical error in main execution: {e}")
raise
if __name__ == '__main__':
main()