Skip to content

Commit

Permalink
fix cli
Browse files Browse the repository at this point in the history
  • Loading branch information
chensun committed Apr 11, 2023
1 parent 5eed9c9 commit 71af82d
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 88 deletions.
4 changes: 2 additions & 2 deletions sdk/python/kfp/cli/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def archive(ctx: click.Context, experiment_id: str, experiment_name: str):

if not experiment_id:
experiment = client_obj.get_experiment(experiment_name=experiment_name)
experiment_id = experiment.id
experiment_id = experiment.experiment_id

client_obj.archive_experiment(experiment_id=experiment_id)
if experiment_id:
Expand Down Expand Up @@ -162,7 +162,7 @@ def unarchive(ctx: click.Context, experiment_id: str, experiment_name: str):

if not experiment_id:
experiment = client_obj.get_experiment(experiment_name=experiment_name)
experiment_id = experiment.id
experiment_id = experiment.experiment_id

client_obj.unarchive_experiment(experiment_id=experiment_id)
if experiment_id:
Expand Down
90 changes: 33 additions & 57 deletions sdk/python/kfp/cli/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,17 +45,6 @@ class OutputFormat(enum.Enum):
json = 'json'


RUN_STORAGE_STATE_MAP = {
kfp_server_api.ApiRunStorageState.AVAILABLE: 'Available',
kfp_server_api.ApiRunStorageState.ARCHIVED: 'Archived',
}
EXPERIMENT_STORAGE_STATE_MAP = {
kfp_server_api.ApiExperimentStorageState.AVAILABLE: 'Available',
kfp_server_api.ApiExperimentStorageState.ARCHIVED: 'Archived',
kfp_server_api.ApiExperimentStorageState.UNSPECIFIED: 'Unspecified',
}


