Skip to content

Commit

Permalink
Updated README + added tests to CI
Browse files Browse the repository at this point in the history
  • Loading branch information
andreped committed Aug 28, 2023
1 parent 7806558 commit df7564c
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 4 deletions.
3 changes: 3 additions & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,6 @@ jobs:

- name: Test library accessibility
run: python -c "from t_loss import TLoss"

- name: Run tests
run: pytest -v tests/
23 changes: 19 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,30 @@ pip install git+https://github.com/andreped/t-loss-tf.git
## Usage
As the t-loss contains a trainable parameter, in keras the loss needed to be implemented as a custom layer.
Hence, instead of setting the loss as normally through `model.compile(loss=[...])`, just add it to the model
at appropriate place.
at appropriate place. An example can be seen below:

```python
import tensorflow as tf
from t_loss import TLoss

model = tf.keras.Sequential()
model.add(TLoss(image_size=512))
[...]
input_shape = (16, 16, 1)
# create dummy inputs and GTs
x = tf.ones((32,) + input_shape, dtype="float32")
y = tf.ones((32,) + input_shape, dtype="float32")

input_x = tf.keras.Input(shape=input_shape)
input_y = tf.keras.Input(shape=input_shape)

z = tf.keras.layers.Conv2D(filters=4, kernel_size=(1, 1), activation="relu")(input_x)
z = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(z)
z = tf.keras.layers.UpSampling2D(size=(2, 2))(z)
z = tf.keras.layers.Conv2D(filters=1, kernel_size=(1, 1), activation="sigmoid")(z)
z = TLoss(tensor_shape=input_shape, image_size=input_shape[0])(z, input_y)
model = tf.keras.Model(inputs=[input_x, input_y], outputs=[z])
print(model.summary())

model.compile(optimizer="adam")
model.fit(x=[x, y], y=y, batch_size=2, epochs=1, verbose="auto")
```

## License
Expand Down

0 comments on commit df7564c

Please sign in to comment.