Source code for climb.common.serialization
import copy
import enum
import importlib
import os
import pickle
from typing import Any, Dict
import matplotlib.figure
import plotly.graph_objects
from . import Message, Session
from .utils import make_filename_path_safe
[docs]
def encode_enum(obj: enum.Enum) -> str:
"""Store the module and the enum name and value in a string separated by a slash.
Args:
obj (enum.Enum): The enum object to encode.
Returns:
str: The encoded string.
"""
# Note: we record the module name to ensure that the enum can be properly imported.
encoding = f"{type(obj).__module__}/{str(obj)}"
return encoding
[docs]
def decode_enum(s: str) -> enum.Enum:
"""Recover the module and the enum name and value from the string, instantiate the enum and return it.
Args:
s (str): The encoded string.
Returns:
enum.Enum: The decoded enum object.
"""
module_str, enum_part = s.split("/")
enum_name, enum_value = enum_part.split(".")
# Note: use the module name to dynamically import the module that has the enum.
module = importlib.import_module(module_str)
enum_cls = getattr(module, enum_name)
return enum_cls[enum_value]
# TODO: This should be made properly modular etc. Currently it's just a quick hack.
[docs]
def message_to_serializable_dict(message: Message, session_path: str) -> Dict[str, Any]:
pickle_dir = os.path.join(session_path, "session_pickles", make_filename_path_safe(message.key))
message_dump = message.model_dump(by_alias=True)
new_message_dump = copy.deepcopy(message_dump)
# Handle enum (ResponseKind), which isn't directly serializable.
if message.engine_state is not None:
new_message_dump["engine_state"]["response_kind"] = encode_enum(message.engine_state.response_kind_value)
# Handle the figure objects.
if message.tool_call_user_report is not None:
serializable = []
for idx, report_item in enumerate(message.tool_call_user_report):
if isinstance(report_item, plotly.graph_objects.Figure):
os.makedirs(pickle_dir, exist_ok=True)
pickle_path = os.path.join(pickle_dir, f"{idx}__plotly_figure.pickle")
with open(pickle_path, "wb") as f:
pickle.dump(report_item, f)
serializable.append({"type": "plotly_figure", "report_item_idx": idx, "path": pickle_path})
elif isinstance(report_item, matplotlib.figure.Figure):
os.makedirs(pickle_dir, exist_ok=True)
pickle_path = os.path.join(pickle_dir, f"{idx}__matplotlib_figure.pickle")
with open(pickle_path, "wb") as f:
pickle.dump(report_item, f)
serializable.append({"type": "matplotlib_figure", "report_item_idx": idx, "path": pickle_path})
elif isinstance(report_item, str):
serializable.append({"type": "str", "report_item_idx": idx, "content": report_item})
else:
raise ValueError(f"Message serialization failed. Unsupported report item type: {type(report_item)}")
new_message_dump["tool_call_user_report"] = serializable
return new_message_dump
[docs]
def message_from_serializable_dict(message_dict: Dict[str, Any]) -> Message:
message_dict_new = copy.deepcopy(message_dict)
# Handle enum (ResponseKind), which isn't directly serializable.
if message_dict["engine_state"] is not None:
message_dict_new["engine_state"]["response_kind"] = decode_enum(message_dict["engine_state"]["response_kind"])
# Handle the figure objects.
if message_dict["tool_call_user_report"]:
deserialized = []
for report_item in message_dict["tool_call_user_report"]:
if report_item["type"] == "plotly_figure":
try:
with open(report_item["path"], "rb") as f:
deserialized.append(pickle.load(f))
except Exception as e:
print(f"Failed to deserialize plotly figure from {report_item['path']}: {e}")
report_item["type"] = "str"
deserialized.append("< Failed to deserialize plotly figure >")
elif report_item["type"] == "matplotlib_figure":
try:
with open(report_item["path"], "rb") as f:
deserialized.append(pickle.load(f))
except Exception as e:
print(f"Failed to deserialize matplotlib figure from {report_item['path']}: {e}")
report_item["type"] = "str"
deserialized.append("< Failed to deserialize matplotlib figure >")
elif report_item["type"] == "str":
deserialized.append(report_item["content"])
else:
raise ValueError(f"Message deserialization failed. Unsupported report item type: {report_item['type']}")
message_dict_new["tool_call_user_report"] = deserialized
return Message(**message_dict_new)
[docs]
def session_to_serializable_dict(session: Session) -> Dict[str, Any]:
session_dump = session.model_dump()
if session.messages:
serialized_messages = [
message_to_serializable_dict(message, session.working_directory) for message in session.messages
]
session_dump["messages"] = serialized_messages
else:
session_dump["messages"] = []
return session_dump
[docs]
def session_from_serializable_dict(session_dict: Dict[str, Any]) -> Session:
session_dict_new = copy.deepcopy(session_dict)
if session_dict["messages"]:
session_dict_new["messages"] = [message_from_serializable_dict(message) for message in session_dict["messages"]]
else:
session_dict_new["messages"] = []
return Session(**session_dict_new)