Skip to content

Add type annotations #76

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,5 @@ def read_relative_file(filename):
"Programming Language :: Python :: 3.12",
"Topic :: Utilities",
],
package_data={"typedmodels": ["py.typed"]},
)
30 changes: 25 additions & 5 deletions typedmodels/admin.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,35 @@
from django.contrib import admin
from .models import TypedModel
from typing import TYPE_CHECKING, Optional, Sequence, Callable, Any, Generic, Type

class TypedModelAdmin(admin.ModelAdmin):
def get_fields(self, request, obj=None):
from django.contrib.admin import ModelAdmin

from .models import TypedModel, TypedModelT

if TYPE_CHECKING:
from django.http import HttpRequest
from django.forms.forms import BaseForm


class TypedModelAdmin(ModelAdmin, Generic[TypedModelT]):
model: "Type[TypedModelT]"

def get_fields(
self,
request: "HttpRequest",
obj: "Optional[TypedModelT]" = None,
) -> Sequence[Callable[..., Any] | str]:
fields = super().get_fields(request, obj)
# we remove the type field from the admin of subclasses.
if TypedModel not in self.model.__bases__:
fields.remove(self.model._meta.get_field('type').name)
return fields

def save_model(self, request, obj, form, change):
def save_model(
self,
request: "HttpRequest",
obj: "TypedModelT",
form: "BaseForm",
change,
) -> None:
if getattr(obj, '_typedmodels_type', None) is None:
# new instances don't have the type attribute
obj._typedmodels_type = form.cleaned_data['type']
Expand Down
42 changes: 29 additions & 13 deletions typedmodels/models.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import TYPE_CHECKING, Optional, Type, TypeVar, Generic

from functools import partial
import types

Expand All @@ -11,13 +13,21 @@
from django.db.models.options import make_immutable_fields_list
from django.utils.encoding import smart_str

if TYPE_CHECKING:
from django.db.models import Model, QuerySet


TypedModelT = TypeVar("TypedModelT", bound="TypedModel")


class TypedModelManager(models.Manager):
def get_queryset(self):
class TypedModelManager(models.Manager, Generic[TypedModelT]):
model: "Type[TypedModelT]"

def get_queryset(self) -> "QuerySet[TypedModelT]":
qs = super(TypedModelManager, self).get_queryset()
return self._filter_by_type(qs)

def _filter_by_type(self, qs):
def _filter_by_type(self, qs: "QuerySet[TypedModelT]"):
if hasattr(self.model, "_typedmodels_type"):
if len(self.model._typedmodels_subtypes) > 1:
qs = qs.filter(type__in=self.model._typedmodels_subtypes)
Expand Down Expand Up @@ -201,7 +211,7 @@ def do_related_class(self, other, cls):

# add a get_type_classes classmethod to allow fetching of all the subclasses (useful for admin)

def get_type_classes(subcls):
def get_type_classes(subcls: "Type[TypedModelT]"):
if subcls is cls:
return list(cls._typedmodels_registry.values())
else:
Expand All @@ -212,7 +222,7 @@ def get_type_classes(subcls):

cls.get_type_classes = classmethod(get_type_classes)

def get_types(subcls):
def get_types(subcls: "Type[TypedModelT]"):
if subcls is cls:
return list(cls._typedmodels_registry.keys())
else:
Expand All @@ -223,7 +233,7 @@ def get_types(subcls):
return cls

@staticmethod
def _model_has_field(cls, base_class, field_name):
def _model_has_field(cls, base_class: "Type[TypedModelT]", field_name: str):
if field_name in base_class._meta._typedmodels_original_many_to_many:
return True
if field_name in base_class._meta._typedmodels_original_fields:
Expand All @@ -242,7 +252,7 @@ def _model_has_field(cls, base_class, field_name):
return False

@staticmethod
def _patch_fields_cache(cls, base_class):
def _patch_fields_cache(cls, base_class: "Type[TypedModelT]"):
orig_get_fields = cls._meta._get_fields

if django.VERSION >= (5, 0):
Expand Down Expand Up @@ -358,6 +368,8 @@ class Feline(Animal):
def say_something(self):
return "meoww"
'''
_typedmodels_type: Optional[str]
_typedmodels_subtypes: Optional[list[str]]
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

list[str] requires from __future__ import annotations until Python 3.9

This project still states support for python 3.6+, although mostly because I haven't gotten around to changing that...


objects = TypedModelManager()

Expand Down Expand Up @@ -406,7 +418,7 @@ def from_db(cls, db, field_names, values):
new._state.db = db
return new

def __init__(self, *args, _typedmodels_do_recast=None, **kwargs):
def __init__(self, *args, _typedmodels_do_recast=None, **kwargs) -> None:
# Calling __init__ on base class because some functions (e.g. save()) need access to field values from base
# class.

Expand All @@ -433,7 +445,7 @@ def __init__(self, *args, _typedmodels_do_recast=None, **kwargs):
if _typedmodels_do_recast:
self.recast()

def recast(self, typ=None):
def recast(self, typ: "Optional[Type[TypedModel]]" = None) -> None:
for base in reversed(self.__class__.mro()):
if issubclass(base, TypedModel) and hasattr(base, "_typedmodels_registry"):
break
Expand Down Expand Up @@ -471,10 +483,14 @@ def recast(self, typ=None):
if current_cls != correct_cls:
self.__class__ = correct_cls

def save(self, *args, **kwargs):
def save(self, *args, **kwargs) -> None:
self.presave(*args, **kwargs)
return super(TypedModel, self).save(*args, **kwargs)

def presave(self, *args, **kwargs) -> None:
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why a separate method for this? I'd prefer to leave this in save() or underscore-prefix it, to avoid adding to the public API

"""Perform checks before saving the model."""
if not getattr(self, "_typedmodels_type", None):
raise RuntimeError("Untyped %s cannot be saved." % self.__class__.__name__)
return super(TypedModel, self).save(*args, **kwargs)

def _get_unique_checks(self, exclude=None, **kwargs):
unique_checks, date_checks = super(TypedModel, self)._get_unique_checks(
Expand All @@ -498,7 +514,7 @@ def _get_unique_checks(self, exclude=None, **kwargs):
_python_serializer_get_dump_object = _PythonSerializer.get_dump_object


def _get_dump_object(self, obj):
def _get_dump_object(self, obj: "Model") -> dict:
if isinstance(obj, TypedModel):
return {
"pk": smart_str(obj._get_pk_val(), strings_only=True),
Expand All @@ -514,7 +530,7 @@ def _get_dump_object(self, obj):
_xml_serializer_start_object = _XmlSerializer.start_object


def _start_object(self, obj):
def _start_object(self, obj: "Model") -> None:
if isinstance(obj, TypedModel):
self.indent(1)
obj_pk = obj._get_pk_val()
Expand Down
Empty file added typedmodels/py.typed
Empty file.