194 lines
7.0 KiB
Python
194 lines
7.0 KiB
Python
import os, sys, json, tty, termios
|
|
import config
|
|
from llm_client import LLMClient
|
|
from tools import coder
|
|
from scripts import gadget
|
|
|
|
TEXT_COLOR_YELLOW = '\033[93m'
|
|
TEXT_COLOR_GREEN = '\033[92m'
|
|
TEXT_COLOR_RESET = '\033[0m'
|
|
|
|
tools_definition = [
|
|
gadget.tools_mapping( coder.schema_read_file, coder.read_file ),
|
|
gadget.tools_mapping( coder.schema_write_file, coder.write_file ),
|
|
gadget.tools_mapping( coder.schema_edit_file, coder.edit_file ),
|
|
gadget.tools_mapping( coder.schema_run_bash, coder.run_bash ),
|
|
gadget.tools_mapping( coder.schema_search_code, coder.search_code ),
|
|
gadget.tools_mapping( coder.schema_git_operation, coder.git_operation ),
|
|
]
|
|
|
|
TOOLS = gadget.tool_schemas(tools_definition)
|
|
TOOL_HANDLERS = gadget.tool_handlers(tools_definition)
|
|
|
|
|
|
def interactive_input():
|
|
fd = sys.stdin.fileno()
|
|
old = termios.tcgetattr(fd)
|
|
|
|
print()
|
|
print("\u2500" * 50)
|
|
print("Hendrik AI Agent - Interactive Mode")
|
|
print("\u2500" * 50)
|
|
print(f"Workspace: {os.getcwd()}")
|
|
print("\u2500" * 50)
|
|
print("[Ctrl+W] Change workspace | :workspace <dir> | [Ctrl+D] Submit")
|
|
print("\u2500" * 50)
|
|
|
|
buffer = bytearray()
|
|
try:
|
|
tty.setraw(fd)
|
|
while True:
|
|
ch = os.read(fd, 1)
|
|
if ch == b'\x03': # Ctrl+C → exit
|
|
termios.tcsetattr(fd, termios.TCSADRAIN, old)
|
|
print("\r\nExiting.")
|
|
sys.exit(0)
|
|
elif ch == b'\x04': # Ctrl+D → submit
|
|
break
|
|
elif ch == b'\x17': # Ctrl+W → change workspace
|
|
termios.tcsetattr(fd, termios.TCSADRAIN, old)
|
|
print("\r\n", end="")
|
|
ws = input("Workspace directory: ").strip()
|
|
if ws:
|
|
resolved = os.path.abspath(ws)
|
|
if not os.path.isdir(resolved):
|
|
print(f"Error: '{resolved}' is not a valid directory")
|
|
else:
|
|
os.chdir(resolved)
|
|
print(f"\u2192 Workspace changed to {os.getcwd()}")
|
|
return interactive_input()
|
|
elif ch in (b'\r', b'\n'): # Enter
|
|
buffer.extend(b'\n')
|
|
sys.stdout.buffer.write(b'\r\n')
|
|
sys.stdout.flush()
|
|
elif ch == b'\x7f': # Backspace
|
|
if buffer:
|
|
buffer.pop()
|
|
sys.stdout.buffer.write(b'\b \b')
|
|
sys.stdout.flush()
|
|
elif ch >= b' ': # Printable characters
|
|
buffer.extend(ch)
|
|
sys.stdout.buffer.write(ch)
|
|
sys.stdout.flush()
|
|
finally:
|
|
termios.tcsetattr(fd, termios.TCSADRAIN, old)
|
|
|
|
full_query = buffer.decode('utf-8', errors='replace').strip()
|
|
|
|
if full_query.startswith(':workspace '):
|
|
ws = full_query[11:].strip()
|
|
resolved = os.path.abspath(ws)
|
|
if not os.path.isdir(resolved):
|
|
print(f"Error: '{resolved}' is not a valid directory")
|
|
return interactive_input()
|
|
os.chdir(resolved)
|
|
print(f"\u2192 Workspace changed to {os.getcwd()}")
|
|
return interactive_input()
|
|
|
|
return full_query
|
|
|
|
|
|
def agent_loop(user_query, messages, llm_client):
|
|
messages.append({"role": "user", "content": user_query})
|
|
for _ in range(config.AGENT_MAX_ITERATIONS):
|
|
response = llm_client.chat(messages, tools=TOOLS)
|
|
if response.tool_calls:
|
|
assistant_msg = {
|
|
"role": "assistant",
|
|
"content": response.content,
|
|
"tool_calls": response.tool_calls
|
|
}
|
|
messages.append(assistant_msg)
|
|
for tool_call in response.tool_calls:
|
|
tool_name = tool_call['function']['name']
|
|
tool_args = json.loads(tool_call['function']['arguments'])
|
|
handler = TOOL_HANDLERS.get(tool_name)
|
|
if not handler:
|
|
result = f"Tool {tool_name} not found"
|
|
else:
|
|
args_display = ", ".join(f"{k}={v!r}" for k, v in tool_args.items())
|
|
print(f" \u2192 {tool_name}({args_display})")
|
|
try:
|
|
if tool_name == "search_code":
|
|
result = handler(
|
|
pattern=tool_args["pattern"],
|
|
search_type=tool_args["search_type"],
|
|
path=tool_args.get("path", ".")
|
|
)
|
|
elif tool_name == "git_operation":
|
|
result = handler(args=tool_args["args"])
|
|
else:
|
|
result = handler(**tool_args)
|
|
except Exception as e:
|
|
result = f"Error executing tool: {str(e)}"
|
|
messages.append({
|
|
"role": "tool",
|
|
"tool_call_id": tool_call['id'],
|
|
"content": str(result)
|
|
})
|
|
else:
|
|
messages.append({"role": "assistant", "content": response.content})
|
|
return response.content, messages
|
|
msg = "Max iterations reached without final answer."
|
|
messages.append({"role": "assistant", "content": msg})
|
|
return msg, messages
|
|
|
|
|
|
def main():
|
|
workspace = None
|
|
query_parts = []
|
|
i = 1
|
|
while i < len(sys.argv):
|
|
if sys.argv[i] in ('-w', '--workspace') and i + 1 < len(sys.argv):
|
|
workspace = sys.argv[i + 1]
|
|
i += 2
|
|
else:
|
|
query_parts.append(sys.argv[i])
|
|
i += 1
|
|
|
|
if workspace:
|
|
resolved = os.path.abspath(workspace)
|
|
if not os.path.isdir(resolved):
|
|
print(f"Error: '{resolved}' is not a valid directory")
|
|
sys.exit(1)
|
|
os.chdir(resolved)
|
|
|
|
llm_client = LLMClient(
|
|
base_url=config.LLM_BASE_URL,
|
|
model=config.LLM_MODEL,
|
|
api_key=config.LLM_API_KEY
|
|
)
|
|
|
|
messages = None
|
|
|
|
if query_parts:
|
|
user_query = ' '.join(query_parts)
|
|
if not user_query:
|
|
print("No query provided.")
|
|
return
|
|
messages = [{"role": "system", "content": gadget.build_system_prompt(tools_definition)}]
|
|
print(f"{TEXT_COLOR_YELLOW}Thinking...{TEXT_COLOR_RESET}")
|
|
final_answer, messages = agent_loop(user_query, messages, llm_client)
|
|
print(f"\n{TEXT_COLOR_GREEN}Final Answer:{TEXT_COLOR_RESET}")
|
|
print(final_answer)
|
|
return
|
|
|
|
while True:
|
|
user_query = interactive_input()
|
|
if not user_query:
|
|
break
|
|
if user_query.lower() in ('/exit', '/quit'):
|
|
break
|
|
|
|
if messages is None:
|
|
messages = [{"role": "system", "content": gadget.build_system_prompt(tools_definition)}]
|
|
|
|
print(f"{TEXT_COLOR_YELLOW}Thinking...{TEXT_COLOR_RESET}")
|
|
final_answer, messages = agent_loop(user_query, messages, llm_client)
|
|
print(f"\n{TEXT_COLOR_GREEN}Final Answer:{TEXT_COLOR_RESET}")
|
|
print(final_answer)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|