Skip to content

Commit

Permalink
Add create, update and delete subscriptions
Browse files Browse the repository at this point in the history
Use graphene-luna for testing purposes.

Also add a generic "DjangoSignalSubscription" type that allows you to subscribe to any Django signal.

Signed-off-by: Tormod Haugland <tormod.haugland@gmail.com>
tOgg1 committed Sep 16, 2024
1 parent d8f6a7c commit b6b98e7
Showing 9 changed files with 694 additions and 2 deletions.
2 changes: 2 additions & 0 deletions graphene_django_cud/consts.py
Original file line number Diff line number Diff line change
@@ -7,3 +7,5 @@

USE_ID_SUFFIXES_FOR_FK_SETTINGS_KEY = "GRAPHENE_DJANGO_CUD_USE_ID_SUFFIXES_FOR_FK"
USE_ID_SUFFIXES_FOR_M2M_SETTINGS_KEY = "GRAPHENE_DJANGO_CUD_USE_ID_SUFFIXES_FOR_M2M"

USE_MUTATION_SIGNALS_FOR_SUBSCRIPTIONS_KEY = "GRAPHENE_DJANGO_CUD_USE_MUTATION_SIGNALS_FOR_SUBSCRIPTIONS"
58 changes: 58 additions & 0 deletions graphene_django_cud/subscriptions/core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import graphene
from graphql import GraphQLError


class SubscriptionField(graphene.Field):
"""
This is an extension of the graphene.Field class that exists
to allow our DjangoCudSubscriptionBase classes to pass a subscribe
method to the Field instantiation, which we use here in the
`wrap_subscribe` method. `wrap_subscribe` is called internally in graphene
to figure out which resolver to use for a subscription field.
"""

def __init__(self, *args, subscribe=None, **kwargs):
self.subscribe = subscribe
super().__init__(*args, **kwargs)

def wrap_subscribe(self, parent_subscribe):
return self.subscribe


class DjangoCudSubscriptionBase(graphene.ObjectType):
"""Base class for DjangoCud subscriptions"""

@classmethod
def get_permissions(cls, root, info, *args, **kwargs):
return cls._meta.permissions

@classmethod
def check_permissions(cls, root, info, *args, **kwargs) -> None:
get_permissions = getattr(cls, "get_permissions", None)
if not callable(get_permissions):
raise TypeError("The `get_permissions` attribute of a subscription must be callable.")

permissions = cls.get_permissions(root, info, *args, **kwargs)

if permissions and len(permissions) > 0:
if not info.context.user.has_perms(permissions):
raise GraphQLError("Not permitted to access this subscription.")

@classmethod
def Field(cls, name=None, description=None, deprecation_reason=None, required=False):
"""Create a field for the subscription that automatically creates a subscription resolver"""
return SubscriptionField(
cls._meta.output,
resolver=cls._meta.resolver,
subscribe=cls._meta.subscribe,
name=name,
description=description or cls._meta.description,
deprecation_reason=deprecation_reason,
required=required,
)

@classmethod
async def subscribe(cls, *args, **kwargs):
"""Dummy subscribe method. Must be implemented by subclasses"""
raise NotImplementedError("`subscribe` must be implemented by the implementing subclass. "
"This is likely a bug in graphene-django-cud.")
129 changes: 129 additions & 0 deletions graphene_django_cud/subscriptions/create.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import asyncio
from collections import OrderedDict
from typing import Optional

import graphene
from asgiref.sync import async_to_sync
from django.conf import settings
from django.db.models.signals import post_save
from django.dispatch import Signal
from graphene.types.objecttype import ObjectTypeOptions
from graphene_django.registry import get_global_registry

from graphene_django_cud.consts import USE_MUTATION_SIGNALS_FOR_SUBSCRIPTIONS_KEY
from graphene_django_cud.signals import post_create_mutation
from graphene_django_cud.subscriptions.core import DjangoCudSubscriptionBase
from graphene_django_cud.util import to_snake_case


class DjangoCreateSubscriptionOptions(ObjectTypeOptions):
model = None
return_field_name = None
permissions = None
signal: Optional[Signal] = None


class DjangoCreateSubscription(DjangoCudSubscriptionBase):
# All active subscriptions are stored in this centralized dictionary.
# We need to do this to keep track of which subscriptions are listening to
# which signals.
subscribers = {}

@classmethod
def __init_subclass_with_meta__(
cls,
_meta=None,
model=None,
permissions=None,
return_field_name=None,
signal=post_create_mutation if getattr(
settings,
USE_MUTATION_SIGNALS_FOR_SUBSCRIPTIONS_KEY,
False
) else post_save,
**kwargs,
):
registry = get_global_registry()
model_type = registry.get_type_for_model(model)

