Skip to content

Commit

Permalink
Allow scheduling specific classifier/crawl jobs per model
Browse files Browse the repository at this point in the history
Signed-off-by: Marcel Klehr <[email protected]>
  • Loading branch information
marcelklehr committed Sep 23, 2022
1 parent 7645628 commit 3b30a32
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 5 deletions.
14 changes: 14 additions & 0 deletions lib/BackgroundJobs/SchedulerJob.php
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@

use OC\Files\Cache\CacheQueryBuilder;
use OC\SystemConfig;
use OCA\Recognize\Classifiers\Audio\MusicnnClassifier;
use OCA\Recognize\Classifiers\Images\ClusteringFaceClassifier;
use OCA\Recognize\Classifiers\Images\ImagenetClassifier;
use OCA\Recognize\Classifiers\Images\LandmarksClassifier;
use OCA\Recognize\Classifiers\Video\MovinetClassifier;
use OCA\Recognize\Service\Logger;
use OCP\AppFramework\Utility\ITimeFactory;
use OCP\BackgroundJob\IJobList;
Expand Down Expand Up @@ -49,6 +54,14 @@ public function __construct(ITimeFactory $timeFactory, Logger $logger, IDBConnec
}

protected function run($argument): void {
$models = $argument['models'] ?? [
ClusteringFaceClassifier::MODEL_NAME,
ImagenetClassifier::MODEL_NAME,
LandmarksClassifier::MODEL_NAME,
MovinetClassifier::MODEL_NAME,
MusicnnClassifier::MODEL_NAME,
];

$qb = $this->db->getQueryBuilder();
$qb->select('root_id', 'storage_id', 'mount_provider_class')
->from('mounts')
Expand Down Expand Up @@ -80,6 +93,7 @@ protected function run($argument): void {
'root_id' => $rootId,
'override_root' => $overrideRoot,
'last_file_id' => 0,
'models' => $models,
]);
}

