From 3b30a32de3224099b9be7d8f5f4b9680fbcff3fc Mon Sep 17 00:00:00 2001 From: Marcel Klehr Date: Wed, 21 Sep 2022 14:07:33 +0200 Subject: [PATCH] Allow scheduling specific classifier/crawl jobs per model Signed-off-by: Marcel Klehr --- lib/BackgroundJobs/SchedulerJob.php | 14 +++++++ lib/BackgroundJobs/StorageCrawlJob.php | 51 +++++++++++++++++++++++--- lib/Controller/AdminController.php | 27 ++++++++++++++ 3 files changed, 87 insertions(+), 5 deletions(-) diff --git a/lib/BackgroundJobs/SchedulerJob.php b/lib/BackgroundJobs/SchedulerJob.php index 32ddce27..83015792 100644 --- a/lib/BackgroundJobs/SchedulerJob.php +++ b/lib/BackgroundJobs/SchedulerJob.php @@ -9,6 +9,11 @@ use OC\Files\Cache\CacheQueryBuilder; use OC\SystemConfig; +use OCA\Recognize\Classifiers\Audio\MusicnnClassifier; +use OCA\Recognize\Classifiers\Images\ClusteringFaceClassifier; +use OCA\Recognize\Classifiers\Images\ImagenetClassifier; +use OCA\Recognize\Classifiers\Images\LandmarksClassifier; +use OCA\Recognize\Classifiers\Video\MovinetClassifier; use OCA\Recognize\Service\Logger; use OCP\AppFramework\Utility\ITimeFactory; use OCP\BackgroundJob\IJobList; @@ -49,6 +54,14 @@ public function __construct(ITimeFactory $timeFactory, Logger $logger, IDBConnec } protected function run($argument): void { + $models = $argument['models'] ?? [ + ClusteringFaceClassifier::MODEL_NAME, + ImagenetClassifier::MODEL_NAME, + LandmarksClassifier::MODEL_NAME, + MovinetClassifier::MODEL_NAME, + MusicnnClassifier::MODEL_NAME, + ]; + $qb = $this->db->getQueryBuilder(); $qb->select('root_id', 'storage_id', 'mount_provider_class') ->from('mounts') @@ -80,6 +93,7 @@ protected function run($argument): void { 'root_id' => $rootId, 'override_root' => $overrideRoot, 'last_file_id' => 0, + 'models' => $models, ]); } diff --git a/lib/BackgroundJobs/StorageCrawlJob.php b/lib/BackgroundJobs/StorageCrawlJob.php index 7f2481e4..215df39a 100644 --- a/lib/BackgroundJobs/StorageCrawlJob.php +++ b/lib/BackgroundJobs/StorageCrawlJob.php @@ -12,11 +12,13 @@ use OCA\Recognize\Classifiers\Audio\MusicnnClassifier; use OCA\Recognize\Classifiers\Images\ClusteringFaceClassifier; use OCA\Recognize\Classifiers\Images\ImagenetClassifier; +use OCA\Recognize\Classifiers\Images\LandmarksClassifier; use OCA\Recognize\Classifiers\Video\MovinetClassifier; use OCA\Recognize\Constants; use OCA\Recognize\Db\QueueFile; use OCA\Recognize\Service\Logger; use OCA\Recognize\Service\QueueService; +use OCA\Recognize\Service\TagManager; use OCP\AppFramework\Utility\ITimeFactory; use OCP\BackgroundJob\IJobList; use OCP\BackgroundJob\QueuedJob; @@ -33,8 +35,9 @@ class StorageCrawlJob extends QueuedJob { private IJobList $jobList; private IDBConnection $db; private SystemConfig $systemConfig; + private TagManager $tagManager; - public function __construct(ITimeFactory $timeFactory, Logger $logger, IMimeTypeLoader $mimeTypes, QueueService $queue, IJobList $jobList, IDBConnection $db, SystemConfig $systemConfig) { + public function __construct(ITimeFactory $timeFactory, Logger $logger, IMimeTypeLoader $mimeTypes, QueueService $queue, IJobList $jobList, IDBConnection $db, SystemConfig $systemConfig, TagManager $tagManager) { parent::__construct($timeFactory); $this->logger = $logger; $this->mimeTypes = $mimeTypes; @@ -42,6 +45,7 @@ public function __construct(ITimeFactory $timeFactory, Logger $logger, IMimeType $this->jobList = $jobList; $this->db = $db; $this->systemConfig = $systemConfig; + $this->tagManager = $tagManager; } protected function run($argument): void { @@ -49,6 +53,14 @@ protected function run($argument): void { $rootId = $argument['root_id']; $overrideRoot = $argument['override_root']; $lastFileId = $argument['last_file_id']; + $models = $argument['models'] ?? [ + ClusteringFaceClassifier::MODEL_NAME, + ImagenetClassifier::MODEL_NAME, + LandmarksClassifier::MODEL_NAME, + MovinetClassifier::MODEL_NAME, + MusicnnClassifier::MODEL_NAME, + ]; + $qb = new CacheQueryBuilder($this->db, $this->systemConfig, $this->logger); try { $root = $qb->selectFileCache() @@ -63,13 +75,26 @@ protected function run($argument): void { $videoTypes = array_map(fn ($mimeType) => $this->mimeTypes->getId($mimeType), Constants::VIDEO_FORMATS); $audioTypes = array_map(fn ($mimeType) => $this->mimeTypes->getId($mimeType), Constants::AUDIO_FORMATS); + $mimeTypes = []; + if (in_array(ClusteringFaceClassifier::MODEL_NAME, $models) || + in_array(ImagenetClassifier::MODEL_NAME, $models) || + in_array(LandmarksClassifier::MODEL_NAME, $models)) { + $mimeTypes = array_merge($imageTypes, $mimeTypes); + } + if (in_array(MovinetClassifier::MODEL_NAME, $models)) { + $mimeTypes = array_merge($videoTypes, $mimeTypes); + } + if (in_array(MusicnnClassifier::MODEL_NAME, $models)) { + $mimeTypes = array_merge($audioTypes, $mimeTypes); + } + try { $qb = new CacheQueryBuilder($this->db, $this->systemConfig, $this->logger); $files = $qb->selectFileCache() ->whereStorageId($storageId) ->andWhere($qb->expr()->like('path', $qb->createNamedParameter($root['path'] . '/%'))) ->andWhere($qb->expr()->eq('storage', $qb->createNamedParameter($storageId))) - ->andWhere($qb->expr()->in('mimetype', $qb->createNamedParameter(array_merge($imageTypes, $videoTypes, $audioTypes), IQueryBuilder::PARAM_INT_ARRAY))) + ->andWhere($qb->expr()->in('mimetype', $qb->createNamedParameter($mimeTypes, IQueryBuilder::PARAM_INT_ARRAY))) ->andWhere($qb->expr()->gt('filecache.fileid', $qb->createNamedParameter($lastFileId))) ->orderBy('filecache.fileid', 'ASC') ->setMaxResults(100) @@ -92,8 +117,23 @@ protected function run($argument): void { $queueFile->setUpdate(false); try { if (in_array($file['mimetype'], $imageTypes)) { - $this->queue->insertIntoQueue(ImagenetClassifier::MODEL_NAME, $queueFile); - $this->queue->insertIntoQueue(ClusteringFaceClassifier::MODEL_NAME, $queueFile); + if (in_array(ImagenetClassifier::class, $models)) { + $this->queue->insertIntoQueue(ImagenetClassifier::MODEL_NAME, $queueFile); + } + if (!in_array(ImagenetClassifier::class, $models) && in_array(LandmarksClassifier::class, $models)) { + $tags = $this->tagManager->getTagsForFiles([$queueFile->getFileId()]); + /** @var \OCP\SystemTag\ISystemTag[] $fileTags */ + $fileTags = $tags[$queueFile->getFileId()]; + $landmarkTags = array_filter($fileTags, function ($tag) { + return in_array($tag->getName(), LandmarksClassifier::PRECONDITION_TAGS); + }); + if (count($landmarkTags) > 0) { + $this->queue->insertIntoQueue(LandmarksClassifier::MODEL_NAME, $queueFile); + } + } + if (in_array(ClusteringFaceClassifier::class, $models)) { + $this->queue->insertIntoQueue(ClusteringFaceClassifier::MODEL_NAME, $queueFile); + } } if (in_array($file['mimetype'], $videoTypes)) { $this->queue->insertIntoQueue(MovinetClassifier::MODEL_NAME, $queueFile); @@ -113,7 +153,8 @@ protected function run($argument): void { 'storage_id' => $storageId, 'root_id' => $rootId, 'override_root' => $overrideRoot, - 'last_file_id' => $queueFile->getFileId() + 'last_file_id' => $queueFile->getFileId(), + 'models' => $models, ]); } } diff --git a/lib/Controller/AdminController.php b/lib/Controller/AdminController.php index 1edf753b..461c1411 100644 --- a/lib/Controller/AdminController.php +++ b/lib/Controller/AdminController.php @@ -3,6 +3,11 @@ namespace OCA\Recognize\Controller; use OCA\Recognize\BackgroundJobs\SchedulerJob; +use OCA\Recognize\Classifiers\Audio\MusicnnClassifier; +use OCA\Recognize\Classifiers\Images\ClusteringFaceClassifier; +use OCA\Recognize\Classifiers\Images\ImagenetClassifier; +use OCA\Recognize\Classifiers\Images\LandmarksClassifier; +use OCA\Recognize\Classifiers\Video\MovinetClassifier; use OCA\Recognize\Service\QueueService; use OCA\Recognize\Service\TagManager; use OCP\AppFramework\Controller; @@ -109,6 +114,28 @@ public function musl(): JSONResponse { } public function setSetting(string $setting, $value) { + if ($value === true && $this->config->getAppValue('recognize', $setting, 'false') === 'false') { + // Additional model enabled: Schedule new crawl run for the affected mime types + switch ($setting) { + case ClusteringFaceClassifier::MODEL_NAME . '.enabled': + $this->jobList->add(SchedulerJob::class, ['models' => [ClusteringFaceClassifier::MODEL_NAME]]); + break; + case ImagenetClassifier::MODEL_NAME . '.enabled': + $this->jobList->add(SchedulerJob::class, ['models' => [ImagenetClassifier::MODEL_NAME]]); + // no break + case LandmarksClassifier::MODEL_NAME . '.enabled': + $this->jobList->add(SchedulerJob::class, ['models' => [LandmarksClassifier::MODEL_NAME]]); + break; + case MovinetClassifier::MODEL_NAME . '.enabled': + $this->jobList->add(SchedulerJob::class, ['models' => [MovinetClassifier::MODEL_NAME]]); + break; + case MusicnnClassifier::MODEL_NAME . '.enabled': + $this->jobList->add(SchedulerJob::class, ['models' => [MusicnnClassifier::MODEL_NAME]]); + break; + default: + break; + } + } $this->config->setAppValue('recognize', $setting, $value); }