diff --git a/keras_core/trainers/trainer.py b/keras_core/trainers/trainer.py index 956d7428a..457b27793 100644 --- a/keras_core/trainers/trainer.py +++ b/keras_core/trainers/trainer.py @@ -34,6 +34,92 @@ def compile( steps_per_execution=1, jit_compile="auto", ): + """Configures the model for training. + + Example: + + ```python + model.compile( + optimizer=keras_core.optimizers.Adam(learning_rate=1e-3), + loss=keras_core.losses.BinaryCrossentropy(), + metrics=[ + keras_core.metrics.BinaryAccuracy(), + keras_core.metrics.FalseNegatives(), + ], + ) + ``` + + Args: + optimizer: String (name of optimizer) or optimizer instance. See + `keras_core.optimizers`. + loss: Loss function. May be a string (name of loss function), or + a `keras_core.losses.Loss` instance. See `keras_core.losses`. A + loss function is any callable with the signature + `loss = fn(y_true, y_pred)`, where `y_true` are the ground truth + values, and `y_pred` are the model's predictions. + `y_true` should have shape `(batch_size, d0, .. dN)` + (except in the case of sparse loss functions such as + sparse categorical crossentropy which expects integer arrays of + shape `(batch_size, d0, .. dN-1)`). + `y_pred` should have shape `(batch_size, d0, .. dN)`. + The loss function should return a float tensor. + loss_weights: Optional list or dictionary specifying scalar + coefficients (Python floats) to weight the loss contributions of + different model outputs. The loss value that will be minimized + by the model will then be the *weighted sum* of all individual + losses, weighted by the `loss_weights` coefficients. If a list, + it is expected to have a 1:1 mapping to the model's outputs. If + a dict, it is expected to map output names (strings) to scalar + coefficients. + metrics: List of metrics to be evaluated by the model during + training and testing. Each of this can be a string (name of a + built-in function), function or a `keras_core.metrics.Metric` + instance. See `keras_core.metrics`. Typically you will use + `metrics=['accuracy']`. A function is any callable with the + signature `result = fn(y_true, _pred)`. To specify different + metrics for different outputs of a multi-output model, you could + also pass a dictionary, such as + `metrics={'a':'accuracy', 'b':['accuracy', 'mse']}`. + You can also pass a list to specify a metric or a list of + metrics for each output, such as + `metrics=[['accuracy'], ['accuracy', 'mse']]` + or `metrics=['accuracy', ['accuracy', 'mse']]`. When you pass + the strings 'accuracy' or 'acc', we convert this to one of + `keras_core.metrics.BinaryAccuracy`, + `keras_core.metrics.CategoricalAccuracy`, + `keras_core.metrics.SparseCategoricalAccuracy` based on the + shapes of the targets and of the model output. We do a similar + conversion for the strings 'crossentropy' and 'ce' as well. + The metrics passed here are evaluated without sample weighting; + if you would like sample weighting to apply, you can specify + your metrics via the `weighted_metrics` argument instead. + weighted_metrics: List of metrics to be evaluated and weighted by + `sample_weight` or `class_weight` during training and testing. + run_eagerly: Bool. If `True`, this `Model`'s logic will never be + compiled (e.g. with `tf.function` or `jax.jit`). Recommended to + leave this as `False` when training for best performance, and + `True` when debugging. + steps_per_execution: Int. The number of batches to run + during each a single compiled function call. Running multiple + batches inside a single a single compiled function call can + greatly improve performance on TPUs or small models with a large + Python overhead. At most, one full epoch will be run each + execution. If a number larger than the size of the epoch is + passed, the execution will be truncated to the size of the + epoch. Note that if `steps_per_execution` is set to `N`, + `Callback.on_batch_begin` and `Callback.on_batch_end` methods + will only be called every `N` batches (i.e. before/after + each compiled function execution). + jit_compile: Bool or `"auto"`. Whether to use XLA compilation when + compiling a model. This value should currently never be `True` + on the torch backed, and should always be `True` or `"auto"` on + the jax backend. On tensorflow, this value can be `True` or + `False`, and will toggle the `jit_compile` option for any + `tf.function` owned by the model. See + https://www.tensorflow.org/xla/tutorials/jit_compile for more + details. If `"auto"`, XLA compilation will be enabled if the + backend supports it, and disabled otherwise. + """ self.optimizer = optimizers.get(optimizer) if hasattr(self, "output_names"): output_names = self.output_names