Skip to content

Commit

Permalink
add docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
pavel-esir committed Aug 5, 2024
1 parent f006764 commit 5b6c620
Showing 1 changed file with 53 additions and 3 deletions.
56 changes: 53 additions & 3 deletions samples/python/multinomial_causal_lm/multinomial_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,30 +9,78 @@


class IterableStreamer(openvino_genai.StreamerBase):
"""
A custom streamer class for handling token streaming and detokenization with buffering.
Attributes:
tokenizer (Tokenizer): The tokenizer used for encoding and decoding tokens.
tokens_cache (list): A buffer to accumulate tokens for detokenization.
text_queue (Queue): A synchronized queue for storing decoded text chunks.
print_len (int): The length of the printed text to manage incremental decoding.
"""

def __init__(self, tokenizer):
"""
Initializes the IterableStreamer with the given tokenizer.
Args:
tokenizer (Tokenizer): The tokenizer to use for encoding and decoding tokens.
"""
super().__init__()
self.tokenizer = tokenizer
self.tokens_cache = []
self.text_queue = queue.Queue()
self.print_len = 0

def __iter__(self):
"""
Returns the iterator object itself.
"""
return self

def __next__(self):
# get() will be blocked until a token is available.
value = self.text_queue.get()
"""
Returns the next value from the text queue.
Returns:
str: The next decoded text chunk.
Raises:
StopIteration: If there are no more elements in the queue.
"""
value = self.text_queue.get() # get() will be blocked until a token is available.
if value is None:
raise StopIteration
return value

def get_stop_flag(self):
"""
Checks whether the generation process should be stopped.
Returns:
bool: Always returns False in this implementation.
"""
return False

def put_word(self, word: str):
"""
Puts a word into the text queue.
Args:
word (str): The word to put into the queue.
"""
self.text_queue.put(word)

def put(self, token_id: int) -> bool:
"""
Processes a token and manages the decoding buffer. Adds decoded text to the queue.
Args:
token_id (int): The token_id to process.
Returns:
bool: True if generation should be stopped, False otherwise.
"""
self.tokens_cache.append(token_id)
text = self.tokenizer.decode(self.tokens_cache)

Expand Down Expand Up @@ -60,7 +108,9 @@ def put(self, token_id: int) -> bool:
return False # False means continue generation

def end(self):
# Flush residual tokens from the buffer.
"""
Flushes residual tokens from the buffer and puts a None value in the queue to signal the end.
"""
text = self.tokenizer.decode(self.tokens_cache)
if len(text) > self.print_len:
word = text[self.print_len:]
Expand Down

0 comments on commit 5b6c620

Please sign in to comment.