-
Notifications
You must be signed in to change notification settings - Fork 89
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix: environment variables specified in Model or Estimator are not passed through to SageMaker ModelStep #160
Changes from all commits
597fe1d
ffbcf53
04cd813
ea3e482
f8008e8
6a4a826
dda1dfb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -159,6 +159,8 @@ def get_expected_model(self, model_name=None): | |
model.name = model_name | ||
else: | ||
model.name = self.job_name | ||
if self.estimator.environment: | ||
model.env = self.estimator.environment | ||
model.model_data = self.output()["ModelArtifacts"]["S3ModelArtifacts"] | ||
return model | ||
|
||
|
@@ -295,7 +297,7 @@ def __init__(self, state_id, model, model_name=None, instance_type=None, tags=No | |
'ExecutionRoleArn': model.role, | ||
'ModelName': model_name or model.name, | ||
'PrimaryContainer': { | ||
'Environment': {}, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the original change that this one superseded (#84) only modified this line. the rationale/need for this change isn't quite captured in the issue this is resolving. Although it's supported, we need to get into why we are making this change. As it modifies your control flow, also suggest adding tests for "all model types" There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I opted for this change since I saw it was supported in sagemaker sdk since the issue was opened, but I agree that we should not introduce breaking changes - even more if the issue does not capture the need to. I'll revert the changes and go for the solution that was proposed in #84 and add test to validate that behaviour is the same There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
We have tests that use a FrameworkModel (ex: test_training_step_creation_with_framework() and others that use a Model (ex: test_training_step_creation_with_model()), but none that pass both. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I didn't mean in a single test (is that even possible). was just verifying that both parts of the control flow are being tested. I would assume that if they already existed, one of them would have broken with the attempt to introduce a breaking change. if that didn't happen, i think it surfaces a gap in testing and we should use this opportunity to plug it in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I added a test in the latest commit that confirms we are consistent |
||
'Environment': model.env, | ||
'Image': model.image_uri, | ||
'ModelDataUrl': model.model_data | ||
} | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -68,6 +68,41 @@ def pca_estimator(): | |
return pca | ||
|
||
|
||
@pytest.fixture | ||
def pca_estimator_with_env(): | ||
s3_output_location = 's3://sagemaker/models' | ||
|
||
pca = sagemaker.estimator.Estimator( | ||
PCA_IMAGE, | ||
role=EXECUTION_ROLE, | ||
instance_count=1, | ||
instance_type='ml.c4.xlarge', | ||
output_path=s3_output_location, | ||
environment={ | ||
'JobName': "job_name", | ||
'ModelName': "model_name" | ||
}, | ||
subnets=[ | ||
'subnet-00000000000000000', | ||
'subnet-00000000000000001' | ||
] | ||
) | ||
|
||
pca.set_hyperparameters( | ||
feature_dim=50000, | ||
num_components=10, | ||
subtract_mean=True, | ||
algorithm_mode='randomized', | ||
mini_batch_size=200 | ||
) | ||
|
||
pca.sagemaker_session = MagicMock() | ||
pca.sagemaker_session.boto_region_name = 'us-east-1' | ||
pca.sagemaker_session._default_bucket = 'sagemaker' | ||
|
||
return pca | ||
|
||
|
||
@pytest.fixture | ||
def pca_estimator_with_debug_hook(): | ||
s3_output_location = 's3://sagemaker/models' | ||
|
@@ -156,6 +191,31 @@ def pca_model(): | |
) | ||
|
||
|
||
@pytest.fixture | ||
def pca_model_with_env(): | ||
ca-nguyen marked this conversation as resolved.
Show resolved
Hide resolved
|
||
model_data = 's3://sagemaker/models/pca.tar.gz' | ||
return Model( | ||
model_data=model_data, | ||
image_uri=PCA_IMAGE, | ||
role=EXECUTION_ROLE, | ||
name='pca-model', | ||
env={ | ||
'JobName': "job_name", | ||
'ModelName': "model_name" | ||
}, | ||
vpc_config={ | ||
"SecurityGroupIds": ["sg-00000000000000000"], | ||
"Subnets": ["subnet-00000000000000000", "subnet-00000000000000001"] | ||
}, | ||
image_config={ | ||
"RepositoryAccessMode": "Vpc", | ||
"RepositoryAuthConfig": { | ||
"RepositoryCredentialsProviderArn": "arn" | ||
} | ||
} | ||
) | ||
|
||
|
||
@pytest.fixture | ||
def pca_transformer(pca_model): | ||
return Transformer( | ||
|
@@ -537,6 +597,63 @@ def test_training_step_creation_with_model(pca_estimator): | |
} | ||
|
||
|
||
@patch('botocore.client.BaseClient._make_api_call', new=mock_boto_api_call) | ||
@patch.object(boto3.session.Session, 'region_name', 'us-east-1') | ||
def test_training_step_creation_with_model_with_env(pca_estimator_with_env): | ||
training_step = TrainingStep('Training', estimator=pca_estimator_with_env, job_name='TrainingJob') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This test is using a new model with a defined env, vpc_config and image_config There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (not blocking: but please follow this guidance for future contributions) You're following the existing conventions here, but let's try to structure the tests as documentation so they are easier to read in the future.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ACK |
||
model_step = ModelStep('Training - Save Model', training_step.get_expected_model(model_name=training_step.output()['TrainingJobName'])) | ||
training_step.next(model_step) | ||
assert training_step.to_dict() == { | ||
'Type': 'Task', | ||
'Parameters': { | ||
'AlgorithmSpecification': { | ||
'TrainingImage': PCA_IMAGE, | ||
'TrainingInputMode': 'File' | ||
}, | ||
'OutputDataConfig': { | ||
'S3OutputPath': 's3://sagemaker/models' | ||
}, | ||
'StoppingCondition': { | ||
'MaxRuntimeInSeconds': 86400 | ||
}, | ||
'ResourceConfig': { | ||
'InstanceCount': 1, | ||
'InstanceType': 'ml.c4.xlarge', | ||
'VolumeSizeInGB': 30 | ||
}, | ||
'RoleArn': EXECUTION_ROLE, | ||
'HyperParameters': { | ||
'feature_dim': '50000', | ||
'num_components': '10', | ||
'subtract_mean': 'True', | ||
'algorithm_mode': 'randomized', | ||
'mini_batch_size': '200' | ||
}, | ||
'TrainingJobName': 'TrainingJob' | ||
}, | ||
'Resource': 'arn:aws:states:::sagemaker:createTrainingJob.sync', | ||
'Next': 'Training - Save Model' | ||
} | ||
|
||
assert model_step.to_dict() == { | ||
'Type': 'Task', | ||
'Resource': 'arn:aws:states:::sagemaker:createModel', | ||
'Parameters': { | ||
'ExecutionRoleArn': EXECUTION_ROLE, | ||
'ModelName.$': "$['TrainingJobName']", | ||
'PrimaryContainer': { | ||
'Environment': { | ||
'JobName': 'job_name', | ||
'ModelName': 'model_name' | ||
}, | ||
ca-nguyen marked this conversation as resolved.
Show resolved
Hide resolved
|
||
'Image': PCA_IMAGE, | ||
'ModelDataUrl.$': "$['ModelArtifacts']['S3ModelArtifacts']" | ||
} | ||
}, | ||
'End': True | ||
} | ||
|
||
|
||
@patch('botocore.client.BaseClient._make_api_call', new=mock_boto_api_call) | ||
@patch.object(boto3.session.Session, 'region_name', 'us-east-1') | ||
def test_training_step_creation_with_framework(tensorflow_estimator): | ||
|
@@ -806,6 +923,31 @@ def test_get_expected_model(pca_estimator): | |
} | ||
|
||
|
||
@patch('botocore.client.BaseClient._make_api_call', new=mock_boto_api_call) | ||
@patch.object(boto3.session.Session, 'region_name', 'us-east-1') | ||
def test_get_expected_model_with_env(pca_estimator_with_env): | ||
training_step = TrainingStep('Training', estimator=pca_estimator_with_env, job_name='TrainingJob') | ||
expected_model = training_step.get_expected_model() | ||
model_step = ModelStep('Create model', model=expected_model, model_name='pca-model') | ||
assert model_step.to_dict() == { | ||
'Type': 'Task', | ||
'Parameters': { | ||
'ExecutionRoleArn': EXECUTION_ROLE, | ||
'ModelName': 'pca-model', | ||
'PrimaryContainer': { | ||
'Environment': { | ||
'JobName': 'job_name', | ||
'ModelName': 'model_name' | ||
}, | ||
'Image': expected_model.image_uri, | ||
'ModelDataUrl.$': "$['ModelArtifacts']['S3ModelArtifacts']" | ||
} | ||
}, | ||
'Resource': 'arn:aws:states:::sagemaker:createModel', | ||
'End': True | ||
} | ||
|
||
|
||
@patch('botocore.client.BaseClient._make_api_call', new=mock_boto_api_call) | ||
@patch.object(boto3.session.Session, 'region_name', 'us-east-1') | ||
def test_get_expected_model_with_framework_estimator(tensorflow_estimator): | ||
|
@@ -859,6 +1001,29 @@ def test_model_step_creation(pca_model): | |
} | ||
|
||
|
||
@patch.object(boto3.session.Session, 'region_name', 'us-east-1') | ||
def test_model_step_creation_with_env(pca_model_with_env): | ||
step = ModelStep('Create model', model=pca_model_with_env, model_name='pca-model', tags=DEFAULT_TAGS) | ||
assert step.to_dict() == { | ||
'Type': 'Task', | ||
'Parameters': { | ||
'ExecutionRoleArn': EXECUTION_ROLE, | ||
'ModelName': 'pca-model', | ||
'PrimaryContainer': { | ||
'Environment': { | ||
'JobName': 'job_name', | ||
'ModelName': 'model_name' | ||
}, | ||
'Image': pca_model_with_env.image_uri, | ||
'ModelDataUrl': pca_model_with_env.model_data | ||
}, | ||
'Tags': DEFAULT_TAGS_LIST | ||
}, | ||
'Resource': 'arn:aws:states:::sagemaker:createModel', | ||
'End': True | ||
} | ||
|
||
|
||
@patch.object(boto3.session.Session, 'region_name', 'us-east-1') | ||
def test_endpoint_config_step_creation(pca_model): | ||
data_capture_config = DataCaptureConfig( | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CR description should also capture why we're doing this. It's not strictly related to the bug reported, but is solving an issue. Typically better to fix things separately, but when there are multiple fixes, they need to be captured in the commit summary, along with rationale and any testing that was performed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated the description with the uncovered bug and what was done to fix it
I am open to move this fix to a separate PR - will go with what reviewers prefer since the review process has already begun
What do you think? @shivlaks @wong-a
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Given where we currently are, i'm not strongly opinionated because it feels more practical to go with it. It feels like it might just be busy work splitting them up in the current state.
Generally and going forward, I think we should keep changes contained and specific to the bug they address / feature they introduce for a few reasons:
having said that, i would probably still split it up if I were the one doing it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If it's already implemented and works, let's keep it in here, since it's a small enough change. Just update the PR description accordingly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done :)