-
Notifications
You must be signed in to change notification settings - Fork 817
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
refactor ingest CLI for better code reuse (#1846)
### Description Much of the current CLI code is copy-paste across subcommands. To alleviate this, most of the duplicate code was moved into base classes for src and destination connector commands. This also allows for code reuse when a destination command is called and it no longer has to jump through hoops to dynamically recreate what _would_ have been called by a source command. The reason everything can't live in a single BaseCmd class is due to the need for a dynamic map to the source command. This runs into a circular dependency issue if it was all in one class. By splitting it into a `BaseSrcCmd` and a `BaseDestCmd` class, this helps avoid that issue.
- Loading branch information
Showing
36 changed files
with
433 additions
and
1,056 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
import typing as t | ||
|
||
import click | ||
|
||
from unstructured.ingest.cli.cmds import base_dest_cmd_fns, base_src_cmd_fns | ||
|
||
src: t.List[click.Group] = [v().get_src_cmd() for v in base_src_cmd_fns] | ||
|
||
dest: t.List[click.Command] = [v().get_dest_cmd() for v in base_dest_cmd_fns] | ||
|
||
__all__ = [ | ||
"src", | ||
"dest", | ||
] |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
import typing as t | ||
from abc import ABC | ||
from dataclasses import dataclass, field | ||
|
||
from unstructured.ingest.cli.interfaces import CliMixin | ||
from unstructured.ingest.interfaces import BaseConfig | ||
|
||
|
||
@dataclass | ||
class BaseCmd(ABC): | ||
cmd_name: str | ||
cli_config: t.Optional[t.Type[BaseConfig]] = None | ||
additional_cli_options: t.List[t.Type[CliMixin]] = field(default_factory=list) | ||
addition_configs: t.Dict[str, t.Type[BaseConfig]] = field(default_factory=dict) | ||
is_fsspec: bool = False | ||
|
||
@property | ||
def cmd_name_key(self): | ||
return self.cmd_name.replace("-", "_") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
import logging | ||
from dataclasses import dataclass | ||
|
||
import click | ||
|
||
from unstructured.ingest.cli.base.cmd import BaseCmd | ||
from unstructured.ingest.cli.cmd_factory import get_src_cmd | ||
from unstructured.ingest.cli.common import ( | ||
log_options, | ||
) | ||
from unstructured.ingest.cli.interfaces import ( | ||
CliFilesStorageConfig, | ||
) | ||
from unstructured.ingest.cli.utils import add_options, conform_click_options | ||
from unstructured.ingest.logger import ingest_log_streaming_init, logger | ||
|
||
|
||
@dataclass | ||
class BaseDestCmd(BaseCmd): | ||
def get_dest_runner(self, source_cmd: str, options: dict, parent_options: dict): | ||
src_cmd_fn = get_src_cmd(cmd_name=source_cmd) | ||
src_cmd = src_cmd_fn() | ||
runner = src_cmd.get_source_runner(options=parent_options) | ||
runner.writer_type = self.cmd_name_key | ||
runner.writer_kwargs = options | ||
return runner | ||
|
||
def check_dest_options(self, options: dict): | ||
self.cli_config.from_dict(options) | ||
|
||
def dest(self, ctx: click.Context, **options): | ||
if not ctx.parent: | ||
raise click.ClickException("destination command called without a parent") | ||
if not ctx.parent.info_name: | ||
raise click.ClickException("parent command missing info name") | ||
source_cmd = ctx.parent.info_name.replace("-", "_") | ||
parent_options: dict = ctx.parent.params if ctx.parent else {} | ||
conform_click_options(options) | ||
verbose = parent_options.get("verbose", False) | ||
ingest_log_streaming_init(logging.DEBUG if verbose else logging.INFO) | ||
log_options(parent_options, verbose=verbose) | ||
log_options(options, verbose=verbose) | ||
try: | ||
self.check_dest_options(options=options) | ||
runner = self.get_dest_runner( | ||
source_cmd=source_cmd, | ||
options=options, | ||
parent_options=parent_options, | ||
) | ||
runner.run(**parent_options) | ||
except Exception as e: | ||
logger.error(e, exc_info=True) | ||
raise click.ClickException(str(e)) from e | ||
|
||
def get_dest_cmd(self) -> click.Command: | ||
# Dynamically create the command without the use of click decorators | ||
fn = self.dest | ||
fn = click.pass_context(fn) | ||
cmd: click.Group = click.command(fn) | ||
cmd.name = self.cmd_name | ||
cmd.invoke_without_command = True | ||
options = [self.cli_config] if self.cli_config else [] | ||
options += self.additional_cli_options | ||
if self.is_fsspec and CliFilesStorageConfig not in options: | ||
options.append(CliFilesStorageConfig) | ||
add_options(cmd, extras=options, is_src=False) | ||
return cmd |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
import logging | ||
from dataclasses import dataclass | ||
|
||
import click | ||
|
||
from unstructured.ingest.cli.base.cmd import BaseCmd | ||
from unstructured.ingest.cli.common import ( | ||
log_options, | ||
) | ||
from unstructured.ingest.cli.interfaces import CliFilesStorageConfig | ||
from unstructured.ingest.cli.utils import Group, add_options, conform_click_options, extract_configs | ||
from unstructured.ingest.interfaces import FsspecConfig | ||
from unstructured.ingest.logger import ingest_log_streaming_init, logger | ||
from unstructured.ingest.runner import runner_map | ||
|
||
|
||
@dataclass | ||
class BaseSrcCmd(BaseCmd): | ||
def get_source_runner(self, options: dict): | ||
addition_configs = self.addition_configs | ||
if self.is_fsspec and "fsspec_config" not in addition_configs: | ||
addition_configs["fsspec_config"] = FsspecConfig | ||
configs = extract_configs( | ||
options, | ||
validate=[self.cli_config] if self.cli_config else None, | ||
extras=addition_configs, | ||
) | ||
runner = runner_map[self.cmd_name_key] | ||
return runner(**configs) # type: ignore | ||
|
||
def src(self, ctx: click.Context, **options): | ||
if ctx.invoked_subcommand: | ||
return | ||
|
||
conform_click_options(options) | ||
verbose = options.get("verbose", False) | ||
ingest_log_streaming_init(logging.DEBUG if verbose else logging.INFO) | ||
log_options(options, verbose=verbose) | ||
try: | ||
runner = self.get_source_runner(options=options) | ||
runner.run(**options) | ||
except Exception as e: | ||
logger.error(e, exc_info=True) | ||
raise click.ClickException(str(e)) from e | ||
|
||
def get_src_cmd(self) -> click.Group: | ||
# Dynamically create the command without the use of click decorators | ||
fn = self.src | ||
fn = click.pass_context(fn) | ||
cmd: click.Group = click.group(fn, cls=Group) | ||
cmd.name = self.cmd_name | ||
cmd.invoke_without_command = True | ||
extra_options = [self.cli_config] if self.cli_config else [] | ||
extra_options += self.additional_cli_options | ||
if self.is_fsspec and CliFilesStorageConfig not in extra_options: | ||
extra_options.append(CliFilesStorageConfig) | ||
add_options(cmd, extras=extra_options) | ||
return cmd |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
import typing as t | ||
|
||
from unstructured.ingest.cli.base.src import BaseSrcCmd | ||
from unstructured.ingest.cli.cmds import base_src_cmd_fns | ||
|
||
|
||
def get_src_cmd_map() -> t.Dict[str, t.Callable[[], BaseSrcCmd]]: | ||
return {b().cmd_name_key: b for b in base_src_cmd_fns} | ||
|
||
|
||
def get_src_cmd(cmd_name: str) -> t.Callable[[], BaseSrcCmd]: | ||
return get_src_cmd_map()[cmd_name] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,67 +1,97 @@ | ||
from __future__ import annotations | ||
|
||
import collections | ||
import typing as t | ||
|
||
import click | ||
from unstructured.ingest.cli.base.src import BaseSrcCmd | ||
|
||
from .airtable import get_base_src_cmd as airtable_base_src_cmd | ||
from .azure import get_base_src_cmd as azure_base_src_cmd | ||
from .azure_cognitive_search import get_base_dest_cmd as azure_cognitive_search_base_dest_cmd | ||
from .biomed import get_base_src_cmd as biomed_base_src_cmd | ||
from .box import get_base_src_cmd as box_base_src_cmd | ||
from .confluence import get_base_src_cmd as confluence_base_src_cmd | ||
from .delta_table import get_base_dest_cmd as delta_table_dest_cmd | ||
from .delta_table import get_base_src_cmd as delta_table_base_src_cmd | ||
from .discord import get_base_src_cmd as discord_base_src_cmd | ||
from .dropbox import get_base_src_cmd as dropbox_base_src_cmd | ||
from .elasticsearch import get_base_src_cmd as elasticsearch_base_src_cmd | ||
from .fsspec import get_base_src_cmd as fsspec_base_src_cmd | ||
from .gcs import get_base_src_cmd as gcs_base_src_cmd | ||
from .github import get_base_src_cmd as github_base_src_cmd | ||
from .gitlab import get_base_src_cmd as gitlab_base_src_cmd | ||
from .google_drive import get_base_src_cmd as google_drive_base_src_cmd | ||
from .jira import get_base_src_cmd as jira_base_src_cmd | ||
from .local import get_base_src_cmd as local_base_src_cmd | ||
from .notion import get_base_src_cmd as notion_base_src_cmd | ||
from .onedrive import get_base_src_cmd as onedrive_base_src_cmd | ||
from .outlook import get_base_src_cmd as outlook_base_src_cmd | ||
from .reddit import get_base_src_cmd as reddit_base_src_cmd | ||
from .s3 import get_base_dest_cmd as s3_base_dest_cmd | ||
from .s3 import get_base_src_cmd as s3_base_src_cmd | ||
from .salesforce import get_base_src_cmd as salesforce_base_src_cmd | ||
from .sharepoint import get_base_src_cmd as sharepoint_base_src_cmd | ||
from .slack import get_base_src_cmd as slack_base_src_cmd | ||
from .wikipedia import get_base_src_cmd as wikipedia_base_src_cmd | ||
|
||
if t.TYPE_CHECKING: | ||
from unstructured.ingest.cli.base.dest import BaseDestCmd | ||
|
||
base_src_cmd_fns: t.List[t.Callable[[], BaseSrcCmd]] = [ | ||
airtable_base_src_cmd, | ||
azure_base_src_cmd, | ||
biomed_base_src_cmd, | ||
box_base_src_cmd, | ||
confluence_base_src_cmd, | ||
delta_table_base_src_cmd, | ||
discord_base_src_cmd, | ||
dropbox_base_src_cmd, | ||
elasticsearch_base_src_cmd, | ||
fsspec_base_src_cmd, | ||
gcs_base_src_cmd, | ||
github_base_src_cmd, | ||
gitlab_base_src_cmd, | ||
google_drive_base_src_cmd, | ||
jira_base_src_cmd, | ||
local_base_src_cmd, | ||
notion_base_src_cmd, | ||
onedrive_base_src_cmd, | ||
outlook_base_src_cmd, | ||
reddit_base_src_cmd, | ||
salesforce_base_src_cmd, | ||
sharepoint_base_src_cmd, | ||
slack_base_src_cmd, | ||
s3_base_src_cmd, | ||
wikipedia_base_src_cmd, | ||
] | ||
|
||
from .airtable import get_source_cmd as airtable_src | ||
from .azure import get_source_cmd as azure_src | ||
from .azure_cognitive_search import get_dest_cmd as azure_cognitive_search_dest | ||
from .biomed import get_source_cmd as biomed_src | ||
from .box import get_source_cmd as box_src | ||
from .confluence import get_source_cmd as confluence_src | ||
from .delta_table import get_dest_cmd as delta_table_dest | ||
from .delta_table import get_source_cmd as delta_table_src | ||
from .discord import get_source_cmd as discord_src | ||
from .dropbox import get_source_cmd as dropbox_src | ||
from .elasticsearch import get_source_cmd as elasticsearch_src | ||
from .fsspec import get_source_cmd as fsspec_src | ||
from .gcs import get_source_cmd as gcs_src | ||
from .github import get_source_cmd as github_src | ||
from .gitlab import get_source_cmd as gitlab_src | ||
from .google_drive import get_source_cmd as google_drive_src | ||
from .jira import get_source_cmd as jira_src | ||
from .local import get_source_cmd as local_src | ||
from .notion import get_source_cmd as notion_src | ||
from .onedrive import get_source_cmd as onedrive_src | ||
from .outlook import get_source_cmd as outlook_src | ||
from .reddit import get_source_cmd as reddit_src | ||
from .s3 import get_dest_cmd as s3_dest | ||
from .s3 import get_source_cmd as s3_src | ||
from .salesforce import get_source_cmd as salesforce_src | ||
from .sharepoint import get_source_cmd as sharepoint_src | ||
from .slack import get_source_cmd as slack_src | ||
from .wikipedia import get_source_cmd as wikipedia_src | ||
# Make sure there are not overlapping names | ||
src_cmd_names = [b().cmd_name for b in base_src_cmd_fns] | ||
src_duplicates = [item for item, count in collections.Counter(src_cmd_names).items() if count > 1] | ||
if src_duplicates: | ||
raise ValueError( | ||
"multiple base src commands defined with the same names: {}".format( | ||
", ".join(src_duplicates), | ||
), | ||
) | ||
|
||
src: t.List[click.Group] = [ | ||
airtable_src(), | ||
azure_src(), | ||
biomed_src(), | ||
box_src(), | ||
confluence_src(), | ||
delta_table_src(), | ||
discord_src(), | ||
dropbox_src(), | ||
elasticsearch_src(), | ||
fsspec_src(), | ||
gcs_src(), | ||
github_src(), | ||
gitlab_src(), | ||
google_drive_src(), | ||
jira_src(), | ||
local_src(), | ||
notion_src(), | ||
onedrive_src(), | ||
outlook_src(), | ||
reddit_src(), | ||
salesforce_src(), | ||
sharepoint_src(), | ||
slack_src(), | ||
s3_src(), | ||
wikipedia_src(), | ||
base_dest_cmd_fns: t.List[t.Callable[[], "BaseDestCmd"]] = [ | ||
s3_base_dest_cmd, | ||
azure_cognitive_search_base_dest_cmd, | ||
delta_table_dest_cmd, | ||
] | ||
|
||
dest: t.List[click.Command] = [azure_cognitive_search_dest(), s3_dest(), delta_table_dest()] | ||
# Make sure there are not overlapping names | ||
dest_cmd_names = [b().cmd_name for b in base_dest_cmd_fns] | ||
dest_duplicates = [item for item, count in collections.Counter(dest_cmd_names).items() if count > 1] | ||
if dest_duplicates: | ||
raise ValueError( | ||
"multiple base dest commands defined with the same names: {}".format( | ||
", ".join(dest_duplicates), | ||
), | ||
) | ||
|
||
__all__ = [ | ||
"src", | ||
"dest", | ||
"base_src_cmd_fns", | ||
"base_dest_cmd_fns", | ||
] |
Oops, something went wrong.