def snake_to_header(string: str) -> str:
"""Converts a snake case string to a table header by replacing underscores
with spaces and making uppercase.
Expand All @@ -74,39 +63,34 @@ class ExperimentData:
id: str
name: str
created_at: str
state: str
storage_state: str


def transform_experiment(exp: kfp_server_api.ApiExperiment) -> Dict[str, Any]:
def transform_experiment(
exp: kfp_server_api.V2beta1Experiment) -> Dict[str, Any]:
return dataclasses.asdict(
ExperimentData(
id=exp.id,
name=exp.name,
id=exp.experiment_id,
name=exp.display_name,
created_at=exp.created_at.isoformat(),
state=EXPERIMENT_STORAGE_STATE_MAP.get(
exp.storage_state, EXPERIMENT_STORAGE_STATE_MAP[
kfp_server_api.ApiExperimentStorageState.AVAILABLE])))
storage_state=exp.storage_state))


@dataclasses.dataclass
class PipelineData:
id: str
name: str
created_at: str
default_version: str


def transform_pipeline(pipeline: kfp_server_api.ApiPipeline) -> Dict[str, Any]:
default_version_id = pipeline.default_version.id if hasattr(
pipeline,
'default_version') and pipeline.default_version is not None and hasattr(
pipeline.default_version, 'id') else None
def transform_pipeline(
pipeline: kfp_server_api.V2beta1Pipeline) -> Dict[str, Any]:
return dataclasses.asdict(
PipelineData(
id=pipeline.id,
name=pipeline.name,
id=pipeline.pipeline_id,
name=pipeline.display_name,
created_at=pipeline.created_at.isoformat(),
default_version=default_version_id))
))


@dataclasses.dataclass
Expand All @@ -118,16 +102,14 @@ class PipelineVersionData:


def transform_pipeline_version(
pipeline_version: kfp_server_api.ApiPipelineVersion) -> Dict[str, Any]:
parent_id = next(
rr for rr in pipeline_version.resource_references
if rr.relationship == kfp_server_api.ApiRelationship.OWNER).key.id
pipeline_version: kfp_server_api.V2beta1PipelineVersion
) -> Dict[str, Any]:
return dataclasses.asdict(
PipelineVersionData(
id=pipeline_version.id,
name=pipeline_version.name,
id=pipeline_version.pipeline_version_id,
name=pipeline_version.display_name,
created_at=pipeline_version.created_at.isoformat(),
parent_id=parent_id,
parent_id=pipeline_version.pipeline_id,
))


Expand All @@ -136,43 +118,37 @@ class RunData:
id: str
name: str
created_at: str
status: str
state: str
storage_state: str


def transform_run(
run: Union[kfp_server_api.ApiRun, kfp_server_api.ApiRunDetail]
) -> Dict[str, Any]:
def transform_run(run: kfp_server_api.V2beta1Run) -> Dict[str, Any]:
return dataclasses.asdict((RunData(
id=run.id,
name=run.name,
id=run.run_id,
name=run.display_name,
created_at=run.created_at.isoformat(),
status=run.status,
state=RUN_STORAGE_STATE_MAP.get(
run.storage_state,
RUN_STORAGE_STATE_MAP[kfp_server_api.ApiRunStorageState.AVAILABLE]))
))
state=run.state,
storage_state=run.storage_state,
)))


@dataclasses.dataclass
class JobData:
class RecurringRunData:
id: str
name: str
created_at: str
experiment_id: str
status: str


def transform_job(recurring_run: kfp_server_api.ApiJob) -> Dict[str, Any]:
experiment_id = next(
rr for rr in recurring_run.resource_references
if rr.key.type == kfp_server_api.ApiResourceType.EXPERIMENT).key.id
def transform_recurring_run(
recurring_run: kfp_server_api.V2beta1RecurringRun) -> Dict[str, Any]:
return dataclasses.asdict(
JobData(
id=recurring_run.id,
name=recurring_run.name,
RecurringRunData(
id=recurring_run.recurring_run_id,
name=recurring_run.display_name,
created_at=recurring_run.created_at.isoformat(),
experiment_id=experiment_id,
experiment_id=recurring_run.experiment_id,
status=recurring_run.status))


Expand All @@ -183,23 +159,23 @@ class ModelType(enum.Enum):
PIPELINE = 'PIPELINE'
PIPELINE_VERSION = 'PIPELINE_VERSION'
RUN = 'RUN'
JOB = 'JOB'
RECURRING_RUN = 'RECURRING_RUN'


transformer_map = {
ModelType.EXPERIMENT: transform_experiment,
ModelType.PIPELINE: transform_pipeline,
ModelType.PIPELINE_VERSION: transform_pipeline_version,
ModelType.RUN: transform_run,
ModelType.JOB: transform_job,
ModelType.RECURRING_RUN: transform_recurring_run,
}

dataclass_map = {
ModelType.EXPERIMENT: ExperimentData,
ModelType.PIPELINE: PipelineData,
ModelType.PIPELINE_VERSION: PipelineVersionData,
ModelType.RUN: RunData,
ModelType.JOB: JobData,
ModelType.RECURRING_RUN: RecurringRunData,
}


Expand Down
12 changes: 8 additions & 4 deletions sdk/python/kfp/cli/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,9 +198,10 @@ def list_versions(ctx: click.Context, pipeline_id: str, page_token: str,


@pipeline.command()
@click.argument('pipeline-id')
@click.argument('version-id')
@click.pass_context
def delete_version(ctx: click.Context, version_id: str):
def delete_version(ctx: click.Context, pipeline_id: str, version_id: str):
"""Delete a version of a pipeline."""
confirmation = f'Are you sure you want to delete pipeline version {version_id}?'
if not click.confirm(confirmation):
Expand All @@ -209,7 +210,8 @@ def delete_version(ctx: click.Context, version_id: str):
client_obj: client.Client = ctx.obj['client']
output_format = ctx.obj['output']

client_obj.delete_pipeline_version(version_id)
client_obj.delete_pipeline_version(
pipeline_id=pipeline_id, pipeline_version_id=version_id)
output.print_deleted_text('pipeline version', version_id, output_format)


Expand All @@ -230,14 +232,16 @@ def get(ctx: click.Context, pipeline_id: str):


@pipeline.command()
@click.argument('pipeline-id')
@click.argument('version-id')
@click.pass_context
def get_version(ctx: click.Context, version_id: str):
def get_version(ctx: click.Context, pipeline_id: str, version_id: str):
"""Get information about a version of a pipeline."""
client_obj: client.Client = ctx.obj['client']
output_format = ctx.obj['output']

version = client_obj.get_pipeline_version(version_id=version_id)
version = client_obj.get_pipeline_version(
pipeline_id=pipeline_id, pipeline_version_id=version_id)
output.print_output(
version,
output.ModelType.PIPELINE,
Expand Down
46 changes: 24 additions & 22 deletions sdk/python/kfp/cli/recurring_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def create(ctx: click.Context,
version_id=version_id)
output.print_output(
recurring_run,
output.ModelType.JOB,
output.ModelType.RECURRING_RUN,
output_format,
)

Expand Down Expand Up @@ -202,73 +202,75 @@ def list(ctx: click.Context, experiment_id: str, page_token: str, max_size: int,
sort_by=sort_by,
filter=filter)
output.print_output(
response.jobs or [],
output.ModelType.JOB,
response.recurring_runs or [],
output.ModelType.RECURRING_RUN,
output_format,
)


@recurring_run.command()
@click.argument('job-id')
@click.argument('recurring-run-id')
@click.pass_context
def get(ctx: click.Context, job_id: str):
def get(ctx: click.Context, recurring_run_id: str):
"""Get information about a recurring run."""
client_obj: client.Client = ctx.obj['client']
output_format = ctx.obj['output']

recurring_run = client_obj.get_recurring_run(job_id)
recurring_run = client_obj.get_recurring_run(recurring_run_id)
output.print_output(
recurring_run,
output.ModelType.JOB,
output.ModelType.RECURRING_RUN,
output_format,
)


@recurring_run.command()
@click.argument('job-id')
@click.argument('recurring-run-id')
@click.pass_context
def delete(ctx: click.Context, job_id: str):
def delete(ctx: click.Context, recurring_run_id: str):
"""Delete a recurring run."""
client_obj: client.Client = ctx.obj['client']
output_format = ctx.obj['output']
confirmation = f'Are you sure you want to delete job {job_id}?'
confirmation = f'Are you sure you want to delete job {recurring_run_id}?'
if not click.confirm(confirmation):
return
client_obj.delete_job(job_id)
output.print_deleted_text('job', job_id, output_format)
client_obj.delete_recurring_run(recurring_run_id)
output.print_deleted_text('recurring_run', recurring_run_id, output_format)


@recurring_run.command()
@click.argument('job-id')
@click.argument('recurring-run-id')
@click.pass_context
def enable(ctx: click.Context, job_id: str):
def enable(ctx: click.Context, recurring_run_id: str):
"""Enable a recurring run."""
client_obj: client.Client = ctx.obj['client']
output_format = ctx.obj['output']

client_obj.enable_job(job_id=job_id)
client_obj.enable_recurring_run(recurring_run_id=recurring_run_id)
# TODO: add wait option, since enable takes time to complete
recurring_run = client_obj.get_recurring_run(job_id=job_id)
recurring_run = client_obj.get_recurring_run(
recurring_run_id=recurring_run_id)
output.print_output(
recurring_run,
output.ModelType.JOB,
output.ModelType.RECURRING_RUN,
output_format,
)


@recurring_run.command()
@click.argument('job-id')
@click.argument('recurring-run-id')
@click.pass_context
def disable(ctx: click.Context, job_id: str):
def disable(ctx: click.Context, recurring_run_id: str):
"""Disable a recurring run."""
client_obj: client.Client = ctx.obj['client']
output_format = ctx.obj['output']

client_obj.disable_job(job_id=job_id)
client_obj.disable_recurring_run(recurring_run_id=recurring_run_id)
# TODO: add wait option, since disable takes time to complete
recurring_run = client_obj.get_recurring_run(job_id=job_id)
recurring_run = client_obj.get_recurring_run(
recurring_run_id=recurring_run_id)
output.print_output(
recurring_run,
output.ModelType.JOB,
output.ModelType.RECURRING_RUN,
output_format,
)
6 changes: 3 additions & 3 deletions sdk/python/kfp/cli/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,21 +133,21 @@ def create(ctx: click.Context, experiment_name: str, run_name: str,

experiment = client_obj.create_experiment(experiment_name)
run = client_obj.run_pipeline(
experiment_id=experiment.id,
experiment_id=experiment.experiment_id,
job_name=run_name,
pipeline_package_path=package_file,
params=arg_dict,
pipeline_id=pipeline_id,
version_id=version)
if timeout > 0:
run_detail = client_obj.wait_for_run_completion(run.id, timeout)
run_detail = client_obj.wait_for_run_completion(run.run_id, timeout)
output.print_output(
run_detail.run,
output.ModelType.RUN,
output_format,
)
else:
display_run(client_obj, namespace, run.id, watch, output_format)
display_run(client_obj, namespace, run.run_id, watch, output_format)


@run.command()
Expand Down

0 comments on commit 71af82d

Please sign in to comment.