if not _meta:
_meta = DjangoCreateSubscriptionOptions(cls)

if not return_field_name:
return_field_name = to_snake_case(model.__name__)

output_fields = OrderedDict()
output_fields[return_field_name] = graphene.Field(model_type)

_meta.model = model
_meta.model_type = model_type
_meta.fields = output_fields
_meta.output = cls
_meta.permissions = permissions

# Importantly, this needs to be set to either nothing or the identity.
# Internally in graphene it will be defaulted to the identity function. If it
# isn't, graphene will try to pass the value resolve from the "subscribe" method
# through this resolver. If it is also set to "subscribe", we will get an issue with
# graphene trying to return an AsyncIterator.
_meta.resolver = None

# This is set to be the subscription resolver in the SubscriptionField class.
_meta.subscribe = cls.subscribe
_meta.return_field_name = return_field_name

# Connect to the model's post_save (or your custom) signal
signal.connect(cls._model_created_handler, sender=model)

super().__init_subclass_with_meta__(_meta=_meta, **kwargs)

@classmethod
def _model_created_handler(cls, sender, instance, created=None, **kwargs):
"""Handle model creation and notify subscribers"""
if created or created is None:
print(sender, instance, created, kwargs)
new_instance = cls.handle_object_created(sender, instance, **kwargs)

assert new_instance is None or isinstance(new_instance, cls._meta.model)

if new_instance:
instance = new_instance

# Notify all subscribers for the model
for subscriber in cls.subscribers.get(sender, []):
async_to_sync(subscriber)(instance)

@classmethod
def handle_object_created(cls, sender, instance, **kwargs):
"""Handle and modify any instance created"""
pass

@classmethod
def check_permissions(cls, root, info, *args, **kwargs) -> None:
return super().check_permissions(root, info, *args, **kwargs)

@classmethod
async def subscribe(cls, root, info, *args, **kwargs):
"""Subscribe to the model creation events asynchronously"""

cls.check_permissions(root, info, *args, **kwargs)

model = cls._meta.model
queue = asyncio.Queue()

# Ensure there's a list of subscribers for the model
if model not in cls.subscribers:
cls.subscribers[model] = []

# Add the queue's put method to the subscribers for this model
cls.subscribers[model].append(queue.put)

try:
while True:
# Wait for the next model instance to be created
instance = await queue.get()
data = {cls._meta.return_field_name: instance}
yield cls(**data)
finally:
# Clean up the subscriber when the subscription ends
cls.subscribers[model].remove(queue.put)
152 changes: 152 additions & 0 deletions graphene_django_cud/subscriptions/delete.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
import asyncio
from collections import OrderedDict
from typing import Optional

import graphene
from asgiref.sync import async_to_sync
from django.db.models.signals import post_save, post_delete
from graphene.types.objecttype import ObjectTypeOptions
from graphene.types.utils import yank_fields_from_attrs
from graphene_django.registry import get_global_registry
from requests import delete

from graphene_django_cud.subscriptions.core import DjangoCudSubscriptionBase
from graphene_django_cud.util import to_snake_case

from graphene_django_cud.util.dict import get_any_of
import logging

logger = logging.getLogger(__name__)


class DjangoDeleteSubscriptionOptions(ObjectTypeOptions):
model = None
return_field_name = None
permissions = None
signal = None


class DjangoDeleteSubscription(DjangoCudSubscriptionBase):
# All active subscriptions are stored in this centralized dictionary.
# We need to do this to keep track of which subscriptions are listening to
# which signals.
subscribers = {}

@classmethod
def __init_subclass_with_meta__(
cls,
_meta=None,
model=None,
permissions=None,
return_field_name=None,
signal=post_delete,
**kwargs,
):
registry = get_global_registry()
model_type = registry.get_type_for_model(model)

if not _meta:
_meta = DjangoDeleteSubscriptionOptions(cls)

if not return_field_name:
return_field_name = to_snake_case(model.__name__)

output_fields = OrderedDict()
output_fields["id"] = graphene.String()

_meta.model = model
_meta.model_type = model_type
_meta.fields = yank_fields_from_attrs(output_fields, _as=graphene.Field)
_meta.output = cls
_meta.permissions = permissions

# Importantly, this needs to be set to either nothing or the identity.
# Internally in graphene it will be defaulted to the identity function.
_meta.resolver = None

# This is set to be the subscription resolver in the SubscriptionField class.
_meta.subscribe = cls.subscribe
_meta.return_field_name = return_field_name

