Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Orca: add save and load model doc to PyTorch Estimator quick-start. #5504

Merged
merged 7 commits into from
Aug 24, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -132,4 +132,18 @@ for r in result:
print(r, ":", result[r])
```

### **Step 5: Save and Load the Model **

Save the Estimator state (including model and optimizer) to the provided model path.
hkvision marked this conversation as resolved.
Show resolved Hide resolved

```python
est.save("mnist_model")
```

Load the Estimator state (model and possibly with optimizer) from the provided model path.

```python
est.load("mnist_model")
```

**Note:** You should call `stop_orca_context()` when your application finishes.
Original file line number Diff line number Diff line change
Expand Up @@ -130,20 +130,20 @@ You could also save the model to Keras H5 format by passing `save_format='h5'`

```python
# save model in SavedModel format
estimator.save("/tmp/cifar10_model")
est.save("/tmp/cifar10_model")

# load model
estimator.load("/tmp/cifar10_model")
est.load("/tmp/cifar10_model")
```

**2. HDF5 format**

```python
# save model in H5 format
estimator.save("/tmp/cifar10_model.h5", save_format='h5')
est.save("/tmp/cifar10_model.h5", save_format='h5')

# load model
estimator.load("/tmp/cifar10_model.h5")
est.load("/tmp/cifar10_model.h5")
```

That's it, the same code can run seamlessly in your local laptop and to distribute K8s or Hadoop cluster.
Expand Down
30 changes: 30 additions & 0 deletions python/orca/colab-notebook/quickstart/pytorch_lenet_mnist.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,36 @@
"The accuracy of this model has reached 98%."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "qhLhJ4vWc-95"
},
"source": [
"### **Step 5: Save the Model**"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "DmmtHcPodLA3"
},
"source": [
"Save the Estimator state (including model and optimizer) to the provided model path.\n",
"\n"
]
},
{
"cell_type": "code",
"metadata": {
"id": "WUbsriiBdprT"
},
"source": [
"est.save(\"mnist_model\")"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
Expand Down