Skip to content

Commit

Permalink
Add lazy loading for builtin plugins (cvat-ai#306)
Browse files Browse the repository at this point in the history
* Refactor env code

* Load builtin plugins lazily

* update changelog
  • Loading branch information
Maxim Zhiltsov authored Jun 20, 2021
1 parent 935a502 commit 9a2f3f2
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 96 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Fixed
- Incorrect image layout on saving and a problem with ecoding on loading (<https://github.com/openvinotoolkit/datumaro/pull/284>)
- An error when xpath fiter is applied to the dataset or its subset (<https://github.com/openvinotoolkit/datumaro/issues/259>)
- Improved CLI startup time in several cases (<https://github.com/openvinotoolkit/datumaro/pull/306>)

### Security
-
Expand Down
170 changes: 83 additions & 87 deletions datumaro/components/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,27 +9,18 @@
import logging as log
import os
import os.path as osp
from typing import Dict, Iterable

from datumaro.components.config import Config
from datumaro.components.config_model import Model, Source
from datumaro.util.os_util import import_foreign_module


class Registry:
def __init__(self, config=None, item_type=None):
self.item_type = item_type

def __init__(self):
self.items = {}

if config is not None:
self.load(config)

def load(self, config):
pass

def register(self, name, value):
if self.item_type:
value = self.item_type(value)
self.items[name] = value
return value

Expand All @@ -46,44 +37,34 @@ def __getitem__(self, key):
def __contains__(self, key):
return key in self.items

def __iter__(self):
return iter(self.items)

class ModelRegistry(Registry):
def __init__(self, config=None):
super().__init__(config, item_type=Model)

def load(self, config):
# TODO: list default dir, insert values
if 'models' in config:
for name, model in config.models.items():
self.register(name, model)
def batch_register(self, items: Dict[str, Model]):
for name, model in items.items():
self.register(name, model)


class SourceRegistry(Registry):
def __init__(self, config=None):
super().__init__(config, item_type=Source)

def load(self, config):
# TODO: list default dir, insert values
if 'sources' in config:
for name, source in config.sources.items():
self.register(name, source)

def batch_register(self, items: Dict[str, Source]):
for name, source in items.items():
self.register(name, source)

class PluginRegistry(Registry):
def __init__(self, config=None, builtin=None, local=None):
super().__init__(config)
def __init__(self, filter=None): #pylint: disable=redefined-builtin
super().__init__()
self.filter = filter

def batch_register(self, values: Iterable):
from datumaro.components.cli_plugin import CliPlugin

if builtin is not None:
for v in builtin:
k = CliPlugin._get_name(v)
self.register(k, v)
if local is not None:
for v in local:
k = CliPlugin._get_name(v)
self.register(k, v)
for v in values:
if self.filter and not self.filter(v):
continue
name = CliPlugin._get_name(v)

self.register(name, v)

class GitWrapper:
def __init__(self, config=None):
Expand Down Expand Up @@ -128,55 +109,60 @@ def remove_submodule(self, name, **kwargs):

class Environment:
_builtin_plugins = None
PROJECT_EXTRACTOR_NAME = 'datumaro_project'

def __init__(self, config=None):
from datumaro.components.project import (
PROJECT_DEFAULT_CONFIG, PROJECT_SCHEMA, load_project_as_dataset)
PROJECT_DEFAULT_CONFIG, PROJECT_SCHEMA)
config = Config(config,
fallback=PROJECT_DEFAULT_CONFIG, schema=PROJECT_SCHEMA)

self.models = ModelRegistry(config)
self.sources = SourceRegistry(config)
self.models = ModelRegistry()
self.sources = SourceRegistry()

self.git = GitWrapper(config)

env_dir = osp.join(config.project_dir, config.env_dir)
builtin = self._load_builtin_plugins()
custom = self._load_plugins2(osp.join(env_dir, config.plugins_dir))
select = lambda seq, t: [e for e in seq if issubclass(e, t)]
def _filter(accept, skip=None):
accept = (accept, ) if inspect.isclass(accept) else tuple(accept)
skip = {skip} if inspect.isclass(skip) else set(skip or [])
skip = tuple(skip | set(accept))
return lambda t: issubclass(t, accept) and t not in skip

from datumaro.components.converter import Converter
from datumaro.components.extractor import (Importer, Extractor,
Transform)
SourceExtractor, Transform)
from datumaro.components.launcher import Launcher
from datumaro.components.validator import Validator
self.extractors = PluginRegistry(
builtin=select(builtin, Extractor),
local=select(custom, Extractor)
)
self.extractors.register(self.PROJECT_EXTRACTOR_NAME,
load_project_as_dataset)

self.importers = PluginRegistry(
builtin=select(builtin, Importer),
local=select(custom, Importer)
)
self.launchers = PluginRegistry(
builtin=select(builtin, Launcher),
local=select(custom, Launcher)
)
self.converters = PluginRegistry(
builtin=select(builtin, Converter),
local=select(custom, Converter)
)
self.transforms = PluginRegistry(
builtin=select(builtin, Transform),
local=select(custom, Transform)
)
self.validators = PluginRegistry(
builtin=select(builtin, Validator),
local=select(custom, Validator)
)
self._extractors = PluginRegistry(_filter(Extractor, SourceExtractor))
self._importers = PluginRegistry(_filter(Importer))
self._launchers = PluginRegistry(_filter(Launcher))
self._converters = PluginRegistry(_filter(Converter))
self._transforms = PluginRegistry(_filter(Transform))
self._builtins_initialized = False

def _get_plugin_registry(self, name):
if not self._builtins_initialized:
self._builtins_initialized = True
self._register_builtin_plugins()
return getattr(self, name)

@property
def extractors(self) -> PluginRegistry:
return self._get_plugin_registry('_extractors')

@property
def importers(self) -> PluginRegistry:
return self._get_plugin_registry('_importers')

@property
def launchers(self) -> PluginRegistry:
return self._get_plugin_registry('_launchers')

@property
def converters(self) -> PluginRegistry:
return self._get_plugin_registry('_converters')

@property
def transforms(self) -> PluginRegistry:
return self._get_plugin_registry('_transforms')

@staticmethod
def _find_plugins(plugins_dir):
Expand Down Expand Up @@ -216,7 +202,14 @@ def _import_module(cls, module_dir, module_name, types, package=None):
return exports

@classmethod
def _load_plugins(cls, plugins_dir, types):
def _load_plugins(cls, plugins_dir, types=None):
if not types:
from datumaro.components.converter import Converter
from datumaro.components.extractor import (Extractor, Importer,
Transform)
from datumaro.components.launcher import Launcher
types = [Extractor, Converter, Importer, Launcher, Transform]

types = tuple(types)

plugins = cls._find_plugins(plugins_dir)
Expand Down Expand Up @@ -252,25 +245,28 @@ def _load_plugins(cls, plugins_dir, types):

@classmethod
def _load_builtin_plugins(cls):
if not cls._builtin_plugins:
if cls._builtin_plugins is None:
plugins_dir = osp.join(
__file__[: __file__.rfind(osp.join('datumaro', 'components'))],
osp.join('datumaro', 'plugins')
)
assert osp.isdir(plugins_dir), plugins_dir
cls._builtin_plugins = cls._load_plugins2(plugins_dir)
cls._builtin_plugins = cls._load_plugins(plugins_dir)
return cls._builtin_plugins

@classmethod
def _load_plugins2(cls, plugins_dir):
from datumaro.components.converter import Converter
from datumaro.components.extractor import (Extractor, Importer,
Transform)
from datumaro.components.launcher import Launcher
from datumaro.components.validator import Validator
types = [Extractor, Converter, Importer, Launcher, Transform, Validator]
def load_plugins(self, plugins_dir):
plugins = self._load_plugins(plugins_dir)
self._register_plugins(plugins)

def _register_builtin_plugins(self):
self._register_plugins(self._load_builtin_plugins())

return cls._load_plugins(plugins_dir, types)
def _register_plugins(self, plugins):
self.extractors.batch_register(plugins)
self.importers.batch_register(plugins)
self.launchers.batch_register(plugins)
self.converters.batch_register(plugins)
self.transforms.batch_register(plugins)

def make_extractor(self, name, *args, **kwargs):
return self.extractors.get(name)(*args, **kwargs)
Expand Down
20 changes: 11 additions & 9 deletions datumaro/components/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,17 @@ def __init__(self, project):

sources = {}
for s_name, source in config.sources.items():
s_format = source.format or env.PROJECT_EXTRACTOR_NAME
s_format = source.format

url = source.url
if not source.url:
url = osp.join(config.project_dir, config.sources_dir, s_name)
sources[s_name] = Dataset.import_from(url,
format=s_format, env=env, **source.options)

if s_format:
sources[s_name] = Dataset.import_from(url,
format=s_format, env=env, **source.options)
else:
sources[s_name] = Project.load(url).make_dataset()
self._sources = sources

own_source = None
Expand Down Expand Up @@ -91,8 +95,7 @@ def __init__(self, project):
item = ExactMerge.merge_items(existing_item, item, path=path)
else:
s_config = config.sources[source_name]
if s_config and \
s_config.format != env.PROJECT_EXTRACTOR_NAME:
if s_config and s_config.format:
# NOTE: consider imported sources as our own dataset
path = None
else:
Expand Down Expand Up @@ -150,7 +153,7 @@ def put(self, item, id=None, subset=None, path=None):
if path:
source = path[0]
# TODO: reverse remapping
self._sources[source].put(item, id=id, subset=subset)
self._sources[source].put(item, id=id, subset=subset, path=path[1:])

if id is None:
id = item.id
Expand Down Expand Up @@ -412,6 +415,8 @@ def __init__(self, config=None, env=None):
fallback=PROJECT_DEFAULT_CONFIG, schema=PROJECT_SCHEMA)
if env is None:
env = Environment(self.config)
env.models.batch_register(self.config.models)
env.sources.batch_register(self.config.sources)
elif config is not None:
raise ValueError("env can only be provided when no config provided")
self.env = env
Expand Down Expand Up @@ -482,6 +487,3 @@ def local_model_dir(self, model_name):

def local_source_dir(self, source_name):
return osp.join(self.config.sources_dir, source_name)

def load_project_as_dataset(url):
return Project.load(url).make_dataset()

0 comments on commit 9a2f3f2

Please sign in to comment.