# Connect to the model's post_save signal
signal.connect(cls._model_deleted_handler, sender=model)

super().__init_subclass_with_meta__(_meta=_meta, **kwargs)

@classmethod
def _model_deleted_handler(cls, sender, *args, **kwargs):
"""Handle model updating and notify subscribers"""

Model = cls._meta.model

instance: Optional[Model] = kwargs.get("instance", None) or next(filter(
lambda x: isinstance(x, Model), args
), None)

deleted_id = get_any_of(
kwargs,
[
"pk",
"raw_id",
"input_id",
"id"
]
) if not instance else get_any_of(
instance,
[
"pk",
"id",
]
)

print(kwargs, args, deleted_id)

if deleted_id is None:
logger.warning("Received a delete signal for a model without an instance or an id being passed to the "
"signal handler. Are you using a compatible signal? Read the documentation for "
"graphene-django-cud for more information.")
return

new_deleted_id = cls.handle_object_deleted(sender, deleted_id, **kwargs)

if new_deleted_id is not None:
deleted_id = new_deleted_id

# Notify all subscribers for the model
for subscriber in cls.subscribers.get(sender, []):
async_to_sync(subscriber)(deleted_id)

@classmethod
def handle_object_deleted(cls, sender, deleted_id, **kwargs):
"""Handle and modify any instance created"""
pass

@classmethod
def check_permissions(cls, root, info, *args, **kwargs) -> None:
return super().check_permissions(root, info, *args, **kwargs)

@classmethod
async def subscribe(cls, root, info, *args, **kwargs):
"""Subscribe to the model creation events asynchronously"""

cls.check_permissions(root, info, *args, **kwargs)

model = cls._meta.model
queue = asyncio.Queue()

# Ensure there's a list of subscribers for the model
if model not in cls.subscribers:
cls.subscribers[model] = []

# Add the queue's put method to the subscribers for this model
cls.subscribers[model].append(queue.put)

try:
while True:
# Wait for the next model instance to be deleted
_id = await queue.get()

yield cls(id=_id)
finally:
# Clean up the subscriber when the subscription ends
cls.subscribers[model].remove(queue.put)
98 changes: 98 additions & 0 deletions graphene_django_cud/subscriptions/signal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import asyncio
from typing import Optional

from asgiref.sync import async_to_sync
from django.dispatch import Signal
from graphene import Field
from graphene.types.objecttype import ObjectTypeOptions
from graphene.types.utils import yank_fields_from_attrs

from graphene_django_cud.subscriptions.core import DjangoCudSubscriptionBase


class DjangoSignalSubscriptionOptions(ObjectTypeOptions):
permissions = None
signal: Optional[Signal] = None
sender = None


class DjangoSignalSubscription(DjangoCudSubscriptionBase):
subscribers = set()

@classmethod
def __init_subclass_with_meta__(
cls,
_meta=None,
permissions=None,
signal=None,
sender=None,
output=None,
**kwargs,
):
if not _meta:
_meta = DjangoSignalSubscriptionOptions(cls)

if not signal:
raise ValueError("You must specify a signal to subscribe to")

output = output or getattr(cls, "Output", None)

if not output:
# If output is defined, we don't need to get the fields
fields = {}
for base in reversed(cls.__mro__):
fields.update(yank_fields_from_attrs(base.__dict__, _as=Field))
output = cls

_meta.permissions = permissions
_meta.signal = signal
_meta.output = output

# Importantly, this needs to be set to either nothing or the identity.
# Internally in graphene it will be defaulted to the identity function. If it
# isn't, graphene will try to pass the value resolve from the "subscribe" method
# through this resolver. If it is also set to "subscribe", we will get an issue with
# graphene trying to return an AsyncIterator.
_meta.resolver = None

# This is set to be the subscription resolver in the SubscriptionField class.
_meta.subscribe = cls.subscribe

signal.connect(cls.handle_signal, sender=sender)

super().__init_subclass_with_meta__(_meta=_meta, **kwargs)

@classmethod
def handle_signal(cls, *args, **kwargs):
data_item = {
**kwargs,
"args": args,
}
for subscriber in cls.subscribers:
async_to_sync(subscriber)(data_item)

@classmethod
def transform_signal_data(cls, data):
"""Transform data into the appropriate dictionary for the fields associated
with this subscription"""
raise NotImplementedError("`transform_signal_data` must be implemented by the implementing subclass.")

@classmethod
async def subscribe(cls, root, info, *args, **kwargs):
"""Subscribe to the model creation events asynchronously"""
cls.check_permissions(root, info, *args, **kwargs)

