Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -207,3 +207,9 @@ cython_debug/
marimo/_static/
marimo/_lsp/
__marimo__/


#
autobolt/
*.db-*
*.db
3 changes: 2 additions & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,5 @@ dependencies:
- pandas
- "autobolt @ git+https://github.com/sriyanc2001/AutoBolt.git"
- black
- pytest
- pytest
- sqlalchemy
97 changes: 97 additions & 0 deletions src/autoboltagent/VLLMModelCustom.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
from smolagents import VLLMModel
from smolagents.tools import Tool
from smolagents.monitoring import TokenUsage
from vllm.transformers_utils.tokenizer import get_tokenizer
from smolagents.models import (
ChatMessage,
MessageRole,
remove_content_after_stop_sequences,

)

from typing import Any

class VLLMModelCustom(VLLMModel):
def __init__(
self,
model_id,
model_kwargs: dict[str, Any] | None = None,
apply_chat_template_kwargs: dict[str, Any] | None = None,
sampling_params=None,
**kwargs,
):
super().__init__(
model_id=model_id,
model_kwargs=model_kwargs,
apply_chat_template_kwargs=apply_chat_template_kwargs,
**kwargs
)
self.sampling_params = sampling_params

def generate(
self,
messages: list[ChatMessage | dict],
stop_sequences: list[str] | None = None,
response_format: dict[str, str] | None = None,
tools_to_call_from: list[Tool] | None = None,
**kwargs,
) -> ChatMessage:
from vllm import SamplingParams # type: ignore
from vllm.sampling_params import StructuredOutputsParams # type: ignore

completion_kwargs = self._prepare_completion_kwargs(
messages=messages,
flatten_messages_as_text=(not self._is_vlm),
stop_sequences=stop_sequences,
tools_to_call_from=tools_to_call_from,
**kwargs,
)

prepared_stop_sequences = completion_kwargs.pop("stop", [])
messages = completion_kwargs.pop("messages")
tools = completion_kwargs.pop("tools", None)
completion_kwargs.pop("tool_choice", None)

if not self.sampling_params:
# Override the OpenAI schema for VLLM compatibility
structured_outputs = (
StructuredOutputsParams(json=response_format["json_schema"]["schema"]) if response_format else None
)


self.sampling_params = SamplingParams(
n=kwargs.get("n", 1),
temperature=kwargs.get("temperature", 0.0),
max_tokens=kwargs.get("max_tokens", 2048),
stop=prepared_stop_sequences,
structured_outputs=structured_outputs,
)


prompt = self.tokenizer.apply_chat_template(
messages,
tools=tools,
add_generation_prompt=True,
tokenize=False,
**self.apply_chat_template_kwargs,
)


out = self.model.generate(
prompt,
sampling_params=self.sampling_params,
**completion_kwargs,
)

output_text = out[0].outputs[0].text
if stop_sequences is not None and not self.supports_stop_parameter:
output_text = remove_content_after_stop_sequences(output_text, stop_sequences)
return ChatMessage(
role=MessageRole.ASSISTANT,
content=output_text,
raw={"out": output_text, "completion_kwargs": completion_kwargs},
token_usage=TokenUsage(
input_tokens=len(out[0].prompt_token_ids),
output_tokens=len(out[0].outputs[0].token_ids),
),
)
37 changes: 29 additions & 8 deletions src/autoboltagent/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,18 @@
)
from .tools import AnalyticalTool, FiniteElementTool

from .tools.logger import AgentLogger

class GuessingAgent(smolagents.ToolCallingAgent):

class GuessingAgent(smolagents.agents.ToolCallingAgent):
"""
An agent that makes guesses without using any tools.

This agent operates solely based on its internal model and does not utilize any external tools for analysis or calculations.
It is designed to provide initial estimates or solutions based on its knowledge and reasoning capabilities.
"""

def __init__(self, model: smolagents.Model) -> None:
def __init__(self, model: smolagents.models.Model) -> None:
"""
Initializes a GuessingAgent that does not use any tools.

Expand All @@ -33,40 +35,59 @@ def __init__(self, model: smolagents.Model) -> None:
)


class LowFidelityAgent(smolagents.ToolCallingAgent):
class LowFidelityAgent(smolagents.agents.ToolCallingAgent):
"""
An agent that utilizes a low-fidelity analytical tool for bolted connection design.

This agent leverages an analytical tool to perform calculations and analyses related to bolted connections.
It is designed to provide solutions based on simplified models and assumptions, making it suitable for quick estimates and preliminary designs.
"""

def __init__(self, model: smolagents.Model) -> None:
def __init__(self, model: smolagents.models.Model, agent_id: str, run_id: str, target_fos: float, agent_logger: AgentLogger|None = None, max_steps=20) -> None:
"""
Initializes a LowFidelityAgent that uses an analytical tool.

Args:
model: An instance of smolagents.Model to be used by the agent.
"""

self.agent_logger = agent_logger
self.agent_id = agent_id
self.run_id = run_id
self.target_fos = target_fos

callbacks = [self.log] if self.agent_logger else []

super().__init__(
name="LowFidelityAgent",
tools=[AnalyticalTool()],
add_base_tools=False,
model=model,
instructions=BASE_INSTRUCTIONS + TOOL_USING_INSTRUCTION,
step_callbacks = callbacks,
verbosity_level=2,
max_steps=max_steps
)

def log(self, step, agent):
if self.agent_logger and step.__class__.__name__ == "ActionStep":
self.agent_logger.log(
agent_id=self.agent_id,
run_id=self.run_id,
target_fos=self.target_fos,
action_step=step
)


class HighFidelityAgent(smolagents.ToolCallingAgent):
class HighFidelityAgent(smolagents.agents.ToolCallingAgent):
"""
An agent that utilizes a high-fidelity finite element analysis tool for bolted connection design.

