Source code for climb.tool.tool_comms
import copy
import io
import os
import queue
import sys
import threading
import traceback
from functools import partial
from typing import Any, Callable, Iterable, List, NoReturn, Optional, Tuple, Union
from climb.common import ToolUserReportSeq
from climb.common.utils import filter_out_lines
BACKUP_OUTPUT_FILE = "progress.txt"
[docs]
class ToolOutput:
def __init__(self) -> None:
self._tool_return: str = ""
self._user_report_outputs: ToolUserReportSeq = []
self.success: bool = True
self.files_in: List[str] = []
self.files_out: List[str] = []
@property
def tool_return(self) -> str:
return self._tool_return
@tool_return.setter
def tool_return(self, value: str) -> None:
self._tool_return = value
@property
def user_report_outputs(self) -> ToolUserReportSeq:
return self._user_report_outputs
@user_report_outputs.setter
def user_report_outputs(self, value: ToolUserReportSeq) -> None:
self._user_report_outputs = value
[docs]
def set_empty(self) -> None:
self._tool_return = ""
self._user_report_outputs = []
self.files_in = []
self.files_out = []
ToolReturnIter = Iterable[Union[str, ToolOutput]]
[docs]
class ToolCommunicator:
def __init__(self) -> None:
self.comm_queue: queue.Queue = queue.Queue()
self.exc_queue: queue.Queue = queue.Queue()
self.std_queue: queue.Queue = queue.Queue()
self.return_set = False
[docs]
def print(self, *args: Any) -> None:
if self.return_set:
raise ValueError("Cannot print after return value has been set")
as_str = "".join([str(s) for s in args])
self.comm_queue.put(f"{as_str}\n")
[docs]
def set_returns(
self,
tool_return: str,
user_report: Optional[ToolUserReportSeq] = None,
files_in: Optional[List[str]] = None,
files_out: Optional[List[str]] = None,
) -> None:
to = ToolOutput()
# Output for LLM:
to.tool_return = tool_return
# Output for user:
if user_report:
to.user_report_outputs = user_report
# Files in and out:
if files_in:
to.files_in = files_in
if files_out:
to.files_out = files_out
self.comm_queue.put(to)
self.return_set = True
[docs]
def except_hook(args: Any, exc_queue: queue.Queue) -> NoReturn:
string_io = io.StringIO()
# print(args.exc_type, args.exc_value)
# print("args.exc_type is SystemExit", args.exc_type is SystemExit)
# print("args.exc_value == 'ThreadWithTrace killed'", str(args.exc_value) == "ThreadWithTrace killed")
if args.exc_type is SystemExit and str(args.exc_value) == "ThreadWithTrace killed":
# Do not print the exception.
# Only put it on the queue.
exc_queue.put("ERROR: Tool terminated by system or user.")
raise
else:
traceback.print_exception(args.exc_type, args.exc_value, args.exc_traceback, file=string_io)
exc_queue.put(string_io.getvalue())
raise ToolException(f"\nException from thread: {args.thread}\n{string_io.getvalue()}")
[docs]
def process_stream_chunk(s: str) -> Optional[str]:
if s.strip() == "":
return None
if not s.endswith("\n"):
return s + "\n"
return s
[docs]
class StreamRedirector:
def __init__(self, q: queue.Queue) -> None:
self.q = q
[docs]
def write(self, text: str) -> None:
text_ = process_stream_chunk(text)
if text_ is not None:
self.q.put(text_)
[docs]
class ToolThread(threading.Thread):
# Source for the trace approach: https://www.geeksforgeeks.org/python-different-ways-to-kill-a-thread/
SYSTEM_EXIT_MSG = "ThreadWithTrace killed"
def __init__(self, *args: Any, **keywords: Any) -> None:
threading.Thread.__init__(self, *args, **keywords)
self.killed: bool = False
[docs]
def start(self, std_q: queue.Queue, exc_q: queue.Queue, backup_output_file_path: str) -> None:
self.sys_stdout_bak = sys.stdout
self.sys_stderr_bak = sys.stderr
self.threading_excepthook_bak = copy.copy(threading.excepthook)
self.backup_output_file_path = backup_output_file_path
sys.stdout = StreamRedirector(std_q) # type: ignore
sys.stderr = StreamRedirector(std_q) # type: ignore
threading.excepthook = partial(except_hook, exc_queue=exc_q)
self.__run_backup = self.run
self.run = self.__run
threading.Thread.start(self)
def _reset_key_resources(self) -> None:
threading.excepthook = self.threading_excepthook_bak
sys.stdout = self.sys_stdout_bak
sys.stderr = self.sys_stderr_bak
if os.path.exists(self.backup_output_file_path):
os.remove(self.backup_output_file_path)
def __run(self) -> None:
sys.settrace(self.globaltrace)
self.__run_backup()
self.run = self.__run_backup
[docs]
def globaltrace(self, frame: Any, event: str, arg: Any) -> Optional[Callable]:
if event == "call":
return self.localtrace
else:
return None
[docs]
def localtrace(self, frame: Any, event: str, arg: Any) -> Optional[Callable]:
if self.killed:
if event == "line":
raise SystemExit(self.SYSTEM_EXIT_MSG)
return self.localtrace
[docs]
def join(self, timeout: Optional[int] = None) -> None:
self._reset_key_resources()
threading.Thread.join(self, timeout=timeout)
[docs]
def kill(self, timeout: Optional[int] = None) -> None:
self.killed = True
self.join(timeout=timeout)
[docs]
def live_output_iterable(
thread: ToolThread,
comm_q: queue.Queue,
exc_q: queue.Queue,
std_q: queue.Queue,
return_holder: ToolOutput,
# stdout_bak: TextIO,
# stderr_bak: TextIO,
wd: str,
) -> ToolReturnIter:
timeout = 1
thread.start(std_q, exc_q, os.path.join(wd, BACKUP_OUTPUT_FILE))
return_was_obtained = False
backup_output_file_contents = ""
while thread.is_alive() and not thread.killed:
# 1. Follow the STDOUT/STDERR streams.
try:
output = std_q.get(block=False)
thread.sys_stdout_bak.write(output)
thread.sys_stdout_bak.flush()
yield filter_out_lines(output)
except queue.Empty:
pass
# 2. Follow the backup output stream file BACKUP_OUTPUT_FILE, and yield any new lines.
# Why is this here? It may be impossible to capture the output of a tool's STDOUT/STDERR streams if it uses
# multiprocessing, e.g. Autoprognosis does this. In such cases, we can use a backup file to capture the output.
# The tool will have to be directed to print occasional output to this file, though. In AutoPrognosis, this is
# achievable with the heartbeat hook.
backup_file_path = os.path.join(wd, BACKUP_OUTPUT_FILE)
if os.path.exists(backup_file_path):
with open(backup_file_path, "r") as f: # pylint: disable=unspecified-encoding
backup_output_file_contents_current = f.read()
if backup_output_file_contents_current != backup_output_file_contents:
new_lines = backup_output_file_contents_current.replace(backup_output_file_contents, "")
backup_output_file_contents = backup_output_file_contents_current
yield filter_out_lines(new_lines)
# 3. Follow the explicit communication queue.
try:
message = comm_q.get(timeout=timeout)
except queue.Empty:
if thread.is_alive():
continue
else:
# 4. Follow the exception queue.
# If an exception occurred in the thread, it should have been saved in the exception queue.
if not exc_q.empty():
# If there is an exception in the queue, we should return its string representation.
exc = exc_q.get()
return_holder.success = False
return_holder.tool_return = exc
break
else:
# If somehow exception was not caught, we return an empty output.
return_holder.set_empty()
break
# 5. Handle output from the tool, if reached.
if isinstance(message, ToolOutput):
return_was_obtained = True
return_holder.tool_return = message.tool_return
return_holder.user_report_outputs = message.user_report_outputs
return_holder.success = message.success
return_holder.files_in = message.files_in
return_holder.files_out = message.files_out
break
# Empty message edge case.
if message is None:
return_holder.set_empty()
break
yield filter_out_lines(message)
# Join the thread if it is still alive.
thread.join()
# 1. Spit out any STDOUT/STDERR streams if there are such left
if not thread.killed:
while True:
try:
output = std_q.get(block=False)
thread.sys_stdout_bak.write(output)
thread.sys_stdout_bak.flush()
yield filter_out_lines(output)
except queue.Empty:
break
# 2. Attempt to catch the return value from the tool, if missed.
#
# In certain cases, likely due to a "race condition"-like scenario, we end up not capturing the return value in the
# above loop. In order to "catch" the return value, we will try to get it from the queue here, by iterating over the
# queue until we find it (or we get to the end without finding it).
if not return_was_obtained:
try:
# Keep looping through the comm_q to see if we can find the return value.
while True:
message = comm_q.get(timeout=timeout)
if isinstance(message, ToolOutput):
# Return value found, save it.
return_was_obtained = True
return_holder.tool_return = message.tool_return
return_holder.user_report_outputs = message.user_report_outputs
return_holder.success = message.success
return_holder.files_in = message.files_in
return_holder.files_out = message.files_out
break
except queue.Empty:
# If we get here, we did not find the return value.
pass
yield return_holder
[docs]
def execute_tool(tool_func: Callable, wd: str, **kwargs: Any) -> Tuple[ToolThread, ToolReturnIter]:
tc = ToolCommunicator()
# Start message generation in a separate thread
thread = ToolThread(target=tool_func, args=(tc,), kwargs=kwargs)
return_holder = ToolOutput()
return (
thread,
live_output_iterable(
thread,
tc.comm_queue,
tc.exc_queue,
tc.std_queue,
return_holder=return_holder,
wd=wd,
),
)