-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
149 lines (102 loc) · 4.95 KB
/
main.py
File metadata and controls
149 lines (102 loc) · 4.95 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import os
import sys
import config
from dotenv import load_dotenv
from google import genai
from google.genai import types
from functions.get_files_info import schema_get_files_info
from functions.get_file_content import schema_get_file_content
from functions.run_python_file import schema_run_python_file
from functions.write_file import schema_write_file
from functions.call_function import call_function
def main():
#@ Check if the correct number of arguments is provided when running the file
if len(sys.argv) < 2:
print("Usage: python main.py <prompt goes here>")
sys.exit(1)
#@ Get the prompt from the command line argument
user_prompt = sys.argv[1]
#@ Getting our system prompt
system_prompt = config.SYSTEM_PROMPT
## Check for flags
verbose = False
if len(sys.argv) > 2:
if sys.argv[2] == "--verbose":
verbose = True
#---
#! Getting access to the API
load_dotenv()
api_key = os.environ.get("GEMINI_API_KEY")
client = genai.Client(api_key=api_key)
#@ Setting up message history along with roles
messages = [
types.Content(role="user", parts=[types.Part(text=user_prompt)])
]
#@ create a list of available functions to our LLM
available_functions = types.Tool(
function_declarations=[
schema_get_files_info,
schema_get_file_content,
schema_write_file,
schema_run_python_file
]
)
# ---
#$ Looping so that we can talk to the model
max_iterations = config.MAX_ITERATIONS
for i in range(max_iterations):
try:
# ! Getting a response from the API -> CALL THE AI
response = client.models.generate_content(
model="gemini-2.0-flash-001",
contents=messages,
config=types.GenerateContentConfig(system_instruction=system_prompt, tools=[available_functions]),
)
prompt_tokens = response.usage_metadata.prompt_token_count
response_tokens = response.usage_metadata.candidates_token_count
# ---
# @ Checking response's .candidates property
# -> Candidates are like the AI's thinking process (the possible responses that it has generated)
# within the candidates are the function calls -> so we know what functions that the AI wants to call
candidates = response.candidates
if candidates is not None:
for item in candidates:
candidate_content = item.content
if candidate_content is not None:
messages.append(candidate_content)
#@ printing and executing function calls (specified by the LLM)
function_call = response.function_calls
if function_call is not None:
for item in function_call:
try:
function_call_result = call_function(item, verbose)
function_response = function_call_result.parts[0].function_response.response
#@ Add reply to messages list
messages.append(types.Content(role="user", parts=function_call_result.parts))
print("=================")
print("=================")
print(f"Added function result to messages: {function_call_result.parts}")
if not function_response:
raise Exception("Fatal function call exception. See function_result for details.")
if verbose:
print(f"-> {function_response}")
except Exception as e:
function_call_error =f'Error occurred while calling function: {item.name}. {e}'
messages.append(types.Content(role="user", parts=[types.Part(text=function_call_error)]))
continue;
if function_call is None and response.text:
#@ Print the response
print("=================")
print("FINAL RESPONSE FROM LLM:")
print("=================")
print(response.text)
break
except Exception as e:
print(f'Error occurred while talking to the LLM. {e}')
# ---
if verbose:
print(f"User prompt: {user_prompt}");
print(f"Prompt tokens: {prompt_tokens}")
print(f"Response tokens: {response_tokens}")
if __name__ == "__main__":
main()