from datetime import datetime from typing import Any, Iterable, Mapping, MutableMapping, Self, Sequence import peewee PrimaryKey = int SelectType = peewee.Model | peewee.Field | peewee.Function | peewee.Expression TupleRow = tuple[Any, ...] MapRow = MutableMapping[str, Any] Row = TupleRow | MapRow Rows = Sequence[TupleRow] | Sequence[MapRow] IterRows = Iterable[TupleRow] | Iterable[MapRow] class BaseModel(peewee.Model): date_created = peewee.DateTimeField(default=datetime.now) date_updated = peewee.DateTimeField(default=datetime.now) @classmethod def select(cls, *fields: list[Self | SelectType]) -> peewee.ModelSelect: return super().select(*fields) @classmethod def update(cls, __data=None, **update: Any) -> peewee.ModelUpdate: if "date_updated" not in update: update["date_updated"] = datetime.now() return super().update(__data, **update) @classmethod def insert(cls, __data=None, **insert: Any) -> peewee.ModelInsert: insert["date_updated"] = datetime.now() return super().insert(__data=__data, **insert) @classmethod def insert_many( cls, rows: IterRows, fields: Sequence[peewee.Field | str] | None = None ) -> peewee.ModelInsert: def lazy_tuples(rows_: Iterable[TupleRow]) -> Iterable[TupleRow]: now = datetime.now() for row in rows_: yield tuple([*row, now, now]) def lazy_dicts(rows_: Iterable[MapRow]) -> Iterable[MapRow]: now = datetime.now() for row in rows_: row["date_created"] = now row["date_updated"] = now yield row if fields: lazy = lazy_tuples if not isinstance(fields, list): fields = list(fields) fields.append("date_created") fields.append("date_updated") else: lazy = lazy_dicts # Safe to ignore type checking here because the presence of a fields list # implies that tuples are being used. Also, it's peewee's problem to # handle that anyways. return super().insert_many(lazy(rows), fields) # type: ignore @classmethod def insert_from( cls, query: peewee.SelectQuery, fields: Sequence[peewee.Field | str] ) -> peewee.ModelInsert: field_names: list[str] = [ field.column_name if isinstance(field, peewee.Field) else field for field in fields ] def lazy(query_: peewee.SelectQuery): select = query_.select_from(field_names) for row in select: values = [] for field_name in field_names: values.append(getattr(row, field_name)) yield tuple(values) return cls.insert_many(lazy(query), fields=field_names) @classmethod def replace(cls, __data=None, **insert: Any) -> peewee.ModelInsert: now = datetime.now() insert["date_created"] = now insert["date_updated"] = now return super().replace(__data, **insert) @classmethod def replace_many( cls, rows: Iterable, fields: list[peewee.Field | str] | None = None ) -> peewee.ModelInsert: return cls.insert_many(rows=rows, fields=fields).on_conflict("REPLACE") @classmethod def raw(cls, sql: str, *params: Iterable[Any]) -> Self: return super().raw(sql, *params) @classmethod def delete(cls) -> peewee.ModelDelete: return super().delete() @classmethod def create(cls, **query: Any): now = datetime.now() query["date_created"] = now query["date_updated"] = now return super().create(**query) @classmethod def bulk_create(cls, model_list: Iterable[Self], batch_size: int | None = None): def lazy(models: Iterable[Self]) -> Iterable[Self]: now = datetime.now() for model in models: model.date_created = now model.date_updated = now yield model return super().bulk_create(lazy(model_list), batch_size=batch_size) @classmethod def bulk_update( cls, model_list: Iterable[Self], fields: Iterable, batch_size: int | None = None ): def lazy(models: Iterable[Self]) -> Iterable[Self]: now = datetime.now() for model in models: model.date_updated = now yield model fields = list(fields) fields.append("date_updated") return super().bulk_update(lazy(model_list), fields, batch_size=batch_size) @classmethod def get(cls, *query: peewee.Expression, **filters) -> Self: return super().get(*query, **filters) @classmethod def get_or_none(cls, *query: peewee.Expression, **filters) -> Self | None: return super().get_or_none(*query, **filters) @classmethod def get_by_id(cls, pk: PrimaryKey) -> Self: return super().get_by_id(pk) @classmethod def get_or_create(cls, **kwargs: Any) -> Self: return super().get_or_create(**kwargs) @classmethod def set_by_id(cls, key: PrimaryKey, value: Mapping[str, Any]): return super().set_by_id(key, value) @classmethod def delete_by_id(cls, pk: PrimaryKey): return super().delete_by_id(pk) @classmethod def filter(cls, *dq_nodes: peewee.DQ, **filters): return super().filter(*dq_nodes, **filters) def get_id(self) -> int: return super().get_id() def save(self, force_insert: bool = False, only: Iterable | None = None) -> int: if only is not None and not isinstance(only, list): only = list(only) return super().save(force_insert, only) @property def dirty_fields(self) -> list: return super().dirty_fields def is_dirty(self) -> bool: return super().is_dirty()