Expand Down
51 changes: 46 additions & 5 deletions lib/BackgroundJobs/StorageCrawlJob.php
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@
use OCA\Recognize\Classifiers\Audio\MusicnnClassifier;
use OCA\Recognize\Classifiers\Images\ClusteringFaceClassifier;
use OCA\Recognize\Classifiers\Images\ImagenetClassifier;
use OCA\Recognize\Classifiers\Images\LandmarksClassifier;
use OCA\Recognize\Classifiers\Video\MovinetClassifier;
use OCA\Recognize\Constants;
use OCA\Recognize\Db\QueueFile;
use OCA\Recognize\Service\Logger;
use OCA\Recognize\Service\QueueService;
use OCA\Recognize\Service\TagManager;
use OCP\AppFramework\Utility\ITimeFactory;
use OCP\BackgroundJob\IJobList;
use OCP\BackgroundJob\QueuedJob;
Expand All @@ -33,22 +35,32 @@ class StorageCrawlJob extends QueuedJob {
private IJobList $jobList;
private IDBConnection $db;
private SystemConfig $systemConfig;
private TagManager $tagManager;

public function __construct(ITimeFactory $timeFactory, Logger $logger, IMimeTypeLoader $mimeTypes, QueueService $queue, IJobList $jobList, IDBConnection $db, SystemConfig $systemConfig) {
public function __construct(ITimeFactory $timeFactory, Logger $logger, IMimeTypeLoader $mimeTypes, QueueService $queue, IJobList $jobList, IDBConnection $db, SystemConfig $systemConfig, TagManager $tagManager) {
parent::__construct($timeFactory);
$this->logger = $logger;
$this->mimeTypes = $mimeTypes;
$this->queue = $queue;
$this->jobList = $jobList;
$this->db = $db;
$this->systemConfig = $systemConfig;
$this->tagManager = $tagManager;
}

protected function run($argument): void {
$storageId = $argument['storage_id'];
$rootId = $argument['root_id'];
$overrideRoot = $argument['override_root'];
$lastFileId = $argument['last_file_id'];
$models = $argument['models'] ?? [
ClusteringFaceClassifier::MODEL_NAME,
ImagenetClassifier::MODEL_NAME,
LandmarksClassifier::MODEL_NAME,
MovinetClassifier::MODEL_NAME,
MusicnnClassifier::MODEL_NAME,
];

$qb = new CacheQueryBuilder($this->db, $this->systemConfig, $this->logger);
try {
$root = $qb->selectFileCache()
Expand All @@ -63,13 +75,26 @@ protected function run($argument): void {
$videoTypes = array_map(fn ($mimeType) => $this->mimeTypes->getId($mimeType), Constants::VIDEO_FORMATS);
$audioTypes = array_map(fn ($mimeType) => $this->mimeTypes->getId($mimeType), Constants::AUDIO_FORMATS);

$mimeTypes = [];
if (in_array(ClusteringFaceClassifier::MODEL_NAME, $models) ||
in_array(ImagenetClassifier::MODEL_NAME, $models) ||
in_array(LandmarksClassifier::MODEL_NAME, $models)) {
$mimeTypes = array_merge($imageTypes, $mimeTypes);
}
if (in_array(MovinetClassifier::MODEL_NAME, $models)) {
$mimeTypes = array_merge($videoTypes, $mimeTypes);
}
if (in_array(MusicnnClassifier::MODEL_NAME, $models)) {
$mimeTypes = array_merge($audioTypes, $mimeTypes);
}

try {
$qb = new CacheQueryBuilder($this->db, $this->systemConfig, $this->logger);
$files = $qb->selectFileCache()
->whereStorageId($storageId)
->andWhere($qb->expr()->like('path', $qb->createNamedParameter($root['path'] . '/%')))
->andWhere($qb->expr()->eq('storage', $qb->createNamedParameter($storageId)))
->andWhere($qb->expr()->in('mimetype', $qb->createNamedParameter(array_merge($imageTypes, $videoTypes, $audioTypes), IQueryBuilder::PARAM_INT_ARRAY)))
->andWhere($qb->expr()->in('mimetype', $qb->createNamedParameter($mimeTypes, IQueryBuilder::PARAM_INT_ARRAY)))
->andWhere($qb->expr()->gt('filecache.fileid', $qb->createNamedParameter($lastFileId)))
->orderBy('filecache.fileid', 'ASC')
->setMaxResults(100)
Expand All @@ -92,8 +117,23 @@ protected function run($argument): void {
$queueFile->setUpdate(false);
try {
if (in_array($file['mimetype'], $imageTypes)) {
$this->queue->insertIntoQueue(ImagenetClassifier::MODEL_NAME, $queueFile);
$this->queue->insertIntoQueue(ClusteringFaceClassifier::MODEL_NAME, $queueFile);
if (in_array(ImagenetClassifier::class, $models)) {
$this->queue->insertIntoQueue(ImagenetClassifier::MODEL_NAME, $queueFile);
}
if (!in_array(ImagenetClassifier::class, $models) && in_array(LandmarksClassifier::class, $models)) {
$tags = $this->tagManager->getTagsForFiles([$queueFile->getFileId()]);
/** @var \OCP\SystemTag\ISystemTag[] $fileTags */
$fileTags = $tags[$queueFile->getFileId()];
$landmarkTags = array_filter($fileTags, function ($tag) {
return in_array($tag->getName(), LandmarksClassifier::PRECONDITION_TAGS);
});
if (count($landmarkTags) > 0) {
$this->queue->insertIntoQueue(LandmarksClassifier::MODEL_NAME, $queueFile);
}
}
if (in_array(ClusteringFaceClassifier::class, $models)) {
$this->queue->insertIntoQueue(ClusteringFaceClassifier::MODEL_NAME, $queueFile);
}
}
if (in_array($file['mimetype'], $videoTypes)) {
$this->queue->insertIntoQueue(MovinetClassifier::MODEL_NAME, $queueFile);
Expand All @@ -113,7 +153,8 @@ protected function run($argument): void {
'storage_id' => $storageId,
'root_id' => $rootId,
'override_root' => $overrideRoot,
'last_file_id' => $queueFile->getFileId()
'last_file_id' => $queueFile->getFileId(),
'models' => $models,
]);
}
}
Expand Down
27 changes: 27 additions & 0 deletions lib/Controller/AdminController.php
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,11 @@
namespace OCA\Recognize\Controller;

use OCA\Recognize\BackgroundJobs\SchedulerJob;
use OCA\Recognize\Classifiers\Audio\MusicnnClassifier;
use OCA\Recognize\Classifiers\Images\ClusteringFaceClassifier;
use OCA\Recognize\Classifiers\Images\ImagenetClassifier;
use OCA\Recognize\Classifiers\Images\LandmarksClassifier;
use OCA\Recognize\Classifiers\Video\MovinetClassifier;
use OCA\Recognize\Service\QueueService;
use OCA\Recognize\Service\TagManager;
use OCP\AppFramework\Controller;
Expand Down Expand Up @@ -109,6 +114,28 @@ public function musl(): JSONResponse {
}

public function setSetting(string $setting, $value) {
if ($value === true && $this->config->getAppValue('recognize', $setting, 'false') === 'false') {
// Additional model enabled: Schedule new crawl run for the affected mime types
switch ($setting) {
case ClusteringFaceClassifier::MODEL_NAME . '.enabled':
$this->jobList->add(SchedulerJob::class, ['models' => [ClusteringFaceClassifier::MODEL_NAME]]);
break;
case ImagenetClassifier::MODEL_NAME . '.enabled':
$this->jobList->add(SchedulerJob::class, ['models' => [ImagenetClassifier::MODEL_NAME]]);
// no break
case LandmarksClassifier::MODEL_NAME . '.enabled':
$this->jobList->add(SchedulerJob::class, ['models' => [LandmarksClassifier::MODEL_NAME]]);
break;
case MovinetClassifier::MODEL_NAME . '.enabled':
$this->jobList->add(SchedulerJob::class, ['models' => [MovinetClassifier::MODEL_NAME]]);
break;
case MusicnnClassifier::MODEL_NAME . '.enabled':
$this->jobList->add(SchedulerJob::class, ['models' => [MusicnnClassifier::MODEL_NAME]]);
break;
default:
break;
}
}
$this->config->setAppValue('recognize', $setting, $value);
}

Expand Down

0 comments on commit 3b30a32

Please sign in to comment.