Skip to content

Commit

Permalink
Support image and audio information in task summaries (#1819)
Browse files Browse the repository at this point in the history
  • Loading branch information
mattdangerw authored Sep 10, 2024
1 parent 6d19bb3 commit d0bb822
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 15 deletions.
4 changes: 4 additions & 0 deletions keras_nlp/src/layers/preprocessing/audio_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ class AudioConverter(PreprocessingLayer):

backbone_cls = None

def audio_shape(self):
"""Returns the preprocessed size of a single audio sample."""
return (None,)

@classproperty
def presets(cls):
"""List built-in presets for an `AudioConverter` subclass."""
Expand Down
4 changes: 4 additions & 0 deletions keras_nlp/src/layers/preprocessing/image_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ class ImageConverter(PreprocessingLayer):

backbone_cls = None

def image_size(self):
"""Returns the default size of a single image."""
return (None, None)

@classproperty
def presets(cls):
"""List built-in presets for an `ImageConverter` subclass."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,13 +73,17 @@ def __init__(
# By default, we just do a simple resize. Any model can subclass this
# layer for preprocessing of a raw image to a model image input.
self.resizing = keras.layers.Resizing(
height,
width,
height=height,
width=width,
crop_to_aspect_ratio=crop_to_aspect_ratio,
interpolation=interpolation,
data_format=data_format,
)

def image_size(self):
"""Returns the preprocessed size of a single image."""
return (self.resizing.height, self.resizing.width)

@preprocessing_function
def call(self, inputs):
return self.resizing(inputs)
Expand Down
49 changes: 36 additions & 13 deletions keras_nlp/src/models/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,14 +293,20 @@ def summary(
print_fn = print_msg

def highlight_number(x):
return f"[color(45)]{x}[/]" if x is None else f"[color(34)]{x}[/]"
if x is None:
f"[color(45)]{x}[/]"
return f"[color(34)]{x:,}[/]" # Format number with commas.

def highlight_symbol(x):
return f"[color(33)]{x}[/]"

def bold_text(x):
return f"[bold]{x}[/]"

def highlight_shape(shape):
highlighted = [highlight_number(x) for x in shape]
return "(" + ", ".join(highlighted) + ")"

if self.preprocessor:
# Create a rich console for printing. Capture for non-interactive logging.
if print_fn:
Expand All @@ -312,27 +318,44 @@ def bold_text(x):
console = rich_console.Console(highlight=False)

column_1 = rich_table.Column(
"Tokenizer (type)",
"Layer (type)",
justify="left",
width=int(0.5 * line_length),
width=int(0.6 * line_length),
)
column_2 = rich_table.Column(
"Vocab #",
"Config",
justify="right",
width=int(0.5 * line_length),
width=int(0.4 * line_length),
)
table = rich_table.Table(
column_1, column_2, width=line_length, show_lines=True
)

def add_layer(layer, info):
layer_name = markup.escape(layer.name)
layer_class = highlight_symbol(
markup.escape(layer.__class__.__name__)
)
table.add_row(
f"{layer_name} ({layer_class})",
info,
)

tokenizer = self.preprocessor.tokenizer
tokenizer_name = markup.escape(tokenizer.name)
tokenizer_class = highlight_symbol(
markup.escape(tokenizer.__class__.__name__)
)
table.add_row(
f"{tokenizer_name} ({tokenizer_class})",
highlight_number(f"{tokenizer.vocabulary_size():,}"),
)
if tokenizer:
info = "Vocab size: "
info += highlight_number(tokenizer.vocabulary_size())
add_layer(tokenizer, info)
image_converter = self.preprocessor.image_converter
if image_converter:
info = "Image size: "
info += highlight_shape(image_converter.image_size())
add_layer(image_converter, info)
audio_converter = self.preprocessor.audio_converter
if audio_converter:
info = "Audio shape: "
info += highlight_shape(audio_converter.audio_shape())
add_layer(audio_converter, info)

# Print the to the console.
preprocessor_name = markup.escape(self.preprocessor.name)
Expand Down
4 changes: 4 additions & 0 deletions keras_nlp/src/models/whisper/whisper_audio_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,10 @@ def __init__(
# `(num_fft_bins // 2 + 1, num_mels).`
self.mel_filters = self._get_mel_filters()

def audio_shape(self):
"""Returns the preprocessed size of a single audio sample."""
return (self.max_audio_length, self.num_mels)

def _get_mel_filters(self):
"""
Adapted from Hugging Face
Expand Down

0 comments on commit d0bb822

Please sign in to comment.