Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add lazy loading for builtin plugins #306

Merged
merged 3 commits into from
Jun 20, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Fixed
- Incorrect image layout on saving and a problem with ecoding on loading (<https://github.com/openvinotoolkit/datumaro/pull/284>)
- An error when xpath fiter is applied to the dataset or its subset (<https://github.com/openvinotoolkit/datumaro/issues/259>)
- Improved CLI startup time in several cases (<https://github.com/openvinotoolkit/datumaro/pull/306>)

### Security
-
Expand Down
170 changes: 83 additions & 87 deletions datumaro/components/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,27 +9,18 @@
import logging as log
import os
import os.path as osp
from typing import Dict, Iterable

from datumaro.components.config import Config
from datumaro.components.config_model import Model, Source
from datumaro.util.os_util import import_foreign_module


class Registry:
def __init__(self, config=None, item_type=None):
self.item_type = item_type

def __init__(self):
self.items = {}

if config is not None:
self.load(config)

def load(self, config):
pass

def register(self, name, value):
if self.item_type:
value = self.item_type(value)
self.items[name] = value
return value

Expand All @@ -46,44 +37,34 @@ def __getitem__(self, key):
def __contains__(self, key):
return key in self.items

def __iter__(self):
return iter(self.items)

class ModelRegistry(Registry):
def __init__(self, config=None):
super().__init__(config, item_type=Model)

def load(self, config):
# TODO: list default dir, insert values
if 'models' in config:
for name, model in config.models.items():
self.register(name, model)
def batch_register(self, items: Dict[str, Model]):
for name, model in items.items():
self.register(name, model)


class SourceRegistry(Registry):
def __init__(self, config=None):
super().__init__(config, item_type=Source)

def load(self, config):
# TODO: list default dir, insert values
if 'sources' in config:
for name, source in config.sources.items():
self.register(name, source)

def batch_register(self, items: Dict[str, Source]):
for name, source in items.items():
self.register(name, source)

class PluginRegistry(Registry):
def __init__(self, config=None, builtin=None, local=None):
super().__init__(config)
def __init__(self, filter=None): #pylint: disable=redefined-builtin
super().__init__()
self.filter = filter

def batch_register(self, values: Iterable):
from datumaro.components.cli_plugin import CliPlugin

if builtin is not None:
for v in builtin:
k = CliPlugin._get_name(v)
self.register(k, v)
if local is not None:
for v in local:
k = CliPlugin._get_name(v)
self.register(k, v)
for v in values:
if self.filter and not self.filter(v):
continue
name = CliPlugin._get_name(v)

self.register(name, v)

class GitWrapper:
def __init__(self, config=None):
Expand Down Expand Up @@ -128,55 +109,60 @@ def remove_submodule(self, name, **kwargs):

class Environment:
_builtin_plugins = None
PROJECT_EXTRACTOR_NAME = 'datumaro_project'

def __init__(self, config=None):
from datumaro.components.project import (
PROJECT_DEFAULT_CONFIG, PROJECT_SCHEMA, load_project_as_dataset)
PROJECT_DEFAULT_CONFIG, PROJECT_SCHEMA)
config = Config(config,
fallback=PROJECT_DEFAULT_CONFIG, schema=PROJECT_SCHEMA)

self.models = ModelRegistry(config)
self.sources = SourceRegistry(config)
self.models = ModelRegistry()
self.sources = SourceRegistry()

self.git = GitWrapper(config)

env_dir = osp.join(config.project_dir, config.env_dir)
builtin = self._load_builtin_plugins()
custom = self._load_plugins2(osp.join(env_dir, config.plugins_dir))
select = lambda seq, t: [e for e in seq if issubclass(e, t)]
def _filter(accept, skip=None):
accept = (accept, ) if inspect.isclass(accept) else tuple(accept)
skip = {skip} if inspect.isclass(skip) else set(skip or [])
skip = tuple(skip | set(accept))
return lambda t: issubclass(t, accept) and t not in skip

from datumaro.components.converter import Converter
from datumaro.components.extractor import (Importer, Extractor,
Transform)
SourceExtractor, Transform)
from datumaro.components.launcher import Launcher
from datumaro.components.validator import Validator
self.extractors = PluginRegistry(
builtin=select(builtin, Extractor),
local=select(custom, Extractor)
)
self.extractors.register(self.PROJECT_EXTRACTOR_NAME,
load_project_as_dataset)

self.importers = PluginRegistry(
builtin=select(builtin, Importer),
local=select(custom, Importer)
)
self.launchers = PluginRegistry(
builtin=select(builtin, Launcher),
local=select(custom, Launcher)
)
self.converters = PluginRegistry(
builtin=select(builtin, Converter),
local=select(custom, Converter)
)
self.transforms = PluginRegistry(
builtin=select(builtin, Transform),
local=select(custom, Transform)
)
self.validators = PluginRegistry(
builtin=select(builtin, Validator),
local=select(custom, Validator)
)
self._extractors = PluginRegistry(_filter(Extractor, SourceExtractor))
self._importers = PluginRegistry(_filter(Importer))
self._launchers = PluginRegistry(_filter(Launcher))
self._converters = PluginRegistry(_filter(Converter))
self._transforms = PluginRegistry(_filter(Transform))
self._builtins_initialized = False

def _get_plugin_registry(self, name):
if not self._builtins_initialized:
self._builtins_initialized = True
self._register_builtin_plugins()
return getattr(self, name)

@property
def extractors(self) -> PluginRegistry:
return self._get_plugin_registry('_extractors')

@property
def importers(self) -> PluginRegistry:
return self._get_plugin_registry('_importers')

@property
def launchers(self) -> PluginRegistry:
return self._get_plugin_registry('_launchers')

@property
def converters(self) -> PluginRegistry:
return self._get_plugin_registry('_converters')

@property
def transforms(self) -> PluginRegistry:
return self._get_plugin_registry('_transforms')

@staticmethod
def _find_plugins(plugins_dir):
Expand Down Expand Up @@ -216,7 +202,14 @@ def _import_module(cls, module_dir, module_name, types, package=None):
return exports

@classmethod
def _load_plugins(cls, plugins_dir, types):
def _load_plugins(cls, plugins_dir, types=None):
if not types:
from datumaro.components.converter import Converter
from datumaro.components.extractor import (Extractor, Importer,
Transform)
from datumaro.components.launcher import Launcher
types = [Extractor, Converter, Importer, Launcher, Transform]

types = tuple(types)

plugins = cls._find_plugins(plugins_dir)
Expand Down Expand Up @@ -252,25 +245,28 @@ def _load_plugins(cls, plugins_dir, types):

@classmethod
def _load_builtin_plugins(cls):
if not cls._builtin_plugins:
if cls._builtin_plugins is None:
plugins_dir = osp.join(
__file__[: __file__.rfind(osp.join('datumaro', 'components'))],
osp.join('datumaro', 'plugins')
)
assert osp.isdir(plugins_dir), plugins_dir
cls._builtin_plugins = cls._load_plugins2(plugins_dir)
cls._builtin_plugins = cls._load_plugins(plugins_dir)
return cls._builtin_plugins

@classmethod
def _load_plugins2(cls, plugins_dir):
from datumaro.components.converter import Converter
from datumaro.components.extractor import (Extractor, Importer,
Transform)
from datumaro.components.launcher import Launcher
from datumaro.components.validator import Validator
types = [Extractor, Converter, Importer, Launcher, Transform, Validator]
def load_plugins(self, plugins_dir):
plugins = self._load_plugins(plugins_dir)
self._register_plugins(plugins)

def _register_builtin_plugins(self):
self._register_plugins(self._load_builtin_plugins())

return cls._load_plugins(plugins_dir, types)
def _register_plugins(self, plugins):
self.extractors.batch_register(plugins)
self.importers.batch_register(plugins)
self.launchers.batch_register(plugins)
self.converters.batch_register(plugins)
self.transforms.batch_register(plugins)

def make_extractor(self, name, *args, **kwargs):
return self.extractors.get(name)(*args, **kwargs)
Expand Down
20 changes: 11 additions & 9 deletions datumaro/components/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,17 @@ def __init__(self, project):

sources = {}
for s_name, source in config.sources.items():
s_format = source.format or env.PROJECT_EXTRACTOR_NAME
s_format = source.format

url = source.url
if not source.url:
url = osp.join(config.project_dir, config.sources_dir, s_name)
sources[s_name] = Dataset.import_from(url,
format=s_format, env=env, **source.options)

if s_format:
sources[s_name] = Dataset.import_from(url,
format=s_format, env=env, **source.options)
else:
sources[s_name] = Project.load(url).make_dataset()
self._sources = sources

own_source = None
Expand Down Expand Up @@ -91,8 +95,7 @@ def __init__(self, project):
item = ExactMerge.merge_items(existing_item, item, path=path)
else:
s_config = config.sources[source_name]
if s_config and \
s_config.format != env.PROJECT_EXTRACTOR_NAME:
if s_config and s_config.format:
# NOTE: consider imported sources as our own dataset
path = None
else:
Expand Down Expand Up @@ -152,7 +155,7 @@ def put(self, item, id=None, subset=None, \
if path:
source = path[0]
# TODO: reverse remapping
self._sources[source].put(item, id=id, subset=subset)
self._sources[source].put(item, id=id, subset=subset, path=path[1:])

if id is None:
id = item.id
Expand Down Expand Up @@ -415,6 +418,8 @@ def __init__(self, config=None, env=None):
fallback=PROJECT_DEFAULT_CONFIG, schema=PROJECT_SCHEMA)
if env is None:
env = Environment(self.config)
env.models.batch_register(self.config.models)
env.sources.batch_register(self.config.sources)
elif config is not None:
raise ValueError("env can only be provided when no config provided")
self.env = env
Expand Down Expand Up @@ -485,6 +490,3 @@ def local_model_dir(self, model_name):

def local_source_dir(self, source_name):
return osp.join(self.config.sources_dir, source_name)

def load_project_as_dataset(url):
return Project.load(url).make_dataset()