queue = asyncio.Queue()

# Add the queue's put method to the subscribers for this model
cls.subscribers.add(queue.put)

try:
while True:
# Wait for the next signal to be fired.
signal_data = await queue.get()
data = cls.transform_signal_data(signal_data)
yield cls(**data)
finally:
# Clean up the subscriber when the subscription ends
cls.subscribers.remove(queue.put)
129 changes: 129 additions & 0 deletions graphene_django_cud/subscriptions/update.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import asyncio
from collections import OrderedDict

import graphene
from asgiref.sync import async_to_sync
from django.conf import settings
from django.db.models.signals import post_save
from graphene.types.objecttype import ObjectTypeOptions
from graphene_django.registry import get_global_registry

from graphene_django_cud.consts import USE_MUTATION_SIGNALS_FOR_SUBSCRIPTIONS_KEY
from graphene_django_cud.signals import post_update_mutation
from graphene_django_cud.subscriptions.core import DjangoCudSubscriptionBase
from graphene_django_cud.util import to_snake_case


class DjangoUpdateSubscriptionOptions(ObjectTypeOptions):
model = None
return_field_name = None
permissions = None
signal = None


class DjangoUpdateSubscription(DjangoCudSubscriptionBase):
# All active subscriptions are stored in this centralized dictionary.
# We need to do this to keep track of which subscriptions are listening to
# which signals.
subscribers = {}

@classmethod
def __init_subclass_with_meta__(
cls,
_meta=None,
model=None,
permissions=None,
return_field_name=None,
signal=post_update_mutation if getattr(
settings,
USE_MUTATION_SIGNALS_FOR_SUBSCRIPTIONS_KEY,
False
) else post_save,
**kwargs,
):
registry = get_global_registry()
model_type = registry.get_type_for_model(model)

if not _meta:
_meta = DjangoUpdateSubscriptionOptions(cls)

if not return_field_name:
return_field_name = to_snake_case(model.__name__)

output_fields = OrderedDict()
output_fields[return_field_name] = graphene.Field(model_type)

_meta.model = model
_meta.model_type = model_type
_meta.fields = output_fields
_meta.output = cls
_meta.permissions = permissions

# Importantly, this needs to be set to either nothing or the identity.
# Internally in graphene it will be defaulted to the identity function. If it
# isn't, graphene will try to pass the value resolve from the "subscribe" method
# through this resolver. If it is also set to "subscribe", we will get an issue with
# graphene trying to return an AsyncIterator.
_meta.resolver = None

# This is set to be the subscription resolver in the SubscriptionField class.
_meta.subscribe = cls.subscribe
_meta.return_field_name = return_field_name

# Connect to the model's post_save (or your custom) signal
signal.connect(cls._model_updated_handler, sender=model)

super().__init_subclass_with_meta__(_meta=_meta, **kwargs)

@classmethod
def _model_updated_handler(cls, sender, instance, created=None, **kwargs):
"""Handle model updating and notify subscribers"""

if created is not None and not created:
return

new_instance = cls.handle_object_updated(sender, instance, **kwargs)

assert new_instance is None or isinstance(new_instance, cls._meta.model)

if new_instance:
instance = new_instance

# Notify all subscribers for the model
for subscriber in cls.subscribers.get(sender, []):
async_to_sync(subscriber)(instance)

@classmethod
def handle_object_updated(cls, sender, instance, **kwargs):
"""Handle and modify any instance created"""
pass

@classmethod
def check_permissions(cls, root, info, *args, **kwargs) -> None:
return super().check_permissions(root, info, *args, **kwargs)

@classmethod
async def subscribe(cls, root, info, *args, **kwargs):
"""Subscribe to the model update events asynchronously"""

cls.check_permissions(root, info, *args, **kwargs)

model = cls._meta.model
queue = asyncio.Queue()

# Ensure there's a list of subscribers for the model
if model not in cls.subscribers:
cls.subscribers[model] = []

# Add the queue's put method to the subscribers for this model
cls.subscribers[model].append(queue.put)

try:
while True:
# Wait for the next model instance to be updated
instance = await queue.get()
data = {cls._meta.return_field_name: instance}
yield cls(**data)
finally:
# Clean up the subscriber when the subscription ends
cls.subscribers[model].remove(queue.put)
94 changes: 92 additions & 2 deletions graphene_django_cud/tests/schema.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
import asyncio
import random

import graphene
from asgiref.sync import sync_to_async
from django.dispatch import Signal
from graphene import Node, Schema
from graphene_django import DjangoObjectType, DjangoConnectionField

