bot: db: add base model and a simple db manager

The base model basically just adds typing to peewee (lmao) and also
handles the date_created, date_updated fields all methods, except for
raw SQL.

This should make it a little easier to write a tiny async layer on top
of it all.

Signed-off-by: Max R. Carrara <max@aequito.sh>
This commit is contained in:
Max R. Carrara 2025-03-10 23:34:34 +01:00
parent c2bd8de5fb
commit 1d571fc726
4 changed files with 243 additions and 0 deletions

View file

@ -7,6 +7,7 @@ readme = "README.md"
license = { text = "MIT" } license = { text = "MIT" }
dependencies = [ dependencies = [
"peewee>=3.17.9",
"pydantic>=2.10.0", "pydantic>=2.10.0",
] ]

7
src/bot/db/__init__.py Normal file
View file

@ -0,0 +1,7 @@
__all__ = [
"BaseModel",
"DatabaseManager",
]
from .db import DatabaseManager
from .model import BaseModel

48
src/bot/db/db.py Normal file
View file

@ -0,0 +1,48 @@
import logging
from typing import Optional
from playhouse.sqlite_ext import SqliteExtDatabase
from bot.env import Environment
_log = logging.getLogger(__name__)
class DatabaseManager:
__db_instance: Optional[SqliteExtDatabase] = None
def __init__(self) -> None:
pass
@classmethod
def load_once(cls) -> SqliteExtDatabase:
if cls.__db_instance is not None:
return cls.__db_instance
try:
db_name = ":memory:" if Environment.bot_memory_db() else "bot.db"
cls.__db_instance = SqliteExtDatabase(
db_name,
autoconnect=True,
pragmas={
"journal_mode": "wal",
"cache_size": -1024 * 64,
"ignore_check_constraints": 0,
"synchronous": 0,
},
)
_log.debug(f'Prepared database instance "{db_name}"')
return cls.__db_instance
except Exception as e:
_log.error(
"An unexpected error occurred while attempting to set up"
" an SQLite database:",
exc_info=e,
)
raise

187
src/bot/db/model.py Normal file
View file

@ -0,0 +1,187 @@
from datetime import datetime
from typing import Any, Iterable, Mapping, MutableMapping, Self, Sequence
import peewee
PrimaryKey = int
SelectType = peewee.Model | peewee.Field | peewee.Function | peewee.Expression
TupleRow = tuple[Any, ...]
MapRow = MutableMapping[str, Any]
Row = TupleRow | MapRow
Rows = Sequence[TupleRow] | Sequence[MapRow]
IterRows = Iterable[TupleRow] | Iterable[MapRow]
class BaseModel(peewee.Model):
date_created = peewee.DateTimeField(default=datetime.now)
date_updated = peewee.DateTimeField(default=datetime.now)
@classmethod
def select(cls, *fields: list[Self | SelectType]) -> peewee.ModelSelect:
return super().select(*fields)
@classmethod
def update(cls, __data=None, **update: Any) -> peewee.ModelUpdate:
if "date_updated" not in update:
update["date_updated"] = datetime.now()
return super().update(__data, **update)
@classmethod
def insert(cls, __data=None, **insert: Any) -> peewee.ModelInsert:
insert["date_updated"] = datetime.now()
return super().insert(__data=__data, **insert)
@classmethod
def insert_many(
cls, rows: IterRows, fields: Sequence[peewee.Field | str] | None = None
) -> peewee.ModelInsert:
def lazy_tuples(rows_: Iterable[TupleRow]) -> Iterable[TupleRow]:
now = datetime.now()
for row in rows_:
yield tuple([*row, now, now])
def lazy_dicts(rows_: Iterable[MapRow]) -> Iterable[MapRow]:
now = datetime.now()
for row in rows_:
row["date_created"] = now
row["date_updated"] = now
yield row
if fields:
lazy = lazy_tuples
if not isinstance(fields, list):
fields = list(fields)
fields.append("date_created")
fields.append("date_updated")
else:
lazy = lazy_dicts
# Safe to ignore type checking here because the presence of a fields list
# implies that tuples are being used. Also, it's peewee's problem to
# handle that anyways.
return super().insert_many(lazy(rows), fields) # type: ignore
@classmethod
def insert_from(
cls, query: peewee.SelectQuery, fields: Sequence[peewee.Field | str]
) -> peewee.ModelInsert:
field_names: list[str] = [
field.column_name if isinstance(field, peewee.Field) else field
for field in fields
]
def lazy(query_: peewee.SelectQuery):
select = query_.select_from(field_names)
for row in select:
values = []
for field_name in field_names:
values.append(getattr(row, field_name))
yield tuple(values)
return cls.insert_many(lazy(query), fields=field_names)
@classmethod
def replace(cls, __data=None, **insert: Any) -> peewee.ModelInsert:
now = datetime.now()
insert["date_created"] = now
insert["date_updated"] = now
return super().replace(__data, **insert)
@classmethod
def replace_many(
cls, rows: Iterable, fields: list[peewee.Field | str] | None = None
) -> peewee.ModelInsert:
return cls.insert_many(rows=rows, fields=fields).on_conflict("REPLACE")
@classmethod
def raw(cls, sql: str, *params: Iterable[Any]) -> Self:
return super().raw(sql, *params)
@classmethod
def delete(cls) -> peewee.ModelDelete:
return super().delete()
@classmethod
def create(cls, **query: Any):
now = datetime.now()
query["date_created"] = now
query["date_updated"] = now
return super().create(**query)
@classmethod
def bulk_create(cls, model_list: Iterable[Self], batch_size: int | None = None):
def lazy(models: Iterable[Self]) -> Iterable[Self]:
now = datetime.now()
for model in models:
model.date_created = now
model.date_updated = now
yield model
return super().bulk_create(lazy(model_list), batch_size=batch_size)
@classmethod
def bulk_update(
cls, model_list: Iterable[Self], fields: Iterable, batch_size: int | None = None
):
def lazy(models: Iterable[Self]) -> Iterable[Self]:
now = datetime.now()
for model in models:
model.date_updated = now
yield model
fields = list(fields)
fields.append("date_updated")
return super().bulk_update(lazy(model_list), fields, batch_size=batch_size)
@classmethod
def get(cls, *query: peewee.Expression, **filters) -> Self:
return super().get(*query, **filters)
@classmethod
def get_or_none(cls, *query: peewee.Expression, **filters) -> Self | None:
return super().get_or_none(*query, **filters)
@classmethod
def get_by_id(cls, pk: PrimaryKey) -> Self:
return super().get_by_id(pk)
@classmethod
def get_or_create(cls, **kwargs: Any) -> Self:
return super().get_or_create(**kwargs)
@classmethod
def set_by_id(cls, key: PrimaryKey, value: Mapping[str, Any]):
return super().set_by_id(key, value)
@classmethod
def delete_by_id(cls, pk: PrimaryKey):
return super().delete_by_id(pk)
@classmethod
def filter(cls, *dq_nodes: peewee.DQ, **filters):
return super().filter(*dq_nodes, **filters)
def get_id(self) -> int:
return super().get_id()
def save(self, force_insert: bool = False, only: Iterable | None = None) -> int:
if only is not None and not isinstance(only, list):
only = list(only)
return super().save(force_insert, only)
@property
def dirty_fields(self) -> list:
return super().dirty_fields
def is_dirty(self) -> bool:
return super().is_dirty()