Skip to content

Commit

Permalink
Merge pull request #2 from dwt0317/weda/demo
Browse files Browse the repository at this point in the history
Add entry and yaml files
  • Loading branch information
dwt0317 authored Dec 19, 2019
2 parents 7c0a6a3 + a7aa230 commit cd805bb
Show file tree
Hide file tree
Showing 2 changed files with 143 additions and 0 deletions.
74 changes: 74 additions & 0 deletions azureml-designer-modules/entries/stratified_splitter_entry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import argparse

from azureml.studio.core.logger import module_logger as logger
from reco_utils.dataset.python_splitters import python_stratified_split
from azureml.studio.core.data_frame_schema import DataFrameSchema
from azureml.studio.core.io.data_frame_directory import load_data_frame_from_directory, save_data_frame_to_directory


if __name__ == '__main__':
parser = argparse.ArgumentParser()

parser.add_argument(
'--input-path',
help='The input directory.',
)

parser.add_argument(
'--ratio', type=float,
help='A float parameter.',
)

parser.add_argument(
'--col-user', type=str,
help='A string parameter.',
)

parser.add_argument(
'--col-item', type=str,
help='A string parameter.',
)

parser.add_argument(
'--seed', type=int,
help='An int parameter.',
)

parser.add_argument(
'--output-train',
help='The output training data directory.',
)
parser.add_argument(
'--output-test',
help='The output test data directory.',
)

args, _ = parser.parse_known_args()

input_df = load_data_frame_from_directory(args.input_path).data

#logger.info(f"Hello world from {PACKAGE_NAME} {VERSION}")

ratio = args.ratio
col_user = args.col_user
col_item = args.col_item
seed = args.seed

logger.debug(f"Received parameters:")
logger.debug(f"Ratio: {ratio}")
logger.debug(f"User: {col_user}")
logger.debug(f"Item: {col_item}")
logger.debug(f"Seed: {seed}")

logger.debug(f"Input path: {args.input_path}")
logger.debug(f"Shape of loaded DataFrame: {input_df.shape}")
logger.debug(f"Cols of DataFrame: {input_df.columns}")

output_train, output_test = python_stratified_split(input_df, ratio=args.ratio, col_user=args.col_user, col_item=args.col_item, seed=args.seed)

logger.debug(f"Output path: {args.output_train}")
logger.debug(f"Output path: {args.output_test}")

save_data_frame_to_directory(args.output_train, output_train, schema=DataFrameSchema.data_frame_to_dict(output_train))
save_data_frame_to_directory(args.output_test, output_test, schema=DataFrameSchema.data_frame_to_dict(output_test))

69 changes: 69 additions & 0 deletions azureml-designer-modules/module_specs/stratified_splitter.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
name: Stratified Splitter
id: efd1af54-0d31-42e1-b3d5-ce3b7c538705
version: 0.0.22
category: Experimentation
description: "Python stratified splitter from CAT Recommender repo: https://github.com/Microsoft/Recommenders/tree/master/."
inputs:
- name: Input path
type: DataFrameDirectory
description: The directory contains dataframe.
port: true
- name: Ratio
type: Float
optional: True
description: >
Ratio for splitting data. If it is a single float number,
it splits data into two halves and the ratio argument indicates the ratio of
training data set; if it is a list of float numbers, the splitter splits
data into several portions corresponding to the split ratios. If a list is
provided and the ratios are not summed to 1, they will be normalized.
- name: User column
type: String
description: Column name of user IDs.
- name: Item column
type: String
description: Column name of item IDs.
- name: Seed
type: Int
min: 1
max: 100
default: 42
description: Seed.
outputs:
- name: Output train data
type: DataFrameDirectory
description: The output directory contains a training dataframe.
port: true
- name: Output test data
type: DataFrameDirectory
description: The output directory contains a test dataframe.
port: true
implementation:
container:
conda:
name: CAT_module_environment
channels:
- defaults
dependencies:
- python=3.7
- pip:
- azureml-designer-core==0.0.26.*
- azureml-designer-classic-modules==0.0.105
command:
- python
- azureml-designer-modules/entries/stratified_splitter_entry.py
args:
- --input-path
- inputPath: Input path
- --ratio
- inputValue: Ratio
- --col-user
- inputValue: User column
- --col-item
- inputValue: Item column
- --seed
- inputValue: Seed
- --output-train
- outputPath: Output train data
- --output-test
- outputPath: Output test data

0 comments on commit cd805bb

Please sign in to comment.