@@ -11,6 +16,11 @@
DjangoBatchCreateMutation,
)
from graphene_django_cud.mutations.filter_update import DjangoFilterUpdateMutation
from graphene_django_cud.signals import post_create_mutation, post_update_mutation, post_delete_mutation
from graphene_django_cud.subscriptions.create import DjangoCreateSubscription
from graphene_django_cud.subscriptions.delete import DjangoDeleteSubscription
from graphene_django_cud.subscriptions.signal import DjangoSignalSubscription
from graphene_django_cud.subscriptions.update import DjangoUpdateSubscription
from graphene_django_cud.tests.models import (
User,
Cat,
@@ -275,8 +285,20 @@ class Meta:
model = Fish


class Mutations(graphene.ObjectType):
test_signal = Signal()


class FireRandomSignal(graphene.Mutation):
ok = graphene.Boolean()

@classmethod
def mutate(cls, root, info):
test_signal.send(sender=cls, value=random.randint(0, 100))

return cls(ok=True)


class Mutations(graphene.ObjectType):
create_user = CreateUserMutation.Field()

batch_create_user = BatchCreateUserMutation.Field()
@@ -310,5 +332,73 @@ class Mutations(graphene.ObjectType):
update_fish = UpdateFishMutation.Field()
delete_fish = DeleteFishMutation.Field(0)

fire_random_signal = FireRandomSignal.Field()


class FishCreatedSubscription(DjangoCreateSubscription):
class Meta:
model = Fish


class CatCreatedSubscription(DjangoCreateSubscription):
class Meta:
model = Cat
signal = post_create_mutation

# noinspection PyStatementEffect
@classmethod
def handle_object_created(cls, sender, instance: Cat, *args, **kwargs):
cat = Cat.objects.select_related("owner").prefetch_related("enemies").get(pk=instance.pk)

return cat


class CatUpdatedSubscription(DjangoUpdateSubscription):
class Meta:
model = Cat
signal = post_update_mutation

# noinspection PyStatementEffect
@classmethod
def handle_object_updated(cls, sender, instance: Cat, *args, **kwargs):
cat = Cat.objects.select_related("owner").prefetch_related("enemies").get(pk=instance.pk)

return cat


class CatDeletedSubscription(DjangoDeleteSubscription):
class Meta:
model = Cat
signal = post_delete_mutation


class RandomSignalFiredSubscription(DjangoSignalSubscription):
lets_go = graphene.String()

@classmethod
def transform_signal_data(cls, data):
return {"lets_go": f"go {data.get('value', 0)}"}

class Meta:
signal = test_signal


def get_random_fish():
return random.choice(list(Fish.objects.all()))


def get_random_cat():
return random.choice(list(Cat.objects.all()))


class Subscription(graphene.ObjectType):
fish_created = FishCreatedSubscription.Field()
cat_created = CatCreatedSubscription.Field()

cat_updated = CatUpdatedSubscription.Field()

cat_deleted = CatDeletedSubscription.Field()
test_signal_fired = RandomSignalFiredSubscription.Field()


schema = Schema(query=Query, mutation=Mutations)
schema = Schema(query=Query, mutation=Mutations, subscription=Subscription)
12 changes: 12 additions & 0 deletions graphene_django_cud/util/dict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
def get_any_of(dict_or_obj_like, keys, default=None):
"""Get the first key in a dict-like object that is not None"""

is_dict = isinstance(dict_or_obj_like, dict)

for key in keys:
value = dict_or_obj_like.get(key) if is_dict else getattr(dict_or_obj_like, key, None)

if value is not None:
return value

return default
22 changes: 22 additions & 0 deletions graphene_django_cud/ws_urls.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
"""graphene_django_cud URL Configuration
The `urlpatterns` list routes URLs to views. For more information please see:
https://docs.djangoproject.com/en/2.1/topics/http/urls/
Examples:
Function views
1. Add an import: from my_app import views
2. Add a URL to urlpatterns: path('', views.home, name='home')
Class-based views
1. Add an import: from other_app.views import Home
2. Add a URL to urlpatterns: path('', Home.as_view(), name='home')
Including another URLconf
1. Import the include() function: from django.urls import include, path
2. Add a URL to urlpatterns: path('blog/', include('blog.urls'))
"""
import luna_ws
from django.contrib.staticfiles.urls import staticfiles_urlpatterns
from django.urls import path

urlpatterns = [
path("graphql", luna_ws.GraphQLSubscriptionHandler)
]

0 comments on commit b6b98e7

Please sign in to comment.