diff --git a/.gitignore b/.gitignore
index 441facd..95e55f1 100644
--- a/.gitignore
+++ b/.gitignore
@@ -207,3 +207,9 @@ cython_debug/
marimo/_static/
marimo/_lsp/
__marimo__/
+
+
+#
+autobolt/
+*.db-*
+*.db
\ No newline at end of file
diff --git a/environment.yml b/environment.yml
index 8d41493..465a54d 100644
--- a/environment.yml
+++ b/environment.yml
@@ -17,4 +17,5 @@ dependencies:
- pandas
- "autobolt @ git+https://github.com/sriyanc2001/AutoBolt.git"
- black
- - pytest
\ No newline at end of file
+ - pytest
+ - sqlalchemy
\ No newline at end of file
diff --git a/src/autoboltagent/VLLMModelCustom.py b/src/autoboltagent/VLLMModelCustom.py
new file mode 100644
index 0000000..242f66a
--- /dev/null
+++ b/src/autoboltagent/VLLMModelCustom.py
@@ -0,0 +1,97 @@
+from smolagents import VLLMModel
+from smolagents.tools import Tool
+from smolagents.monitoring import TokenUsage
+from vllm.transformers_utils.tokenizer import get_tokenizer
+from smolagents.models import (
+ ChatMessage,
+ MessageRole,
+ remove_content_after_stop_sequences,
+
+)
+
+from typing import Any
+
+class VLLMModelCustom(VLLMModel):
+ def __init__(
+ self,
+ model_id,
+ model_kwargs: dict[str, Any] | None = None,
+ apply_chat_template_kwargs: dict[str, Any] | None = None,
+ sampling_params=None,
+ **kwargs,
+ ):
+ super().__init__(
+ model_id=model_id,
+ model_kwargs=model_kwargs,
+ apply_chat_template_kwargs=apply_chat_template_kwargs,
+ **kwargs
+ )
+ self.sampling_params = sampling_params
+
+ def generate(
+ self,
+ messages: list[ChatMessage | dict],
+ stop_sequences: list[str] | None = None,
+ response_format: dict[str, str] | None = None,
+ tools_to_call_from: list[Tool] | None = None,
+ **kwargs,
+ ) -> ChatMessage:
+ from vllm import SamplingParams # type: ignore
+ from vllm.sampling_params import StructuredOutputsParams # type: ignore
+
+ completion_kwargs = self._prepare_completion_kwargs(
+ messages=messages,
+ flatten_messages_as_text=(not self._is_vlm),
+ stop_sequences=stop_sequences,
+ tools_to_call_from=tools_to_call_from,
+ **kwargs,
+ )
+
+ prepared_stop_sequences = completion_kwargs.pop("stop", [])
+ messages = completion_kwargs.pop("messages")
+ tools = completion_kwargs.pop("tools", None)
+ completion_kwargs.pop("tool_choice", None)
+
+ if not self.sampling_params:
+ # Override the OpenAI schema for VLLM compatibility
+ structured_outputs = (
+ StructuredOutputsParams(json=response_format["json_schema"]["schema"]) if response_format else None
+ )
+
+
+ self.sampling_params = SamplingParams(
+ n=kwargs.get("n", 1),
+ temperature=kwargs.get("temperature", 0.0),
+ max_tokens=kwargs.get("max_tokens", 2048),
+ stop=prepared_stop_sequences,
+ structured_outputs=structured_outputs,
+ )
+
+
+ prompt = self.tokenizer.apply_chat_template(
+ messages,
+ tools=tools,
+ add_generation_prompt=True,
+ tokenize=False,
+ **self.apply_chat_template_kwargs,
+ )
+
+
+ out = self.model.generate(
+ prompt,
+ sampling_params=self.sampling_params,
+ **completion_kwargs,
+ )
+
+ output_text = out[0].outputs[0].text
+ if stop_sequences is not None and not self.supports_stop_parameter:
+ output_text = remove_content_after_stop_sequences(output_text, stop_sequences)
+ return ChatMessage(
+ role=MessageRole.ASSISTANT,
+ content=output_text,
+ raw={"out": output_text, "completion_kwargs": completion_kwargs},
+ token_usage=TokenUsage(
+ input_tokens=len(out[0].prompt_token_ids),
+ output_tokens=len(out[0].outputs[0].token_ids),
+ ),
+ )
\ No newline at end of file
diff --git a/src/autoboltagent/agents.py b/src/autoboltagent/agents.py
index 35a19fb..0fc6416 100644
--- a/src/autoboltagent/agents.py
+++ b/src/autoboltagent/agents.py
@@ -7,8 +7,10 @@
)
from .tools import AnalyticalTool, FiniteElementTool
+from .tools.logger import AgentLogger
-class GuessingAgent(smolagents.ToolCallingAgent):
+
+class GuessingAgent(smolagents.agents.ToolCallingAgent):
"""
An agent that makes guesses without using any tools.
@@ -16,7 +18,7 @@ class GuessingAgent(smolagents.ToolCallingAgent):
It is designed to provide initial estimates or solutions based on its knowledge and reasoning capabilities.
"""
- def __init__(self, model: smolagents.Model) -> None:
+ def __init__(self, model: smolagents.models.Model) -> None:
"""
Initializes a GuessingAgent that does not use any tools.
@@ -33,7 +35,7 @@ def __init__(self, model: smolagents.Model) -> None:
)
-class LowFidelityAgent(smolagents.ToolCallingAgent):
+class LowFidelityAgent(smolagents.agents.ToolCallingAgent):
"""
An agent that utilizes a low-fidelity analytical tool for bolted connection design.
@@ -41,24 +43,43 @@ class LowFidelityAgent(smolagents.ToolCallingAgent):
It is designed to provide solutions based on simplified models and assumptions, making it suitable for quick estimates and preliminary designs.
"""
- def __init__(self, model: smolagents.Model) -> None:
+ def __init__(self, model: smolagents.models.Model, agent_id: str, run_id: str, target_fos: float, agent_logger: AgentLogger|None = None, max_steps=20) -> None:
"""
Initializes a LowFidelityAgent that uses an analytical tool.
Args:
model: An instance of smolagents.Model to be used by the agent.
"""
+
+ self.agent_logger = agent_logger
+ self.agent_id = agent_id
+ self.run_id = run_id
+ self.target_fos = target_fos
+
+ callbacks = [self.log] if self.agent_logger else []
+
super().__init__(
name="LowFidelityAgent",
tools=[AnalyticalTool()],
add_base_tools=False,
model=model,
instructions=BASE_INSTRUCTIONS + TOOL_USING_INSTRUCTION,
+ step_callbacks = callbacks,
verbosity_level=2,
+ max_steps=max_steps
)
+ def log(self, step, agent):
+ if self.agent_logger and step.__class__.__name__ == "ActionStep":
+ self.agent_logger.log(
+ agent_id=self.agent_id,
+ run_id=self.run_id,
+ target_fos=self.target_fos,
+ action_step=step
+ )
+
-class HighFidelityAgent(smolagents.ToolCallingAgent):
+class HighFidelityAgent(smolagents.agents.ToolCallingAgent):
"""
An agent that utilizes a high-fidelity finite element analysis tool for bolted connection design.
@@ -66,7 +87,7 @@ class HighFidelityAgent(smolagents.ToolCallingAgent):
It is designed to provide accurate and reliable solutions based on comprehensive models, making it suitable for
"""
- def __init__(self, model: smolagents.Model) -> None:
+ def __init__(self, model: smolagents.models.Model) -> None:
"""
Initializes a HighFidelityAgent that uses a finite element tool.
@@ -83,7 +104,7 @@ def __init__(self, model: smolagents.Model) -> None:
)
-class DualFidelityAgent(smolagents.ToolCallingAgent):
+class DualFidelityAgent(smolagents.agents.ToolCallingAgent):
"""
An agent that utilizes both low-fidelity and high-fidelity tools for bolted connection design.
@@ -91,7 +112,7 @@ class DualFidelityAgent(smolagents.ToolCallingAgent):
It is designed to provide solutions that balance speed and accuracy by using the low-fidelity tool
"""
- def __init__(self, model: smolagents.Model) -> None:
+ def __init__(self, model: smolagents.models.Model) -> None:
"""
Initializes a DualFidelityAgent that uses both analytical and finite element tools.
diff --git a/src/autoboltagent/grammars.py b/src/autoboltagent/grammars.py
new file mode 100644
index 0000000..ce4dfcd
--- /dev/null
+++ b/src/autoboltagent/grammars.py
@@ -0,0 +1,74 @@
+# grammar that outputs low-fidelity tool call or final_answer
+low_fidelity_agent_grammar = r"""
+root ::= "\n" payload "\n"
+payload ::= fos | final
+
+fos ::= "{\"name\": \"analytical_fos_calculation\", \"arguments\": " fos_args "}}"
+fos_args ::= "{" ( num_field ", " dia_field ) | ( dia_field ", " num_field ) "}"
+num_field ::= "\"num_bolts\": " int
+dia_field ::= "\"bolt_diameter\": " number
+
+final ::= "{\"name\": \"final_answer\", \"arguments\": {\"answer\": " fos_args "}}}"
+
+int ::= digit+
+number ::= digit+ frac?
+frac ::= "." digit{1,4}
+
+digit ::= "0" | "1" | "2" | "3" | "4" | "5" | "6" | "7" | "8" | "9"
+"""
+
+# grammar that outputs high-fidelity tool call or final_answer
+high_fidelity_agent_grammar = r"""
+root ::= "\n" payload "\n"
+payload ::= fos | final
+
+fos ::= "{\"name\": \"fea_fos_calculation\", \"arguments\": " fos_args "}}"
+fos_args ::= "{" ( num_field ", " dia_field ) | ( dia_field ", " num_field ) "}"
+num_field ::= "\"num_bolts\": " int
+dia_field ::= "\"bolt_diameter\": " number
+
+final ::= "{\"name\": \"final_answer\", \"arguments\": {\"answer\": " fos_args "}}}"
+
+int ::= digit+
+number ::= digit+ frac?
+frac ::= "." digit{1,4}
+
+digit ::= "0" | "1" | "2" | "3" | "4" | "5" | "6" | "7" | "8" | "9"
+"""
+
+# grammar that outputs both types of tool call and final_answer
+dual_fidelity_agent_grammar = r"""
+root ::= "\n" payload "\n"
+payload ::= fos | final
+
+fos ::= "{\"name\": ("\"analytical_fos_calculation\"" | "\"fea_fos_calculation\""), \"arguments\": " fos_args "}}"
+fos_args ::= "{" ( num_field ", " dia_field ) | ( dia_field ", " num_field ) "}"
+num_field ::= "\"num_bolts\": " int
+dia_field ::= "\"bolt_diameter\": " number
+
+final ::= "{\"name\": \"final_answer\", \"arguments\": {\"answer\": " fos_args "}}}"
+
+int ::= digit+
+number ::= digit+ frac?
+frac ::= "." digit{1,4}
+
+digit ::= "0" | "1" | "2" | "3" | "4" | "5" | "6" | "7" | "8" | "9"
+"""
+
+low_fidelity_agent_grammar_original = r"""
+root ::= "\n" payload "\n"
+payload ::= fos | final
+
+fos ::= "{\"name\": \"analytical_fos_calculation\", \"arguments\": " fos_args "}}"
+fos_args ::= "{" ( num_field ", " dia_field ) | ( dia_field ", " num_field ) "}"
+num_field ::= "\"num_bolts\": " int
+dia_field ::= "\"bolt_diameter\": " number
+
+final ::= "{\"name\": \"final_answer\", \"arguments\": {\"answer\": " fos_args "}}"
+
+int ::= digit+
+number ::= digit+ frac?
+frac ::= "." digit{1,4}
+
+digit ::= "0" | "1" | "2" | "3" | "4" | "5" | "6" | "7" | "8" | "9"
+"""
\ No newline at end of file
diff --git a/src/autoboltagent/tools/high_fidelity_tool.py b/src/autoboltagent/tools/high_fidelity_tool.py
index 02a7bf9..52e9f73 100644
--- a/src/autoboltagent/tools/high_fidelity_tool.py
+++ b/src/autoboltagent/tools/high_fidelity_tool.py
@@ -1,10 +1,11 @@
import autobolt
import smolagents
+from typing import Dict, Any, Union, cast
from .inputs import INPUTS
-class FiniteElementTool(smolagents.Tool):
+class FiniteElementTool(smolagents.tools.Tool):
"""
A tool that calculates the factor of safety for a bolted connection using finite element analysis.
@@ -15,7 +16,8 @@ class FiniteElementTool(smolagents.Tool):
name = "fea_fos_calculation"
description = "Calculates the factor of safety using finite element analysis."
- inputs = INPUTS
+ input_type = dict[str, dict[str, Union[str, type, bool]]]
+ inputs: input_type = cast(input_type,INPUTS)
output_type = "number"
diff --git a/src/autoboltagent/tools/logger.py b/src/autoboltagent/tools/logger.py
new file mode 100644
index 0000000..e1e7db7
--- /dev/null
+++ b/src/autoboltagent/tools/logger.py
@@ -0,0 +1,126 @@
+from sqlalchemy import create_engine
+
+from sqlalchemy.orm import declarative_base, Session, sessionmaker, Mapped, mapped_column
+from contextlib import contextmanager
+
+from pathlib import Path
+
+from datetime import datetime, timezone
+
+Base = declarative_base()
+
+class Iteration(Base):
+ __tablename__ = "iterations"
+
+ id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
+ agent_id: Mapped[str] = mapped_column()
+ run_id: Mapped[str] = mapped_column()
+ iteration_no: Mapped[int] = mapped_column()
+
+ start_time: Mapped[datetime] = mapped_column(nullable=True)
+ end_time: Mapped[datetime] = mapped_column(nullable=True)
+
+ status: Mapped[str] = mapped_column(nullable=True)
+ tool_call: Mapped[str] = mapped_column(nullable=True)
+ observations: Mapped[str] = mapped_column(nullable=True)
+ target_fos: Mapped[float] = mapped_column(nullable=True)
+
+ failure_reason: Mapped[str] = mapped_column(nullable=True)
+ llm_output: Mapped[str] = mapped_column(nullable=True)
+ error_message: Mapped[str] = mapped_column(nullable=True)
+
+class AgentLogger:
+ _instance = None
+
+ def connect_to_db(self, db_url: str):
+ self.db_url = db_url
+ self.engine = create_engine(db_url, future=True, pool_pre_ping=True)
+
+ try:
+ with self.engine.connect() as conn:
+ conn.exec_driver_sql("PRAGMA journal_mode=WAL;")
+ conn.exec_driver_sql("PRAGMA synchronous=NORMAL;")
+
+ Base.metadata.create_all(self.engine)
+
+ self.db_session = sessionmaker(bind=self.engine, expire_on_commit=False, future=True)
+ except Exception as e:
+ raise IOError("Failed to connect to DB, check if file in use", repr(e))
+
+ def __new__(cls, db_url: str):
+ if not cls._instance:
+ cls._instance = super().__new__(cls)
+ cls._instance.db_url = None
+ cls._instance.engine = None
+ cls._instance.db_session = None
+ cls._instance.connect_to_db(db_url)
+ return cls._instance
+
+ @classmethod
+ def reset(cls):
+ if not cls._instance:
+ return
+
+ inst = cls._instance
+
+ if inst.engine:
+ inst.engine.dispose()
+
+ if inst.db_url:
+ for suffix in ("", "-wal", "-shm"):
+ file_path = Path(inst.db_url.replace("sqlite:///", "") + suffix)
+ file_path.unlink(missing_ok=True)
+
+ cls._instance = None
+
+ def log(
+ self,
+ run_id,
+ agent_id,
+ target_fos,
+ action_step
+ ):
+
+ iteration_no = action_step.step_number
+ start_dt = datetime.fromtimestamp(action_step.timing.start_time, tz=timezone.utc)
+ end_dt = datetime.fromtimestamp(action_step.timing.end_time, tz=timezone.utc)
+
+ error = getattr(action_step, "error", None)
+ tool_calls = getattr(action_step, "tool_calls", None)
+ observations = getattr(action_step, "observations", None)
+ llm_message = getattr(action_step, "model_output_message", None)
+ llm_output = getattr(llm_message, "content", None)
+
+ print(tool_calls)
+ print(error)
+ print(observations)
+ print(llm_message)
+ print(action_step.token_usage)
+
+ try:
+
+ with self.db_session() as session:
+ session.add(
+ Iteration(
+ run_id=run_id,
+ agent_id=agent_id,
+ iteration_no=iteration_no,
+ start_time=start_dt,
+ end_time=end_dt,
+ tool_call=str(tool_calls[0] if tool_calls else None),
+ observations=observations,
+ target_fos=target_fos,
+ llm_output=llm_output,
+ error_message=error.message if (error and error.message) else None
+ )
+ )
+ session.flush()
+ session.commit()
+
+ except Exception as e:
+ print("\n\n")
+ print(e)
+ print("\n\n")
+
+
+
\ No newline at end of file
diff --git a/src/autoboltagent/tools/low_fidelity_tool.py b/src/autoboltagent/tools/low_fidelity_tool.py
index e250a0e..b5a62f9 100644
--- a/src/autoboltagent/tools/low_fidelity_tool.py
+++ b/src/autoboltagent/tools/low_fidelity_tool.py
@@ -1,5 +1,7 @@
import smolagents
+from typing import Dict, Any, Union, cast
+
from .fastener_toolkit import (
get_joint_constant,
get_tensile_stress_area,
@@ -8,7 +10,7 @@
from .inputs import INPUTS
-class AnalyticalTool(smolagents.Tool):
+class AnalyticalTool(smolagents.tools.Tool):
"""
A tool that calculates the factor of safety for a bolted connection using analytical expressions.
@@ -18,8 +20,9 @@ class AnalyticalTool(smolagents.Tool):
name = "analytical_fos_calculation"
description = "Calculates the factor of safety using analytical expressions."
-
- inputs = INPUTS
+
+ input_type = dict[str, dict[str, Union[str, type, bool]]]
+ inputs: input_type = cast(input_type,INPUTS)
output_type = "number"
@@ -81,3 +84,82 @@ def forward(
f"The factor of safety for bolts is {bolt_fos:.2f} ({bolt_comparison}) and "
f"the factor of safety for plates is {plate_fos:.2f} ({plate_comparison})."
)
+
+
+class VerboseAnalyticalTool(smolagents.tools.Tool):
+ """
+ A tool that calculates the factor of safety for a bolted connection using analytical expressions.
+
+ This tool uses established engineering formulas to compute the factor of safety for a bolted connection
+ based on the provided parameters.
+ """
+
+ name = "analytical_fos_calculation"
+ description = "Calculates the factor of safety using analytical expressions."
+
+ inputs = {
+ "num_bolts": {
+ "type": "number",
+ "description": "Number of bolts used in the joint",
+ },
+ "bolt_diameter": {
+ "type": "number",
+ "description": "Diameter of the bolt in mm",
+ }
+ }
+
+ output_type = "object"
+
+ def __init__(self, joint_configuration: Dict[str, Any], tolerance: float = 0.1):
+ super().__init__()
+ self.tolerance = tolerance
+ self.desired_safety_factor = joint_configuration["desired_safety_factor"]
+ self.load = joint_configuration["load"]
+ self.preload = joint_configuration["preload"]
+ self.bolt_yield_strength = joint_configuration["bolt_yield_strength"]
+ self.bolt_elastic_modulus = joint_configuration["bolt_elastic_modulus"]
+ self.plate_thickness = joint_configuration["plate_thickness"]
+ self.plate_elastic_modulus = joint_configuration["plate_elastic_modulus"]
+ self.plate_yield_strength = joint_configuration["plate_yield_strength"]
+ self.pitch = joint_configuration["pitch"]
+
+ def forward(
+ self,
+ num_bolts: int,
+ bolt_diameter: float,
+ ) -> dict:
+
+ load_per_bolt = self.load / num_bolts
+ preload_per_bolt = self.preload / num_bolts
+ tensile_area = get_tensile_stress_area(bolt_diameter, self.pitch)
+
+ c = get_joint_constant(
+ bolt_diameter,
+ self.plate_thickness * 2,
+ self.plate_elastic_modulus,
+ self.bolt_elastic_modulus,
+ )
+
+ bolt_fos = bolt_yield_safety_factor(
+ c=c,
+ load=load_per_bolt,
+ preload=preload_per_bolt,
+ a_ts=tensile_area,
+ b_ys=self.bolt_yield_strength,
+ )
+
+ bearing_area = bolt_diameter * self.plate_thickness * num_bolts
+ bearing_stress = self.load / bearing_area
+ allowable_bearing_stress = 1.5 * self.plate_yield_strength
+ plate_fos = allowable_bearing_stress / bearing_stress
+
+ bolt_diff = bolt_fos - self.desired_safety_factor
+ plate_diff = plate_fos - self.desired_safety_factor
+
+ return {
+ "ok": (abs(bolt_diff) < self.tolerance) and (abs(plate_diff) < self.tolerance),
+ "bolt_fos": bolt_fos,
+ "bolt_diff": bolt_diff,
+ "plate_fos": plate_fos,
+ "plate_diff": plate_diff
+ }
\ No newline at end of file
diff --git a/src/autoboltagent/verbose_agents.py b/src/autoboltagent/verbose_agents.py
new file mode 100644
index 0000000..d8c9286
--- /dev/null
+++ b/src/autoboltagent/verbose_agents.py
@@ -0,0 +1,49 @@
+import smolagents
+
+from .verbose_prompts import (
+ TOOL_USING_INSTRUCTION,
+ SIMPLIFIED_TOOL_USING_INSTRUCTION,
+ BASE_INSTRUCTIONS,
+ INPUT_FORMAT,
+ OUTPUT_FORMAT,
+ LOW_FIDELITY_TOOL_INSTRUCTION,
+ EXAMPLE_TASK_INSTRUCTIONS
+)
+from .tools.low_fidelity_tool import VerboseAnalyticalTool
+
+from .tools.logger import AgentLogger
+
+
+class VerboseLowFidelityAgent(smolagents.agents.ToolCallingAgent):
+ def __init__(self, model: smolagents.models.Model, joint_configuration: dict, agent_id: str, run_id: str, target_fos: float, agent_logger: AgentLogger|None = None) -> None:
+
+ self.agent_logger = agent_logger
+ self.agent_id = agent_id
+ self.run_id = run_id
+ self.target_fos = target_fos
+
+ callbacks = [self.log] if self.agent_logger else []
+ super().__init__(
+ name="VerboseLowFidelityAgent",
+ tools=[VerboseAnalyticalTool(joint_configuration)],
+ add_base_tools=False,
+ model=model,
+ instructions=(
+ BASE_INSTRUCTIONS +
+ INPUT_FORMAT +
+ OUTPUT_FORMAT +
+ SIMPLIFIED_TOOL_USING_INSTRUCTION +
+ LOW_FIDELITY_TOOL_INSTRUCTION
+ ),
+ step_callbacks=callbacks,
+ verbosity_level=2,
+ )
+
+ def log(self, step, agent):
+ if self.agent_logger and step.__class__.__name__ == "ActionStep":
+ self.agent_logger.log(
+ agent_id=self.agent_id,
+ run_id=self.run_id,
+ target_fos=self.target_fos,
+ action_step=step
+ )
\ No newline at end of file
diff --git a/src/autoboltagent/verbose_prompts.py b/src/autoboltagent/verbose_prompts.py
new file mode 100644
index 0000000..ad42c46
--- /dev/null
+++ b/src/autoboltagent/verbose_prompts.py
@@ -0,0 +1,150 @@
+# Base prompt for the agent
+BASE_INSTRUCTIONS = """
+# BASE INSTRUCTIONS
+
+You are a mechanical engineering expert specializing in the design of bolted connections.
+You will be given tasks that require you to determine the number and size of bolts to achieve a required factor of safety.
+Work iteratively to refine your solution.
+Before you complete the task, you must satisfy the following requirements:
+- The factor of safety for both the bolt and the plate is within +/-0.1 of the target value.
+- You must recommend both a bolt size (diameter) and the number of bolts.
+"""
+
+INPUT_FORMAT = """
+# INPUT FORMAT
+
+You will be given a json-like object called "joint_configuration" with some fields corresponding to a joint configuration. The fields are listed below:
+
+ - load: (number) The external load force, in Newtons (N)
+ - desired_safety_factor: (number) the desired FOS number
+ - bolt_yield_strength: (number) the yield strength of the bolt material, in MegaPascals (MPa)
+ - plate_yield_strength: (number) the yield strength of the plate material, in MegaPascals (MPa)
+ - preload: the force of preload per joint, in Newtons (N)
+ - pitch: (number) thread pitch in mm
+ - plate_thickness: (number) plate thickness in mm
+ - bolt_elastic_modulus: (number) elastic modulus of bolt, in GigaPascals (GPa)
+ - plate_elastic_modulus: (number) elastic modulus of plate material, in GigaPascals (GPa)
+
+Below is an example of a valid input:
+
+joint_configuration = {
+ "load": 60000,
+ "desired_safety_factor": 3.0,
+ "bolt_yield_strength": 940,
+ "plate_yield_strength": 250,
+ "preload": 150000,
+ "pitch": 1.5,
+ "plate_thickness": 10,
+ "bolt_elastic_modulus": 210,
+ "plate_elastic_modulus": 210
+}
+
+These inputs will also be used to call the tool, and the names correspond exactly to those in the function signature of the tool.
+"""
+
+# Instructions for using tools
+TOOL_USING_INSTRUCTION = """
+# TOOL INSTRUCTIONS
+
+You will be given some tool(s) that use different methods to calculate the FOS of a joint configuration.
+
+They will be called with the following function signature:
+
+tool_call(
+ desired_safety_factor: float,
+ load: float,
+ preload: float,
+ num_bolts: int,
+ bolt_diameter: float,
+ bolt_yield_strength: float,
+ bolt_elastic_modulus: float,
+ plate_thickness: float,
+ plate_elastic_modulus: float,
+ plate_yield_strength: float,
+ pitch: float
+)
+
+These inputs include the specifications of the joint_configuration as well as num_bolts and bolt_diameter (in mm), which you will need to supply. You MUST include every parameter in the signature. Partial calls are strictly forbidden and will result in failure.
+The output of the tool will be a python dictionary with two fields: bolt_fos and plate_fos, which refer to the factor of safety for the bolt and plate respectively.
+
+When calling the tool, you must copy all joint_configuration fields exactly as provided. Do not change, round, “correct,” or infer any value (including pitch).
+"""
+
+SIMPLIFIED_TOOL_USING_INSTRUCTION = """
+# TOOL INSTRUCTIONS
+
+You will be given one or more tools that tell you if the factor of safety (FOS) for a joint configuration is within target.
+
+The tool is called using the following EXACT format:
+
+
+{"name":"analytical_fos_calculation","arguments":{"num_bolts":,"bolt_diameter":}}
+
+
+# RULES (ABSOLUTE):
+
+0) Your entire message MUST be EXACTLY one block OR EXACTLY the final JSON. No other characters, no extra lines, no explanation.
+
+1) After ANY tool result (True OR False), your NEXT message MUST be a block OR the final JSON.
+ You are FORBIDDEN from writing explanations, analysis, narration, or repeating observations.
+
+2) If the tool result is False, you MUST output a NEW .
+ You MUST change at least one of (num_bolts, bolt_diameter).
+
+3) If the tool result IS True, output ONLY the final JSON and NOTHING ELSE.
+
+## Heuristic
+Roughly speaking, the FOS will increase if the bolt diameter or number of bolts is increased. Keep that in mind when choosing your new values.
+"""
+
+LOW_FIDELITY_TOOL_INSTRUCTION = """
+## analytical_fos_calculation
+
+This tool is the low-fidelity tool that uses analytical methods to calculate the FOS for the joint configuration and determine if it is within tolerance. It is computationally efficient but may be lacking in accuracy.
+"""
+
+HIGH_FIDELITY_TOOL_INSTRUCTION = """
+## fea_fos_calculation
+
+This tool is the high-fidelity tool that uses a computationally intensive but very accurate finite element analysis method to calculate the FOS.
+"""
+
+OUTPUT_FORMAT = """
+# OUTPUT FORMAT
+
+You will output two different things depending on the result of the tool call.
+
+## FOS not within tolerance OR error
+If the tool returns an FOS that is not within tolerance, or an error message is returned, then the previous values you suggested for the number of bolts and the bolt diameter is wrong, and you must rethink them.
+
+After you have come up with new numbers, output a tool call with the new numbers like this and NOTHING ELSE:
+
+
+{"name":"analytical_fos_calculation","arguments":{"num_bolts": , "bolt_diameter": }}
+
+
+## FOS within tolerance
+If the tool returns an acceptable FOS, then you are done. Return the number of bolts and diameter that resulted in this FOS in a json-like.
+
+{
+ "num_bolts": ,
+ "bolt_diameter":
+}
+"""
+
+# Instructions specific to dual-fidelity agent
+DUAL_FIDELITY_COORDINATION = """
+
+You should also note that you have access to a low-fidelity analytical tool and a high-fidelity finite element analysis tool.
+- Use the low-fidelity tool for quick initial estimates and to explore different design options.
+- Use the high-fidelity tool to validate and refine your designs.
+"""
+
+# Instructions for an example design task
+EXAMPLE_TASK_INSTRUCTIONS = """
+Given the following joint configuration:
+
+joint_configuration = {}
+
+Determine the optimal number of bolts and the major diameter of the bolts:
+"""
diff --git a/tests/test_agents.py b/tests/test_agents.py
index d002ba4..4fe66f3 100644
--- a/tests/test_agents.py
+++ b/tests/test_agents.py
@@ -10,7 +10,7 @@ def is_macos() -> bool:
return platform.system() == "Darwin"
-def get_testing_model() -> smolagents.Model:
+def get_testing_model() -> smolagents.models.Model:
if is_macos():
# Use a local model on macOS for faster testing
return smolagents.MLXModel(
@@ -18,7 +18,7 @@ def get_testing_model() -> smolagents.Model:
)
else:
# Use the smallest Instruct model available for fast CI feedback
- return smolagents.TransformersModel(
+ return smolagents.models.TransformersModel(
model_id="HuggingFaceTB/SmolLM-135M-Instruct",
max_new_tokens=200, # Keep generation short for speed
)
@@ -38,7 +38,15 @@ def test_guessing_agent():
def test_low_fidelity_agent():
# Create the LowFidelityAgent and run it
- response = autoboltagent.LowFidelityAgent(get_testing_model()).run(
+ agent = autoboltagent.LowFidelityAgent(
+ model=get_testing_model(),
+ agent_id="low fidelity agent",
+ run_id="test 1",
+ target_fos=3.0,
+ max_steps=5
+ )
+
+ response = agent.run(
autoboltagent.prompts.EXAMPLE_TASK_INSTRUCTIONS
)
diff --git a/tests/test_logger.py b/tests/test_logger.py
new file mode 100644
index 0000000..23179d2
--- /dev/null
+++ b/tests/test_logger.py
@@ -0,0 +1,103 @@
+from autoboltagent.tools.logger import AgentLogger
+from autoboltagent.tools.logger import Iteration
+from smolagents import ActionStep, Timing, ToolCall, ChatMessage, MessageRole, AgentError
+from sqlalchemy import create_engine, select
+from sqlalchemy.orm import Session
+from datetime import datetime, timezone
+import pytest
+
+db_url = "sqlite:///agent_logs_test.db"
+
+@pytest.fixture
+def logger():
+ """
+ Fixture to create a new AgentLogger with a fresh db per test
+ """
+ AgentLogger.reset()
+ logger = AgentLogger(db_url)
+ yield logger
+ AgentLogger.reset()
+
+def get_log_session(db_url):
+ """
+ Helper function to get db session from url
+ """
+ engine = create_engine(db_url)
+ return Session(engine)
+
+def test_logger_empty_when_fresh(logger):
+ """
+ Test to see if the logger db is empty when fresh
+ """
+ with get_log_session(db_url) as session:
+ iterations = session.query(Iteration).all()
+
+ assert len(iterations) == 0
+
+def test_logger_one_write_fields_persist(logger):
+ """
+ Test one log write and see if all fields persists in the db
+ """
+ start_time = datetime.now(timezone.utc).timestamp()
+ end_time = datetime.now(timezone.utc).timestamp()
+
+ step = ActionStep(
+ step_number=1,
+ timing=Timing(
+ start_time=start_time,
+ end_time=end_time
+ ),
+ observations = "observation observation",
+ tool_calls=[ToolCall(name="tool call", arguments={"asdf": 1}, id="1341fad")],
+ model_output_message=ChatMessage(role=MessageRole("assistant"), content="LLM output 1"),
+ error=None
+ )
+
+ logger.log(
+ run_id="run_1",
+ agent_id="agent_1",
+ target_fos=1,
+ action_step=step
+ )
+
+ with get_log_session(db_url) as session:
+ iteration = session.query(Iteration).one()
+
+ assert iteration.run_id == "run_1"
+ assert iteration.agent_id == "agent_1"
+ assert iteration.iteration_no == 1
+ assert type(iteration.start_time) == datetime
+ assert type(iteration.end_time) == datetime
+ assert iteration.target_fos == 1
+ assert iteration.llm_output == "LLM output 1"
+ assert iteration.error_message == None
+
+
+def test_logger_large_write(logger):
+ """
+ Test 50 log writes and see if db contains 50 rows
+ """
+ for i in range(50):
+ start_time = datetime.now(timezone.utc).timestamp()
+ end_time = datetime.now(timezone.utc).timestamp()
+ step = ActionStep(
+ step_number=1,
+ timing=Timing(
+ start_time=start_time,
+ end_time=end_time
+ ),
+ observations = "observation observation",
+ tool_calls=[ToolCall(name="tool call", arguments={"asdf": 1}, id="1341fad")],
+ model_output_message=ChatMessage(role=MessageRole("assistant"), content=f"LLM output {i}"),
+ error=None
+ )
+ logger.log(
+ run_id="run_1",
+ agent_id="agent_1",
+ target_fos=1,
+ action_step=step
+ )
+ with get_log_session(db_url) as session:
+ iterations = session.query(Iteration).all()
+
+ assert len(iterations) == 50
\ No newline at end of file