Skip to content

Commit

Permalink
Orca: add save and load model doc to PyTorch Estimator quick-start. (#…
Browse files Browse the repository at this point in the history
…5504)

* feat: add save and load model doc to pytorch estimator quickstart.

* fix: fix typo.

* fix: fix typo

* fix: fix typo.

* feat: add doc to ray backend quickstart

* fix: fix typo.

* feat: add doc to ray backend quickstart
  • Loading branch information
lalalapotter authored Aug 24, 2022
1 parent 584ee60 commit 57d45e5
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -130,4 +130,18 @@ for r in result:
print(r, ":", result[r])
```

### **Step 5: Save and Load the Model**

Save the Estimator states (including model and optimizer) to the provided model path.

```python
est.save("mnist_model")
```

Load the Estimator states (model and possibly with optimizer) from provided model path.

```python
est.load("mnist_model")
```

**Note:** You should call `stop_orca_context()` when your application finishes.
Original file line number Diff line number Diff line change
Expand Up @@ -132,4 +132,18 @@ for r in result:
print(r, ":", result[r])
```

### **Step 5: Save and Load the Model**

Save the Estimator states (including model and optimizer) to the provided model path.

```python
est.save("mnist_model")
```

Load the Estimator states (model and possibly with optimizer) from the provided model path.

```python
est.load("mnist_model")
```

**Note:** You should call `stop_orca_context()` when your application finishes.
Original file line number Diff line number Diff line change
Expand Up @@ -130,20 +130,20 @@ You could also save the model to Keras H5 format by passing `save_format='h5'`

```python
# save model in SavedModel format
estimator.save("/tmp/cifar10_model")
est.save("/tmp/cifar10_model")

# load model
estimator.load("/tmp/cifar10_model")
est.load("/tmp/cifar10_model")
```

**2. HDF5 format**

```python
# save model in H5 format
estimator.save("/tmp/cifar10_model.h5", save_format='h5')
est.save("/tmp/cifar10_model.h5", save_format='h5')

# load model
estimator.load("/tmp/cifar10_model.h5")
est.load("/tmp/cifar10_model.h5")
```

That's it, the same code can run seamlessly in your local laptop and to distribute K8s or Hadoop cluster.
Expand Down

0 comments on commit 57d45e5

Please sign in to comment.