diff --git a/contentcuration/contentcuration/constants/locking.py b/contentcuration/contentcuration/constants/locking.py new file mode 100644 index 0000000000..6b53fbd081 --- /dev/null +++ b/contentcuration/contentcuration/constants/locking.py @@ -0,0 +1,5 @@ +""" +Constants for locking behaviors, like advisory locking in Postgres, and mutexes +""" +TREE_LOCK = 1001 +TASK_LOCK = 1002 diff --git a/contentcuration/contentcuration/db/advisory_lock.py b/contentcuration/contentcuration/db/advisory_lock.py index 61d53a379f..f1d71995ed 100644 --- a/contentcuration/contentcuration/db/advisory_lock.py +++ b/contentcuration/contentcuration/db/advisory_lock.py @@ -6,11 +6,36 @@ logging = logger.getLogger(__name__) +# signed limits are 2**32 or 2**64, so one less power of 2 +# to become unsigned limits (half above 0, half below 0) +INT_32BIT = 2**31 +INT_64BIT = 2**63 + class AdvisoryLockBusy(RuntimeError): pass +def _prepare_keys(keys): + """ + Ensures that integers do not exceed postgres constraints: + - signed 64bit allowed with single key + - signed 32bit allowed with two keys + :param keys: A list of unsigned integers + :return: A list of signed integers + """ + limit = INT_64BIT if len(keys) == 1 else INT_32BIT + new_keys = [] + for key in keys: + # if key is over the limit, convert to negative int since key should be unsigned int + if key >= limit: + key = limit - key + if key < -limit or key >= limit: + raise OverflowError(f"Advisory lock key '{key}' is too large") + new_keys.append(key) + return new_keys + + @contextmanager def execute_lock(key1, key2=None, unlock=False, session=False, shared=False, wait=True): """ @@ -32,6 +57,7 @@ def execute_lock(key1, key2=None, unlock=False, session=False, shared=False, wai keys = [key1] if key2 is not None: keys.append(key2) + keys = _prepare_keys(keys) query = "SELECT pg{_try}_advisory_{xact_}{lock}{_shared}({keys}) AS lock;".format( _try="" if wait else "_try", @@ -41,11 +67,11 @@ def execute_lock(key1, key2=None, unlock=False, session=False, shared=False, wai keys=", ".join(["%s" for i in range(0, 2 if key2 is not None else 1)]) ) - log_query = "'{}' with params {}".format(query, keys) - logging.debug("Acquiring advisory lock: {}".format(query, log_query)) + log_query = f"'{query}' with params {keys}" + logging.debug(f"Acquiring advisory lock: {log_query}") with connection.cursor() as c: c.execute(query, keys) - logging.debug("Acquired advisory lock: {}".format(query, log_query)) + logging.debug(f"Acquired advisory lock: {log_query}") yield c diff --git a/contentcuration/contentcuration/db/models/manager.py b/contentcuration/contentcuration/db/models/manager.py index 3556fe8e70..72e15186a7 100644 --- a/contentcuration/contentcuration/db/models/manager.py +++ b/contentcuration/contentcuration/db/models/manager.py @@ -12,6 +12,7 @@ from mptt.managers import TreeManager from mptt.signals import node_moved +from contentcuration.constants.locking import TREE_LOCK from contentcuration.db.advisory_lock import advisory_lock from contentcuration.db.models.query import CustomTreeQuerySet from contentcuration.utils.cache import ResourceSizeCache @@ -32,7 +33,6 @@ # The exact optimum batch size is probably highly dependent on tree # topology also, so these rudimentary tests are likely insufficient BATCH_SIZE = 100 -TREE_LOCK = 1001 class CustomManager(Manager.from_queryset(CTEQuerySet)): diff --git a/contentcuration/contentcuration/frontend/channelEdit/vuex/clipboard/actions.js b/contentcuration/contentcuration/frontend/channelEdit/vuex/clipboard/actions.js index 2358d4ded4..5dd2c62763 100644 --- a/contentcuration/contentcuration/frontend/channelEdit/vuex/clipboard/actions.js +++ b/contentcuration/contentcuration/frontend/channelEdit/vuex/clipboard/actions.js @@ -1,5 +1,6 @@ import get from 'lodash/get'; import partition from 'lodash/partition'; +import chunk from 'lodash/chunk'; import uniq from 'lodash/uniq'; import uniqBy from 'lodash/uniqBy'; import defer from 'lodash/defer'; @@ -83,12 +84,20 @@ export function loadClipboardNodes(context, { parent }) { const legacyNodeIds = legacyNodes.map(n => n.id); return Promise.all([ - context.dispatch( - 'contentNode/loadContentNodes', - { '[node_id+channel_id]__in': nodeIdChannelIdPairs }, - { root } + // To avoid error code 414 URI Too Long errors, we chunk the pairs + // Given URI limit is 2000 chars: + // base URL at 100 chars + each pair at 70 chars = max 27 pairs + ...chunk(nodeIdChannelIdPairs, 25).map(chunkPairs => + context.dispatch( + 'contentNode/loadContentNodes', + { '[node_id+channel_id]__in': chunkPairs }, + { root } + ) + ), + // Chunk legacy nodes, double the size since not pairs + ...chunk(legacyNodeIds, 50).map(legacyChunk => + context.dispatch('contentNode/loadContentNodes', { id__in: legacyChunk }, { root }) ), - context.dispatch('contentNode/loadContentNodes', { id__in: legacyNodeIds }, { root }), ]).then(() => { return context.dispatch('addClipboardNodes', { nodes: clipboardNodes, diff --git a/contentcuration/contentcuration/frontend/shared/client.js b/contentcuration/contentcuration/frontend/shared/client.js index e962e46dd4..53d5bbb309 100644 --- a/contentcuration/contentcuration/frontend/shared/client.js +++ b/contentcuration/contentcuration/frontend/shared/client.js @@ -58,7 +58,7 @@ client.interceptors.response.use( } } - message = message ? `${message}: ${url}` : `Network Error: ${url}`; + message = message ? `${message}: [${status}] ${url}` : `Network Error: [${status}] ${url}`; if (process.env.NODE_ENV !== 'production') { // In dev build log warnings to console for developer use diff --git a/contentcuration/contentcuration/frontend/shared/data/index.js b/contentcuration/contentcuration/frontend/shared/data/index.js index 9665588e44..017fcfed46 100644 --- a/contentcuration/contentcuration/frontend/shared/data/index.js +++ b/contentcuration/contentcuration/frontend/shared/data/index.js @@ -1,4 +1,5 @@ import Dexie from 'dexie'; +import * as Sentry from '@sentry/vue'; import mapValues from 'lodash/mapValues'; import channel from './broadcastChannel'; import { CHANGES_TABLE, IGNORED_SOURCE, TABLE_NAMES } from './constants'; @@ -47,11 +48,25 @@ function runElection() { elector.awaitLeadership().then(startSyncing); elector.onduplicate = () => { stopSyncing(); - elector.die().then(runElection); + elector + .die() + .then(() => { + // manually reset reference to dead elector on the channel + // which is set within `createLeaderElection` and whose + // presence is also validated against, requiring its removal + channel._leaderElector = null; + return runElection(); + }) + .catch(Sentry.captureException); }; } -export function initializeDB() { - setupSchema(); - return db.open().then(runElection); +export async function initializeDB() { + try { + setupSchema(); + await db.open(); + await runElection(); + } catch (e) { + Sentry.captureException(e); + } } diff --git a/contentcuration/contentcuration/frontend/shared/data/serverSync.js b/contentcuration/contentcuration/frontend/shared/data/serverSync.js index e114b8cd19..68b0e7f02b 100644 --- a/contentcuration/contentcuration/frontend/shared/data/serverSync.js +++ b/contentcuration/contentcuration/frontend/shared/data/serverSync.js @@ -1,3 +1,4 @@ +import * as Sentry from '@sentry/vue'; import debounce from 'lodash/debounce'; import findLastIndex from 'lodash/findLastIndex'; import get from 'lodash/get'; @@ -99,6 +100,20 @@ function handleDisallowed(response) { // that were rejected. const disallowed = get(response, ['data', 'disallowed'], []); if (disallowed.length) { + // Capture occurrences of the api disallowing changes + if (process.env.NODE_ENV === 'production') { + Sentry.withScope(function(scope) { + scope.addAttachment({ + filename: 'disallowed.json', + data: JSON.stringify(disallowed), + contentType: 'application/json', + }); + Sentry.captureException(new Error('/api/sync returned disallowed changes')); + }); + } else { + console.warn('/api/sync returned disallowed changes:', disallowed); // eslint-disable-line no-console + } + // Collect all disallowed const disallowedRevs = disallowed.map(d => Number(d.rev)); // Set the return error data onto the changes - this will update the change diff --git a/contentcuration/contentcuration/migrations/0142_add_task_signature.py b/contentcuration/contentcuration/migrations/0142_add_task_signature.py new file mode 100644 index 0000000000..194580211c --- /dev/null +++ b/contentcuration/contentcuration/migrations/0142_add_task_signature.py @@ -0,0 +1,28 @@ +# Generated by Django 3.2.14 on 2022-12-09 16:09 +from django.db import migrations +from django.db import models + + +class Migration(migrations.Migration): + + replaces = [('django_celery_results', '0140_delete_task'),] + + def __init__(self, name, app_label): + super(Migration, self).__init__(name, 'django_celery_results') + + dependencies = [ + ('contentcuration', '0141_soft_delete_user'), + ('django_celery_results', '0011_taskresult_periodic_task_name'), + ] + + operations = [ + migrations.AddField( + model_name='taskresult', + name='signature', + field=models.CharField(max_length=32, null=True), + ), + migrations.AddIndex( + model_name='taskresult', + index=models.Index(condition=models.Q(('status__in', frozenset(['STARTED', 'REJECTED', 'RETRY', 'RECEIVED', 'PENDING']))), fields=['signature'], name='task_result_signature_idx'), + ), + ] diff --git a/contentcuration/contentcuration/models.py b/contentcuration/contentcuration/models.py index 1317453267..4240b19844 100644 --- a/contentcuration/contentcuration/models.py +++ b/contentcuration/contentcuration/models.py @@ -7,6 +7,7 @@ from datetime import datetime import pytz +from celery import states as celery_states from django.conf import settings from django.contrib.auth.base_user import AbstractBaseUser from django.contrib.auth.base_user import BaseUserManager @@ -75,6 +76,7 @@ from contentcuration.db.models.manager import CustomManager from contentcuration.statistics import record_channel_stats from contentcuration.utils.cache import delete_public_channel_cache_keys +from contentcuration.utils.celery.tasks import generate_task_signature from contentcuration.utils.parser import load_json_string from contentcuration.viewsets.sync.constants import ALL_CHANGES from contentcuration.viewsets.sync.constants import ALL_TABLES @@ -2541,13 +2543,20 @@ def serialize_to_change_dict(self): class TaskResultCustom(object): """ Custom fields to add to django_celery_results's TaskResult model + + If adding fields to this class, run `makemigrations` then move the generated migration from the + `django_celery_results` app to the `contentcuration` app and override the constructor to change + the app_label. See `0141_add_task_signature` for an example """ # user shouldn't be null, but in order to append the field, this needs to be allowed user = models.ForeignKey(settings.AUTH_USER_MODEL, related_name="tasks", on_delete=models.CASCADE, null=True) channel_id = DjangoUUIDField(db_index=True, null=True, blank=True) progress = models.IntegerField(null=True, blank=True, validators=[MinValueValidator(0), MaxValueValidator(100)]) + # a hash of the task name and kwargs for identifying repeat tasks + signature = models.CharField(null=True, blank=False, max_length=32) super_as_dict = TaskResult.as_dict + super_save = TaskResult.save def as_dict(self): """ @@ -2561,6 +2570,22 @@ def as_dict(self): ) return super_dict + def set_signature(self): + """ + Generates and sets the signature for the task if it isn't set + """ + if self.signature is not None: + # nothing to do + return + self.signature = generate_task_signature(self.task_name, task_kwargs=self.task_kwargs, channel_id=self.channel_id) + + def save(self, *args, **kwargs): + """ + Override save to ensure signature is generated + """ + self.set_signature() + return self.super_save(*args, **kwargs) + @classmethod def contribute_to_class(cls, model_class=TaskResult): """ @@ -2568,9 +2593,22 @@ def contribute_to_class(cls, model_class=TaskResult): :param model_class: TaskResult model """ for field in dir(cls): - if not field.startswith("_"): + if not field.startswith("_") and field not in ('contribute_to_class', 'Meta'): model_class.add_to_class(field, getattr(cls, field)) + # manually add Meta afterwards + setattr(model_class._meta, 'indexes', getattr(model_class._meta, 'indexes', []) + cls.Meta.indexes) + + class Meta: + indexes = [ + # add index that matches query usage for signature + models.Index( + fields=['signature'], + name='task_result_signature_idx', + condition=Q(status__in=celery_states.UNREADY_STATES), + ), + ] + # trigger class contributions immediately TaskResultCustom.contribute_to_class() diff --git a/contentcuration/contentcuration/tests/test_asynctask.py b/contentcuration/contentcuration/tests/test_asynctask.py index e24e04933e..bbd2714f47 100644 --- a/contentcuration/contentcuration/tests/test_asynctask.py +++ b/contentcuration/contentcuration/tests/test_asynctask.py @@ -234,7 +234,8 @@ def test_fetch_or_enqueue_task__channel_id__uuid_then_hex(self): self.assertEqual(expected_task.task_id, async_result.task_id) def test_requeue_task(self): - existing_task_ids = requeue_test_task.find_ids() + signature = requeue_test_task._generate_signature({}) + existing_task_ids = requeue_test_task.find_ids(signature) self.assertEqual(len(existing_task_ids), 0) first_async_result = requeue_test_task.enqueue(self.user, requeue=True) diff --git a/contentcuration/contentcuration/utils/celery/tasks.py b/contentcuration/contentcuration/utils/celery/tasks.py index 630fa92d0b..5b55d83bd9 100644 --- a/contentcuration/contentcuration/utils/celery/tasks.py +++ b/contentcuration/contentcuration/utils/celery/tasks.py @@ -1,11 +1,18 @@ +import contextlib +import hashlib import logging import math import uuid +import zlib +from collections import OrderedDict from celery import states from celery.app.task import Task from celery.result import AsyncResult +from django.db import transaction +from contentcuration.constants.locking import TASK_LOCK +from contentcuration.db.advisory_lock import advisory_lock from contentcuration.utils.sentry import report_exception @@ -66,6 +73,22 @@ def get_task_model(ref, task_id): return ref.backend.TaskModel.objects.get_task(task_id) +def generate_task_signature(task_name, task_kwargs=None, channel_id=None): + """ + :type task_name: str + :param task_kwargs: the celery encoded/serialized form of the task_kwargs dict + :type task_kwargs: str|None + :type channel_id: str|None + :return: A hex string, md5 + :rtype: str + """ + md5 = hashlib.md5() + md5.update(task_name.encode('utf-8')) + md5.update((task_kwargs or '').encode('utf-8')) + md5.update((channel_id or '').encode('utf-8')) + return md5.hexdigest() + + class CeleryTask(Task): """ This is set as the Task class on our Celery app, so to track progress on a task, mark it @@ -107,35 +130,56 @@ def shadow_name(self, *args, **kwargs): """ return super(CeleryTask, self).shadow_name(*args, **kwargs) - def find_ids(self, channel_id=None, **kwargs): + def _prepare_kwargs(self, kwargs): """ - :param channel_id: - :param kwargs: Keyword arguments sent to the task, which will be matched against - :return: A TaskResult queryset - :rtype: django.db.models.query.QuerySet + Prepares kwargs, converting UUID to their hex value """ - task_qs = self.TaskModel.objects.filter(task_name=self.name) + return OrderedDict( + (key, value.hex if isinstance(value, uuid.UUID) else value) + for key, value in kwargs.items() + ) - # add channel filter since we have dedicated field - if channel_id: - task_qs = task_qs.filter(channel_id=channel_id) - else: - task_qs = task_qs.filter(channel_id__isnull=True) + def _generate_signature(self, kwargs): + """ + :param kwargs: A dictionary of task kwargs + :return: An hex string representing an md5 hash of task metadata + """ + prepared_kwargs = self._prepare_kwargs(kwargs) + return generate_task_signature( + self.name, + task_kwargs=self.backend.encode(prepared_kwargs), + channel_id=prepared_kwargs.get('channel_id') + ) - # search for task args in values - for value in kwargs.values(): - task_qs = task_qs.filter(task_kwargs__contains=self.backend.encode(value)) + @contextlib.contextmanager + def _lock_signature(self, signature): + """ + Opens a transaction and creates an advisory lock for its duration, based off a crc32 hash to convert + the signature into an integer which postgres' lock function require + :param signature: An hex string representing an md5 hash of task metadata + """ + with transaction.atomic(): + # compute crc32 to turn signature into integer + key2 = zlib.crc32(signature.encode('utf-8')) + advisory_lock(TASK_LOCK, key2=key2) + yield - return task_qs.values_list("task_id", flat=True) + def find_ids(self, signature): + """ + :param signature: An hex string representing an md5 hash of task metadata + :return: A TaskResult queryset + :rtype: django.db.models.query.QuerySet + """ + return self.TaskModel.objects.filter(signature=signature)\ + .values_list("task_id", flat=True) - def find_incomplete_ids(self, channel_id=None, **kwargs): + def find_incomplete_ids(self, signature): """ - :param channel_id: - :param kwargs: + :param signature: An hex string representing an md5 hash of task metadata :return: A TaskResult queryset :rtype: django.db.models.query.QuerySet """ - return self.find_ids(channel_id=channel_id, **kwargs).exclude(status__in=states.READY_STATES) + return self.find_ids(signature).filter(status__in=states.UNREADY_STATES) def fetch(self, task_id): """ @@ -146,27 +190,6 @@ def fetch(self, task_id): """ return self.AsyncResult(task_id) - def _fetch_match(self, task_id, **kwargs): - """ - Gets the result object for a task, assuming it was called async, and ensures it was called with kwargs and - assumes that kwargs is has been decoded from an prepared form - :param task_id: The hex task ID - :param kwargs: The kwargs the task was called with, which must match when fetching - :return: A CeleryAsyncResult - :rtype: CeleryAsyncResult - """ - async_result = self.fetch(task_id) - # the task kwargs are serialized in the DB so just ensure that args actually match - if async_result.kwargs == kwargs: - return async_result - return None - - def _prepare_kwargs(self, kwargs): - return self.backend.encode({ - key: value.hex if isinstance(value, uuid.UUID) else value - for key, value in kwargs.items() - }) - def enqueue(self, user, **kwargs): """ Enqueues the task called with `kwargs`, and requires the user who wants to enqueue it. If `channel_id` is @@ -182,17 +205,20 @@ def enqueue(self, user, **kwargs): if user is None or not isinstance(user, User): raise TypeError("All tasks must be assigned to a user.") + signature = kwargs.pop('signature', None) + if signature is None: + signature = self._generate_signature(kwargs) + task_id = uuid.uuid4().hex prepared_kwargs = self._prepare_kwargs(kwargs) - transcoded_kwargs = self.backend.decode(prepared_kwargs) - channel_id = transcoded_kwargs.get("channel_id") + channel_id = prepared_kwargs.get("channel_id") - logging.info(f"Enqueuing task:id {self.name}:{task_id} for user:channel {user.pk}:{channel_id} | {prepared_kwargs}") + logging.info(f"Enqueuing task:id {self.name}:{task_id} for user:channel {user.pk}:{channel_id} | {signature}") # returns a CeleryAsyncResult async_result = self.apply_async( task_id=task_id, - kwargs=transcoded_kwargs, + kwargs=prepared_kwargs, ) # ensure the result is saved to the backend (database) @@ -201,9 +227,10 @@ def enqueue(self, user, **kwargs): # after calling apply, we should have task result model, so get it and set our custom fields task_result = get_task_model(self, task_id) task_result.task_name = self.name - task_result.task_kwargs = prepared_kwargs + task_result.task_kwargs = self.backend.encode(prepared_kwargs) task_result.user = user task_result.channel_id = channel_id + task_result.signature = signature task_result.save() return async_result @@ -219,16 +246,24 @@ def fetch_or_enqueue(self, user, **kwargs): """ # if we're eagerly executing the task (synchronously), then we shouldn't check for an existing task because # implementations probably aren't prepared to rely on an existing asynchronous task - if not self.app.conf.task_always_eager: - transcoded_kwargs = self.backend.decode(self._prepare_kwargs(kwargs)) - task_ids = self.find_incomplete_ids(**transcoded_kwargs).order_by("date_created")[:1] + if self.app.conf.task_always_eager: + return self.enqueue(user, **kwargs) + + signature = self._generate_signature(kwargs) + + # create an advisory lock to obtain exclusive control on preventing task duplicates + with self._lock_signature(signature): + # order by most recently created + task_ids = self.find_incomplete_ids(signature).order_by("-date_created")[:1] if task_ids: - async_result = self._fetch_match(task_ids[0], **transcoded_kwargs) - if async_result: - logging.info(f"Fetched matching task {self.name} for user {user.pk} with id {async_result.id} | {kwargs}") + async_result = self.fetch(task_ids[0]) + # double check + if async_result and async_result.status not in states.READY_STATES: + logging.info(f"Fetched matching task {self.name} for user {user.pk} with id {async_result.id} | {signature}") return async_result - logging.info(f"Didn't fetch matching task {self.name} for user {user.pk} | {kwargs}") - return self.enqueue(user, **kwargs) + logging.info(f"Didn't fetch matching task {self.name} for user {user.pk} | {signature}") + kwargs.update(signature=signature) + return self.enqueue(user, **kwargs) def requeue(self, **kwargs): """ @@ -243,8 +278,9 @@ def requeue(self, **kwargs): task_result = get_task_model(self, request.id) task_kwargs = request.kwargs.copy() task_kwargs.update(kwargs) - logging.info(f"Re-queuing task {self.name} for user {task_result.user.pk} from {request.id} | {task_kwargs}") - return self.enqueue(task_result.user, **task_kwargs) + signature = self._generate_signature(kwargs) + logging.info(f"Re-queuing task {self.name} for user {task_result.user.pk} from {request.id} | {signature}") + return self.enqueue(task_result.user, signature=signature, **task_kwargs) def revoke(self, exclude_task_ids=None, **kwargs): """ @@ -253,7 +289,9 @@ def revoke(self, exclude_task_ids=None, **kwargs): :param kwargs: Task keyword arguments that will be used to match against tasks :return: The number of tasks revoked """ - task_ids = self.find_incomplete_ids(**self.backend.decode(self._prepare_kwargs(kwargs))) + signature = self._generate_signature(kwargs) + task_ids = self.find_incomplete_ids(signature) + if exclude_task_ids is not None: task_ids = task_ids.exclude(task_id__in=task_ids) count = 0 diff --git a/contentcuration/contentcuration/utils/publish.py b/contentcuration/contentcuration/utils/publish.py index 93bee0bf40..2918628ebd 100644 --- a/contentcuration/contentcuration/utils/publish.py +++ b/contentcuration/contentcuration/utils/publish.py @@ -869,10 +869,7 @@ def sync_contentnode_and_channel_tsvectors(channel_id): # Insert newly created nodes. # "set_contentnode_tsvectors" command is defined in "search/management/commands" directory. - call_command("set_contentnode_tsvectors", - "--channel-id={}".format(channel_id), - "--tree-id={}".format(channel["main_tree__tree_id"]), - "--complete") + call_command("set_contentnode_tsvectors", "--channel-id={}".format(channel_id)) @delay_user_storage_calculation diff --git a/contentcuration/locale/es_ES/LC_MESSAGES/contentcuration-messages.json b/contentcuration/locale/es_ES/LC_MESSAGES/contentcuration-messages.json index c24f068db5..fc815840b6 100644 --- a/contentcuration/locale/es_ES/LC_MESSAGES/contentcuration-messages.json +++ b/contentcuration/locale/es_ES/LC_MESSAGES/contentcuration-messages.json @@ -484,7 +484,7 @@ "CommonMetadataStrings.readReference": "Referencia", "CommonMetadataStrings.readingAndWriting": "Lectura y escritura", "CommonMetadataStrings.readingComprehension": "ComprensiĆ³n lectora", - "CommonMetadataStrings.reflect": "Reflejar", + "CommonMetadataStrings.reflect": "Reflexionar", "CommonMetadataStrings.school": "Escuela", "CommonMetadataStrings.sciences": "Ciencias", "CommonMetadataStrings.shortActivity": "Actividad corta", diff --git a/contentcuration/search/management/commands/set_contentnode_tsvectors.py b/contentcuration/search/management/commands/set_contentnode_tsvectors.py index c5e78ca8d1..067a956f62 100644 --- a/contentcuration/search/management/commands/set_contentnode_tsvectors.py +++ b/contentcuration/search/management/commands/set_contentnode_tsvectors.py @@ -10,9 +10,11 @@ from search.models import ContentNodeFullTextSearch from search.utils import get_fts_annotated_contentnode_qs +from contentcuration.models import Channel + logmodule.basicConfig(level=logmodule.INFO) -logging = logmodule.getLogger("command") +logging = logmodule.getLogger(__name__) CHUNKSIZE = 10000 @@ -20,55 +22,53 @@ class Command(BaseCommand): def add_arguments(self, parser): parser.add_argument("--channel-id", type=str, dest="channel_id", - help="The channel_id to annotate to the nodes. If not specified then each node's channel_id is queried and then annotated.") - parser.add_argument("--tree-id", type=int, dest="tree_id", - help="Set tsvectors for a specific tree_id nodes only. If not specified then tsvectors for all nodes of ContentNode table are set.") - parser.add_argument("--published", dest="published", action="store_true", help="Filters on whether node is published or not.") - parser.add_argument("--complete", dest="complete", action="store_true", help="Filters on whether node is complete or not.") - - def get_tsvector_nodes_queryset(self, *args, **options): - tsvector_nodes_queryset = get_fts_annotated_contentnode_qs(channel_id=options["channel_id"]) + help="The channel_id for which tsvectors need to be generated.\ + If not specified then tsvectors is generated for all published channels.") + parser.add_argument("--published", dest="published", action="store_true", + help="Filters on whether channel's contentnodes are published or not.") - if options["tree_id"]: - tsvector_nodes_queryset = tsvector_nodes_queryset.filter(tree_id=options["tree_id"]) + def handle(self, *args, **options): + start = time.time() - if options["complete"]: - tsvector_nodes_queryset = tsvector_nodes_queryset.filter(complete=True) + if options["channel_id"]: + generate_tsv_for_channels = list(Channel.objects.filter(id=options["channel_id"]).values("id", "main_tree__tree_id")) + else: + generate_tsv_for_channels = list(Channel.objects.filter(main_tree__published=True, deleted=False).values("id", "main_tree__tree_id")) if options["published"]: - tsvector_nodes_queryset = tsvector_nodes_queryset.filter(published=True) - - tsvector_not_already_inserted_query = ~Exists(ContentNodeFullTextSearch.objects.filter(contentnode_id=OuterRef("id"))) - tsvector_nodes_queryset = (tsvector_nodes_queryset - .filter(tsvector_not_already_inserted_query, channel_id__isnull=False) - .values("id", "channel_id", "keywords_tsvector", "author_tsvector").order_by()) + publish_filter_dict = dict(published=True) + else: + publish_filter_dict = dict() - return tsvector_nodes_queryset + total_tsvectors_inserted = 0 - def handle(self, *args, **options): - start = time.time() + for channel in generate_tsv_for_channels: + tsvector_not_already_inserted_query = ~Exists(ContentNodeFullTextSearch.objects.filter(contentnode_id=OuterRef("id"))) + tsvector_nodes_query = (get_fts_annotated_contentnode_qs(channel["id"]) + .filter(tsvector_not_already_inserted_query, tree_id=channel["main_tree__tree_id"], complete=True, **publish_filter_dict) + .values("id", "channel_id", "keywords_tsvector", "author_tsvector") + .order_by()) - tsvector_nodes_queryset = self.get_tsvector_nodes_queryset(*args, **options) + insertable_nodes_tsvector = list(tsvector_nodes_query[:CHUNKSIZE]) - insertable_nodes_tsvector = list(tsvector_nodes_queryset[:CHUNKSIZE]) - total_tsvectors_inserted = 0 + logging.info("Inserting contentnode tsvectors of channel {}.".format(channel["id"])) - while insertable_nodes_tsvector: - logging.info("Inserting contentnode tsvectors.") + while insertable_nodes_tsvector: + insert_objs = list() + for node in insertable_nodes_tsvector: + obj = ContentNodeFullTextSearch(contentnode_id=node["id"], channel_id=node["channel_id"], + keywords_tsvector=node["keywords_tsvector"], author_tsvector=node["author_tsvector"]) + insert_objs.append(obj) - insert_objs = list() - for node in insertable_nodes_tsvector: - obj = ContentNodeFullTextSearch(contentnode_id=node["id"], channel_id=node["channel_id"], - keywords_tsvector=node["keywords_tsvector"], author_tsvector=node["author_tsvector"]) - insert_objs.append(obj) + inserted_objs_list = ContentNodeFullTextSearch.objects.bulk_create(insert_objs) - inserted_objs_list = ContentNodeFullTextSearch.objects.bulk_create(insert_objs) + current_inserts_count = len(inserted_objs_list) + total_tsvectors_inserted = total_tsvectors_inserted + current_inserts_count - current_inserts_count = len(inserted_objs_list) - total_tsvectors_inserted = total_tsvectors_inserted + current_inserts_count + logging.info("Inserted {} contentnode tsvectors of channel {}.".format(current_inserts_count, channel["id"])) - logging.info("Inserted {} contentnode tsvectors.".format(current_inserts_count)) + insertable_nodes_tsvector = list(tsvector_nodes_query[:CHUNKSIZE]) - insertable_nodes_tsvector = list(tsvector_nodes_queryset[:CHUNKSIZE]) + logging.info("Insertion complete for channel {}.".format(channel["id"])) logging.info("Completed! Successfully inserted total of {} contentnode tsvectors in {} seconds.".format(total_tsvectors_inserted, time.time() - start)) diff --git a/contentcuration/search/utils.py b/contentcuration/search/utils.py index 4f6768f650..8519fde49d 100644 --- a/contentcuration/search/utils.py +++ b/contentcuration/search/utils.py @@ -14,28 +14,19 @@ def get_fts_search_query(value): return SearchQuery(value=value, config=POSTGRES_FTS_CONFIG) -def get_fts_annotated_contentnode_qs(channel_id=None): +def get_fts_annotated_contentnode_qs(channel_id): """ Returns a `ContentNode` queryset annotated with fields required for full text search. - - If `channel_id` is provided, annotates that specific `channel_id` else annotates - the `channel_id` to which the contentnode belongs. """ from contentcuration.models import ContentNode - if channel_id: - queryset = ContentNode.objects.annotate(channel_id=Value(channel_id)) - else: - queryset = ContentNode._annotate_channel_id(ContentNode.objects) - - queryset = queryset.annotate( + return ContentNode.objects.annotate( + channel_id=Value(channel_id), contentnode_tags=StringAgg("tags__tag_name", delimiter=" "), keywords_tsvector=CONTENTNODE_KEYWORDS_TSVECTOR, author_tsvector=CONTENTNODE_AUTHOR_TSVECTOR ) - return queryset - def get_fts_annotated_channel_qs(): """