Skip to content

Commit

Permalink
dvclive: reformat Python code
Browse files Browse the repository at this point in the history
  • Loading branch information
jorgeorpinel committed Mar 28, 2021
1 parent 5814b32 commit 3617e6f
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 18 deletions.
31 changes: 17 additions & 14 deletions content/docs/dvclive/dvclive-with-dvc.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,29 @@
Even though Dvclive does not require DVC, they can integrate in several useful
ways.

> In this section we will modify the [basic usage example](/doc/dvclive/usage)
> to see how DVC can cooperate with Dvclive module.
> In this section we reuse the finished
> [basic usage example](/doc/dvclive/usage) to see how DVC can cooperate with
> Dvclive.
```python
# train.py

import dvclive
from keras.callbacks import Callback
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers.core import Dense, Activation
from keras.utils import np_utils


class MetricsCallback(Callback):
def on_epoch_end(self, epoch: int, logs: dict = None):
logs = logs or {}
for metric, value in logs.items():
dvclive.log(metric, value)
dvclive.next_step()


def load_data():
(x_train, y_train), (x_test, y_test) = mnist.load_data()

Expand All @@ -23,35 +35,26 @@ def load_data():
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255

classes = 10
y_train = np_utils.to_categorical(y_train, classes)
y_test = np_utils.to_categorical(y_test, classes)
return (x_train, y_train), (x_test, y_test)


def get_model():
model = Sequential()

model.add(Dense(512, input_dim=784))
model.add(Activation('relu'))

model.add(Dense(10, input_dim=512))

model.add(Activation('softmax'))

model.compile(loss='categorical_crossentropy',
metrics=['accuracy'], optimizer='sgd')
return model


from keras.callbacks import Callback
import dvclive

class MetricsCallback(Callback):
def on_epoch_end(self, epoch: int, logs: dict = None):
logs = logs or {}
for metric, value in logs.items():
dvclive.log(metric, value)
dvclive.next_step()

(x_train, y_train), (x_test, y_test) = load_data()
model = get_model()

Expand Down
13 changes: 9 additions & 4 deletions content/docs/dvclive/usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ from keras.models import Sequential
from keras.layers.core import Dense, Activation
from keras.utils import np_utils


def load_data():
(x_train, y_train), (x_test, y_test) = mnist.load_data()

Expand All @@ -23,18 +24,19 @@ def load_data():
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255

classes = 10
y_train = np_utils.to_categorical(y_train, classes)
y_test = np_utils.to_categorical(y_test, classes)
return (x_train, y_train), (x_test, y_test)


def get_model():
model = Sequential()

model.add(Dense(512, input_dim=784))
model.add(Activation('relu'))

model.add(Dense(10, input_dim=512))

model.add(Activation('softmax'))

model.compile(loss='categorical_crossentropy',
Expand All @@ -59,11 +61,14 @@ log the `accuracy`, `loss`, `validation_accuracy` and `validation_loss` after
each epoch, so that we can observe how the training progresses.

In order to do that, we will provide a
[`Callback`](https://keras.io/api/callbacks/) for the `fit` method call:
[`Callback`](https://keras.io/api/callbacks/) for the `fit` method call (add
this to `train.py`):

```python
from keras.callbacks import Callback
import dvclive
from keras.callbacks import Callback


class MetricsCallback(Callback):
def on_epoch_end(self, epoch: int, logs: dict = None):
logs = logs or {}
Expand Down

0 comments on commit 3617e6f

Please sign in to comment.