Skip to content

Commit

Permalink
Orca: add save and load model doc to PyTorch Estimator quick-start. (i…
Browse files Browse the repository at this point in the history
…ntel-analytics#5504)

* feat: add save and load model doc to pytorch estimator quickstart.

* fix: fix typo.

* fix: fix typo

* fix: fix typo.

* feat: add doc to ray backend quickstart

* fix: fix typo.

* feat: add doc to ray backend quickstart
  • Loading branch information
lalalapotter authored and ForJadeForest committed Sep 20, 2022
1 parent f765583 commit e855423
Show file tree
Hide file tree
Showing 5 changed files with 90 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -130,4 +130,18 @@ for r in result:
print(r, ":", result[r])
```

### **Step 5: Save and Load the Model**

Save the Estimator states (including model and optimizer) to the provided model path.

```python
est.save("mnist_model")
```

Load the Estimator states (model and possibly with optimizer) from provided model path.

```python
est.load("mnist_model")
```

**Note:** You should call `stop_orca_context()` when your application finishes.
Original file line number Diff line number Diff line change
Expand Up @@ -132,4 +132,18 @@ for r in result:
print(r, ":", result[r])
```

### **Step 5: Save and Load the Model**

Save the Estimator states (including model and optimizer) to the provided model path.

```python
est.save("mnist_model")
```

Load the Estimator states (model and possibly with optimizer) from the provided model path.

```python
est.load("mnist_model")
```

**Note:** You should call `stop_orca_context()` when your application finishes.
Original file line number Diff line number Diff line change
Expand Up @@ -130,20 +130,20 @@ You could also save the model to Keras H5 format by passing `save_format='h5'`

```python
# save model in SavedModel format
estimator.save("/tmp/cifar10_model")
est.save("/tmp/cifar10_model")

# load model
estimator.load("/tmp/cifar10_model")
est.load("/tmp/cifar10_model")
```

**2. HDF5 format**

```python
# save model in H5 format
estimator.save("/tmp/cifar10_model.h5", save_format='h5')
est.save("/tmp/cifar10_model.h5", save_format='h5')

# load model
estimator.load("/tmp/cifar10_model.h5")
est.load("/tmp/cifar10_model.h5")
```

That's it, the same code can run seamlessly in your local laptop and to distribute K8s or Hadoop cluster.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,35 @@
"The accuracy of this model has reached 98%."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "uL9QMvCRLjae"
},
"source": [
"### **Step 5: Save the model**"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "icshLUG2LlR6"
},
"source": [
"Save the Estimator states (including model and optimizer) to the provided model path."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "sQjoelcfL-Mv"
},
"outputs": [],
"source": [
"est.save(\"mnist_model\")"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down
29 changes: 29 additions & 0 deletions python/orca/colab-notebook/quickstart/pytorch_lenet_mnist.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,35 @@
"The accuracy of this model has reached 98%."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "qhLhJ4vWc-95"
},
"source": [
"### **Step 5: Save the Model**"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "DmmtHcPodLA3"
},
"source": [
"Save the Estimator states (including model and optimizer) to the provided model path."
]
},
{
"cell_type": "code",
"metadata": {
"id": "WUbsriiBdprT"
},
"source": [
"est.save(\"mnist_model\")"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
Expand Down

0 comments on commit e855423

Please sign in to comment.