Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a docstring for compile #838

Merged
merged 1 commit into from
Sep 5, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 86 additions & 0 deletions keras_core/trainers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,92 @@ def compile(
steps_per_execution=1,
jit_compile="auto",
):
"""Configures the model for training.

Example:

```python
model.compile(
optimizer=keras_core.optimizers.Adam(learning_rate=1e-3),
loss=keras_core.losses.BinaryCrossentropy(),
metrics=[
keras_core.metrics.BinaryAccuracy(),
keras_core.metrics.FalseNegatives(),
],
)
```

Args:
optimizer: String (name of optimizer) or optimizer instance. See
`keras_core.optimizers`.
loss: Loss function. May be a string (name of loss function), or
a `keras_core.losses.Loss` instance. See `keras_core.losses`. A
loss function is any callable with the signature
`loss = fn(y_true, y_pred)`, where `y_true` are the ground truth
values, and `y_pred` are the model's predictions.
`y_true` should have shape `(batch_size, d0, .. dN)`
(except in the case of sparse loss functions such as
sparse categorical crossentropy which expects integer arrays of
shape `(batch_size, d0, .. dN-1)`).
`y_pred` should have shape `(batch_size, d0, .. dN)`.
The loss function should return a float tensor.
loss_weights: Optional list or dictionary specifying scalar
coefficients (Python floats) to weight the loss contributions of
different model outputs. The loss value that will be minimized
by the model will then be the *weighted sum* of all individual
losses, weighted by the `loss_weights` coefficients. If a list,
it is expected to have a 1:1 mapping to the model's outputs. If
a dict, it is expected to map output names (strings) to scalar
coefficients.
metrics: List of metrics to be evaluated by the model during
training and testing. Each of this can be a string (name of a
built-in function), function or a `keras_core.metrics.Metric`
instance. See `keras_core.metrics`. Typically you will use
`metrics=['accuracy']`. A function is any callable with the
signature `result = fn(y_true, _pred)`. To specify different
metrics for different outputs of a multi-output model, you could
also pass a dictionary, such as
`metrics={'a':'accuracy', 'b':['accuracy', 'mse']}`.
You can also pass a list to specify a metric or a list of
metrics for each output, such as
`metrics=[['accuracy'], ['accuracy', 'mse']]`
or `metrics=['accuracy', ['accuracy', 'mse']]`. When you pass
the strings 'accuracy' or 'acc', we convert this to one of
`keras_core.metrics.BinaryAccuracy`,
`keras_core.metrics.CategoricalAccuracy`,
`keras_core.metrics.SparseCategoricalAccuracy` based on the
shapes of the targets and of the model output. We do a similar
conversion for the strings 'crossentropy' and 'ce' as well.
The metrics passed here are evaluated without sample weighting;
if you would like sample weighting to apply, you can specify
your metrics via the `weighted_metrics` argument instead.
weighted_metrics: List of metrics to be evaluated and weighted by
`sample_weight` or `class_weight` during training and testing.
run_eagerly: Bool. If `True`, this `Model`'s logic will never be
compiled (e.g. with `tf.function` or `jax.jit`). Recommended to
leave this as `False` when training for best performance, and
`True` when debugging.
steps_per_execution: Int. The number of batches to run
during each a single compiled function call. Running multiple
batches inside a single a single compiled function call can
greatly improve performance on TPUs or small models with a large
Python overhead. At most, one full epoch will be run each
execution. If a number larger than the size of the epoch is
passed, the execution will be truncated to the size of the
epoch. Note that if `steps_per_execution` is set to `N`,
`Callback.on_batch_begin` and `Callback.on_batch_end` methods
will only be called every `N` batches (i.e. before/after
each compiled function execution).
jit_compile: Bool or `"auto"`. Whether to use XLA compilation when
compiling a model. This value should currently never be `True`
on the torch backed, and should always be `True` or `"auto"` on
the jax backend. On tensorflow, this value can be `True` or
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not entirely accurate. You can pass jit_compile=False with JAX, which means that the model will train/evaluate/predict in eager execution (so jit_compile is equivalent to not run_eagerly). TF is special in the sense that it supports the unique combination jit_compile == False and run_eagerly == False, which means "classic graph mode".

`False`, and will toggle the `jit_compile` option for any
`tf.function` owned by the model. See
https://www.tensorflow.org/xla/tutorials/jit_compile for more
details. If `"auto"`, XLA compilation will be enabled if the
backend supports it, and disabled otherwise.
"""
self.optimizer = optimizers.get(optimizer)
if hasattr(self, "output_names"):
output_names = self.output_names
Expand Down