This agent leverages a finite element tool to perform detailed calculations and analyses related to bolted connections.
It is designed to provide accurate and reliable solutions based on comprehensive models, making it suitable for
"""

def __init__(self, model: smolagents.Model) -> None:
def __init__(self, model: smolagents.models.Model) -> None:
"""
Initializes a HighFidelityAgent that uses a finite element tool.

Expand All @@ -83,15 +104,15 @@ def __init__(self, model: smolagents.Model) -> None:
)


class DualFidelityAgent(smolagents.ToolCallingAgent):
class DualFidelityAgent(smolagents.agents.ToolCallingAgent):
"""
An agent that utilizes both low-fidelity and high-fidelity tools for bolted connection design.

This agent leverages both an analytical tool and a finite element tool to perform calculations and analyses related to bolted connections.
It is designed to provide solutions that balance speed and accuracy by using the low-fidelity tool
"""

def __init__(self, model: smolagents.Model) -> None:
def __init__(self, model: smolagents.models.Model) -> None:
"""
Initializes a DualFidelityAgent that uses both analytical and finite element tools.

Expand Down
74 changes: 74 additions & 0 deletions src/autoboltagent/grammars.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# grammar that outputs low-fidelity tool call or final_answer
low_fidelity_agent_grammar = r"""
root ::= "<tool_call>\n" payload "\n</tool_call>"
payload ::= fos | final

fos ::= "{\"name\": \"analytical_fos_calculation\", \"arguments\": " fos_args "}}"
fos_args ::= "{" ( num_field ", " dia_field ) | ( dia_field ", " num_field ) "}"
num_field ::= "\"num_bolts\": " int
dia_field ::= "\"bolt_diameter\": " number

final ::= "{\"name\": \"final_answer\", \"arguments\": {\"answer\": " fos_args "}}}"

int ::= digit+
number ::= digit+ frac?
frac ::= "." digit{1,4}

digit ::= "0" | "1" | "2" | "3" | "4" | "5" | "6" | "7" | "8" | "9"
"""

# grammar that outputs high-fidelity tool call or final_answer
high_fidelity_agent_grammar = r"""
root ::= "<tool_call>\n" payload "\n</tool_call>"
payload ::= fos | final

fos ::= "{\"name\": \"fea_fos_calculation\", \"arguments\": " fos_args "}}"
fos_args ::= "{" ( num_field ", " dia_field ) | ( dia_field ", " num_field ) "}"
num_field ::= "\"num_bolts\": " int
dia_field ::= "\"bolt_diameter\": " number

final ::= "{\"name\": \"final_answer\", \"arguments\": {\"answer\": " fos_args "}}}"

int ::= digit+
number ::= digit+ frac?
frac ::= "." digit{1,4}

digit ::= "0" | "1" | "2" | "3" | "4" | "5" | "6" | "7" | "8" | "9"
"""

# grammar that outputs both types of tool call and final_answer
dual_fidelity_agent_grammar = r"""
root ::= "<tool_call>\n" payload "\n</tool_call>"
payload ::= fos | final

fos ::= "{\"name\": ("\"analytical_fos_calculation\"" | "\"fea_fos_calculation\""), \"arguments\": " fos_args "}}"
fos_args ::= "{" ( num_field ", " dia_field ) | ( dia_field ", " num_field ) "}"
num_field ::= "\"num_bolts\": " int
dia_field ::= "\"bolt_diameter\": " number

final ::= "{\"name\": \"final_answer\", \"arguments\": {\"answer\": " fos_args "}}}"

int ::= digit+
number ::= digit+ frac?
frac ::= "." digit{1,4}

digit ::= "0" | "1" | "2" | "3" | "4" | "5" | "6" | "7" | "8" | "9"
"""

low_fidelity_agent_grammar_original = r"""
root ::= "<tool_call>\n" payload "\n</tool_call>"
payload ::= fos | final

fos ::= "{\"name\": \"analytical_fos_calculation\", \"arguments\": " fos_args "}}"
fos_args ::= "{" ( num_field ", " dia_field ) | ( dia_field ", " num_field ) "}"
num_field ::= "\"num_bolts\": " int
dia_field ::= "\"bolt_diameter\": " number

final ::= "{\"name\": \"final_answer\", \"arguments\": {\"answer\": " fos_args "}}"

int ::= digit+
number ::= digit+ frac?
frac ::= "." digit{1,4}

digit ::= "0" | "1" | "2" | "3" | "4" | "5" | "6" | "7" | "8" | "9"
"""
6 changes: 4 additions & 2 deletions src/autoboltagent/tools/high_fidelity_tool.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import autobolt
import smolagents
from typing import Dict, Any, Union, cast

from .inputs import INPUTS


class FiniteElementTool(smolagents.Tool):
class FiniteElementTool(smolagents.tools.Tool):
"""
A tool that calculates the factor of safety for a bolted connection using finite element analysis.

Expand All @@ -15,7 +16,8 @@ class FiniteElementTool(smolagents.Tool):
name = "fea_fos_calculation"
description = "Calculates the factor of safety using finite element analysis."

inputs = INPUTS
input_type = dict[str, dict[str, Union[str, type, bool]]]
inputs: input_type = cast(input_type,INPUTS)

output_type = "number"

Expand Down
Loading
Loading