Source code for climb.db.tinydb_db

import enum
from typing import List

from tinydb import Query, TinyDB
from tinydb.storages import JSONStorage
from tinydb.table import Document
from tinydb_serialization import SerializationMiddleware, Serializer
from tinydb_serialization.serializers import DateTimeSerializer

from climb.common import Session, UserSettings
from climb.common.serialization import (
    decode_enum,
    encode_enum,
    session_from_serializable_dict,
    session_to_serializable_dict,
)

from ._db import DB


# Custom serializer for enums.
[docs] class EnumSerializer(Serializer): OBJ_CLASS = enum.Enum # The class this serializer handles
[docs] def encode(self, obj: enum.Enum) -> str: return encode_enum(obj)
[docs] def decode(self, s: str) -> enum.Enum: return decode_enum(s)
serialization = SerializationMiddleware(JSONStorage) serialization.register_serializer(DateTimeSerializer(), "TinyDate") serialization.register_serializer(EnumSerializer(), "TinyEnum")
[docs] class TinyDB_DB(DB): def __init__(self, db_path: str = "db.json") -> None: self.db_path = db_path self.db = TinyDB(db_path, storage=serialization)
[docs] def update_user_settings(self, settings: UserSettings) -> None: # This "table" is used only to store user settings, use Document/doc_id upsert method to update this. self.db.table("user_settings").upsert(Document(settings.model_dump(), doc_id=0))
[docs] def get_user_settings(self) -> UserSettings: exists = len(self.db.table("user_settings").all()) > 0 if not exists: self.update_user_settings(UserSettings()) # Retrieve the first (and only) user settings document: return UserSettings(**self.db.table("user_settings").all()[0])
[docs] def update_session(self, session: Session) -> None: serializable_session = session_to_serializable_dict(session) self.db.table("session").upsert(serializable_session, Query().session_key == session.session_key)
[docs] def get_session(self, session_key: str) -> Session: doc_to_deserialize = self.db.table("session").search(Query().session_key == session_key)[0] deserialized_session = session_from_serializable_dict(doc_to_deserialize) return deserialized_session
[docs] def get_all_sessions(self) -> List[Session]: return [session_from_serializable_dict(doc) for doc in self.db.table("session").all()]
[docs] def delete_session(self, session_key: str) -> None: self.db.table("session").remove(Query().session_key == session_key)