From 1d571fc72693fa0500983854df829429e061b2f4 Mon Sep 17 00:00:00 2001 From: "Max R. Carrara" Date: Mon, 10 Mar 2025 23:34:34 +0100 Subject: [PATCH] bot: db: add base model and a simple db manager The base model basically just adds typing to peewee (lmao) and also handles the date_created, date_updated fields all methods, except for raw SQL. This should make it a little easier to write a tiny async layer on top of it all. Signed-off-by: Max R. Carrara --- pyproject.toml | 1 + src/bot/db/__init__.py | 7 ++ src/bot/db/db.py | 48 +++++++++++ src/bot/db/model.py | 187 +++++++++++++++++++++++++++++++++++++++++ 4 files changed, 243 insertions(+) create mode 100644 src/bot/db/__init__.py create mode 100644 src/bot/db/db.py create mode 100644 src/bot/db/model.py diff --git a/pyproject.toml b/pyproject.toml index 3b24669..ef9ecb7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,6 +7,7 @@ readme = "README.md" license = { text = "MIT" } dependencies = [ + "peewee>=3.17.9", "pydantic>=2.10.0", ] diff --git a/src/bot/db/__init__.py b/src/bot/db/__init__.py new file mode 100644 index 0000000..975e99d --- /dev/null +++ b/src/bot/db/__init__.py @@ -0,0 +1,7 @@ +__all__ = [ + "BaseModel", + "DatabaseManager", +] + +from .db import DatabaseManager +from .model import BaseModel diff --git a/src/bot/db/db.py b/src/bot/db/db.py new file mode 100644 index 0000000..dfcd6be --- /dev/null +++ b/src/bot/db/db.py @@ -0,0 +1,48 @@ +import logging + +from typing import Optional + +from playhouse.sqlite_ext import SqliteExtDatabase + +from bot.env import Environment + + +_log = logging.getLogger(__name__) + + +class DatabaseManager: + __db_instance: Optional[SqliteExtDatabase] = None + + def __init__(self) -> None: + pass + + @classmethod + def load_once(cls) -> SqliteExtDatabase: + if cls.__db_instance is not None: + return cls.__db_instance + + try: + db_name = ":memory:" if Environment.bot_memory_db() else "bot.db" + + cls.__db_instance = SqliteExtDatabase( + db_name, + autoconnect=True, + pragmas={ + "journal_mode": "wal", + "cache_size": -1024 * 64, + "ignore_check_constraints": 0, + "synchronous": 0, + }, + ) + + _log.debug(f'Prepared database instance "{db_name}"') + + return cls.__db_instance + except Exception as e: + _log.error( + "An unexpected error occurred while attempting to set up" + " an SQLite database:", + exc_info=e, + ) + + raise diff --git a/src/bot/db/model.py b/src/bot/db/model.py new file mode 100644 index 0000000..6e5d7e7 --- /dev/null +++ b/src/bot/db/model.py @@ -0,0 +1,187 @@ +from datetime import datetime +from typing import Any, Iterable, Mapping, MutableMapping, Self, Sequence + +import peewee + + +PrimaryKey = int +SelectType = peewee.Model | peewee.Field | peewee.Function | peewee.Expression + +TupleRow = tuple[Any, ...] +MapRow = MutableMapping[str, Any] +Row = TupleRow | MapRow +Rows = Sequence[TupleRow] | Sequence[MapRow] +IterRows = Iterable[TupleRow] | Iterable[MapRow] + + +class BaseModel(peewee.Model): + date_created = peewee.DateTimeField(default=datetime.now) + date_updated = peewee.DateTimeField(default=datetime.now) + + @classmethod + def select(cls, *fields: list[Self | SelectType]) -> peewee.ModelSelect: + return super().select(*fields) + + @classmethod + def update(cls, __data=None, **update: Any) -> peewee.ModelUpdate: + if "date_updated" not in update: + update["date_updated"] = datetime.now() + return super().update(__data, **update) + + @classmethod + def insert(cls, __data=None, **insert: Any) -> peewee.ModelInsert: + insert["date_updated"] = datetime.now() + return super().insert(__data=__data, **insert) + + @classmethod + def insert_many( + cls, rows: IterRows, fields: Sequence[peewee.Field | str] | None = None + ) -> peewee.ModelInsert: + def lazy_tuples(rows_: Iterable[TupleRow]) -> Iterable[TupleRow]: + now = datetime.now() + for row in rows_: + yield tuple([*row, now, now]) + + def lazy_dicts(rows_: Iterable[MapRow]) -> Iterable[MapRow]: + now = datetime.now() + for row in rows_: + row["date_created"] = now + row["date_updated"] = now + yield row + + if fields: + lazy = lazy_tuples + + if not isinstance(fields, list): + fields = list(fields) + + fields.append("date_created") + fields.append("date_updated") + + else: + lazy = lazy_dicts + + # Safe to ignore type checking here because the presence of a fields list + # implies that tuples are being used. Also, it's peewee's problem to + # handle that anyways. + return super().insert_many(lazy(rows), fields) # type: ignore + + @classmethod + def insert_from( + cls, query: peewee.SelectQuery, fields: Sequence[peewee.Field | str] + ) -> peewee.ModelInsert: + field_names: list[str] = [ + field.column_name if isinstance(field, peewee.Field) else field + for field in fields + ] + + def lazy(query_: peewee.SelectQuery): + select = query_.select_from(field_names) + for row in select: + values = [] + for field_name in field_names: + values.append(getattr(row, field_name)) + yield tuple(values) + + return cls.insert_many(lazy(query), fields=field_names) + + @classmethod + def replace(cls, __data=None, **insert: Any) -> peewee.ModelInsert: + now = datetime.now() + insert["date_created"] = now + insert["date_updated"] = now + return super().replace(__data, **insert) + + @classmethod + def replace_many( + cls, rows: Iterable, fields: list[peewee.Field | str] | None = None + ) -> peewee.ModelInsert: + return cls.insert_many(rows=rows, fields=fields).on_conflict("REPLACE") + + @classmethod + def raw(cls, sql: str, *params: Iterable[Any]) -> Self: + return super().raw(sql, *params) + + @classmethod + def delete(cls) -> peewee.ModelDelete: + return super().delete() + + @classmethod + def create(cls, **query: Any): + now = datetime.now() + query["date_created"] = now + query["date_updated"] = now + + return super().create(**query) + + @classmethod + def bulk_create(cls, model_list: Iterable[Self], batch_size: int | None = None): + def lazy(models: Iterable[Self]) -> Iterable[Self]: + now = datetime.now() + for model in models: + model.date_created = now + model.date_updated = now + + yield model + + return super().bulk_create(lazy(model_list), batch_size=batch_size) + + @classmethod + def bulk_update( + cls, model_list: Iterable[Self], fields: Iterable, batch_size: int | None = None + ): + def lazy(models: Iterable[Self]) -> Iterable[Self]: + now = datetime.now() + for model in models: + model.date_updated = now + + yield model + + fields = list(fields) + fields.append("date_updated") + + return super().bulk_update(lazy(model_list), fields, batch_size=batch_size) + + @classmethod + def get(cls, *query: peewee.Expression, **filters) -> Self: + return super().get(*query, **filters) + + @classmethod + def get_or_none(cls, *query: peewee.Expression, **filters) -> Self | None: + return super().get_or_none(*query, **filters) + + @classmethod + def get_by_id(cls, pk: PrimaryKey) -> Self: + return super().get_by_id(pk) + + @classmethod + def get_or_create(cls, **kwargs: Any) -> Self: + return super().get_or_create(**kwargs) + + @classmethod + def set_by_id(cls, key: PrimaryKey, value: Mapping[str, Any]): + return super().set_by_id(key, value) + + @classmethod + def delete_by_id(cls, pk: PrimaryKey): + return super().delete_by_id(pk) + + @classmethod + def filter(cls, *dq_nodes: peewee.DQ, **filters): + return super().filter(*dq_nodes, **filters) + + def get_id(self) -> int: + return super().get_id() + + def save(self, force_insert: bool = False, only: Iterable | None = None) -> int: + if only is not None and not isinstance(only, list): + only = list(only) + + return super().save(force_insert, only) + + @property + def dirty_fields(self) -> list: + return super().dirty_fields + + def is_dirty(self) -> bool: + return super().is_dirty()