Skip to content

Commit

Permalink
Run black on components in pre-commit (#511)
Browse files Browse the repository at this point in the history
This has been bothering me for a while, so wanted to take advantage of
the lack of open PRs to finally activate black for the components :)
  • Loading branch information
RobbeSneyders authored Oct 11, 2023
1 parent 1c770a7 commit 785940e
Show file tree
Hide file tree
Showing 32 changed files with 393 additions and 256 deletions.
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ repos:
name: black
files: |
(?x)^(
components/.*|
src/.*|
examples/.*|
tests/.*|
Expand Down
47 changes: 27 additions & 20 deletions components/caption_images/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@

logger = logging.getLogger(__name__)

os.environ['TORCH_CUDNN_V8_API_DISABLED'] = "1"
os.environ["TORCH_CUDNN_V8_API_DISABLED"] = "1"


def process_image_batch(
images: np.ndarray,
*,
processor: BlipProcessor,
device: str,
images: np.ndarray,
*,
processor: BlipProcessor,
device: str,
) -> t.List[torch.Tensor]:
"""
Process image in batches to a list of tensors.
Expand Down Expand Up @@ -49,16 +49,19 @@ def transform(img: Image) -> BatchEncoding:

@torch.no_grad()
def caption_image_batch(
image_batch: t.List[torch.Tensor],
*,
model: BlipForConditionalGeneration,
processor: BlipProcessor,
max_new_tokens: int,
index: pd.Series,
image_batch: t.List[torch.Tensor],
*,
model: BlipForConditionalGeneration,
processor: BlipProcessor,
max_new_tokens: int,
index: pd.Series,
) -> pd.Series:
"""Caption a batch of images."""
input_batch = torch.cat(image_batch)
output_batch = model.generate(pixel_values=input_batch, max_new_tokens=max_new_tokens)
output_batch = model.generate(
pixel_values=input_batch,
max_new_tokens=max_new_tokens,
)
captions_batch = processor.batch_decode(output_batch, skip_special_tokens=True)

return pd.Series(captions_batch, index=index)
Expand All @@ -68,28 +71,32 @@ class CaptionImagesComponent(PandasTransformComponent):
"""Component that captions images using a model from the Hugging Face hub."""

def __init__(
self,
*_,
model_id: str,
batch_size: int,
max_new_tokens: int,
self,
*_,
model_id: str,
batch_size: int,
max_new_tokens: int,
):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"Device: {self.device}")

logger.info("Initialize model '%s'", model_id)
self.processor = BlipProcessor.from_pretrained(model_id)
self.model = BlipForConditionalGeneration.from_pretrained(model_id).to(self.device)
self.model = BlipForConditionalGeneration.from_pretrained(model_id).to(
self.device,
)

self.batch_size = batch_size
self.max_new_tokens = max_new_tokens

def transform(self, dataframe: pd.DataFrame) -> pd.DataFrame:

images = dataframe["images"]["data"]

results: t.List[pd.Series] = []
for batch in np.split(images, np.arange(self.batch_size, len(images), self.batch_size)):
for batch in np.split(
images,
np.arange(self.batch_size, len(images), self.batch_size),
):
if not batch.empty:
image_tensors = process_image_batch(
batch,
Expand Down
15 changes: 8 additions & 7 deletions components/caption_images/tests/test_caption_images.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,18 @@ def test_image_caption_component():
"https://cdn.pixabay.com/photo/2023/07/19/18/56/japanese-beetle-8137606_1280.png",
]
input_dataframe = pd.DataFrame(
{"images": {"data": [requests.get(url).content for url in image_urls]}})
{"images": {"data": [requests.get(url).content for url in image_urls]}},
)

expected_output_dataframe = pd.DataFrame(
data={("captions", "text"): {0: "a motorcycle", 1: "a beetle"}},
)
data={("captions", "text"): {0: "a motorcycle", 1: "a beetle"}},
)

component = CaptionImagesComponent(
model_id="Salesforce/blip-image-captioning-base",
batch_size=4,
max_new_tokens=2,
)
model_id="Salesforce/blip-image-captioning-base",
batch_size=4,
max_new_tokens=2,
)

output_dataframe = component.transform(input_dataframe)

Expand Down
68 changes: 45 additions & 23 deletions components/download_images/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,24 @@

logger = logging.getLogger(__name__)

dask.config.set(scheduler='processes')
dask.config.set(scheduler="processes")


class DownloadImagesComponent(PandasTransformComponent):
"""Component that downloads images based on URLs."""

def __init__(self,
*_,
timeout: int,
retries: int,
n_connections: int,
image_size: int,
resize_mode: str,
resize_only_if_bigger: bool,
min_image_size: int,
max_aspect_ratio: float,
):
def __init__(
self,
*_,
timeout: int,
retries: int,
n_connections: int,
image_size: int,
resize_mode: str,
resize_only_if_bigger: bool,
min_image_size: int,
max_aspect_ratio: float,
):
"""Component that downloads images from a list of URLs and executes filtering and resizing.
Args:
Expand Down Expand Up @@ -60,29 +61,46 @@ def __init__(self,
max_aspect_ratio=max_aspect_ratio,
)

async def download_image(self, url: str, *, semaphore: asyncio.Semaphore) -> t.Optional[bytes]:
async def download_image(
self,
url: str,
*,
semaphore: asyncio.Semaphore,
) -> t.Optional[bytes]:
url = url.strip()

user_agent_string = (
"Mozilla/5.0 (X11; Ubuntu; Linux x86_64; rv:72.0) Gecko/20100101 Firefox/72.0"
"Mozilla/5.0 (X11; Ubuntu; Linux x86_64; rv:72.0) "
"Gecko/20100101 Firefox/72.0 "
"(compatible; +https://github.com/ml6team/fondant)"
)
user_agent_string += " (compatible; +https://github.com/ml6team/fondant)"

transport = httpx.AsyncHTTPTransport(retries=self.retries)
async with httpx.AsyncClient(transport=transport, follow_redirects=True) as client:
async with httpx.AsyncClient(
transport=transport,
follow_redirects=True,
) as client:
try:
async with semaphore:
response = await client.get(url, timeout=self.timeout,
headers={"User-Agent": user_agent_string})
response = await client.get(
url,
timeout=self.timeout,
headers={"User-Agent": user_agent_string},
)
image_stream = response.content
except Exception as e:
logger.warning(f"Skipping {url}: {repr(e)}")
image_stream = None

return image_stream

async def download_and_resize_image(self, id_: str, url: str, *, semaphore: asyncio.Semaphore) \
-> t.Tuple[str, t.Optional[bytes], t.Optional[int], t.Optional[int]]:
async def download_and_resize_image(
self,
id_: str,
url: str,
*,
semaphore: asyncio.Semaphore,
) -> t.Tuple[str, t.Optional[bytes], t.Optional[int], t.Optional[int]]:
image_stream = await self.download_image(url, semaphore=semaphore)
if image_stream is not None:
image_stream, width, height = self.resizer(io.BytesIO(image_stream))
Expand All @@ -99,8 +117,10 @@ async def download_dataframe() -> None:
semaphore = asyncio.Semaphore(self.n_connections)

images = await asyncio.gather(
*[self.download_and_resize_image(id_, url, semaphore=semaphore)
for id_, url in zip(dataframe.index, dataframe["images"]["url"])],
*[
self.download_and_resize_image(id_, url, semaphore=semaphore)
for id_, url in zip(dataframe.index, dataframe["images"]["url"])
],
)
results.extend(images)

Expand All @@ -114,6 +134,8 @@ async def download_dataframe() -> None:

results_df = results_df.dropna()
results_df = results_df.set_index("id", drop=True)
results_df.columns = pd.MultiIndex.from_product([["images"], results_df.columns])
results_df.columns = pd.MultiIndex.from_product(
[["images"], results_df.columns],
)

return results_df
4 changes: 2 additions & 2 deletions components/download_images/tests/test_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ def test_transform(respx_mock):

# Mock httpx to prevent network calls and return test images
image_dir = "tests/images"
images = []
images = [
open(os.path.join(image_dir, image), "rb").read() for image in os.listdir(image_dir) # noqa
open(os.path.join(image_dir, image), "rb").read() # noqa
for image in os.listdir(image_dir)
]
for url, image in zip(urls, images):
respx_mock.get(url).mock(return_value=Response(200, content=image))
Expand Down
36 changes: 20 additions & 16 deletions components/embed_images/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@

logger = logging.getLogger(__name__)

os.environ['TORCH_CUDNN_V8_API_DISABLED'] = "1"
os.environ["TORCH_CUDNN_V8_API_DISABLED"] = "1"


def process_image_batch(
images: np.ndarray,
*,
processor: CLIPProcessor,
device: str,
images: np.ndarray,
*,
processor: CLIPProcessor,
device: str,
) -> t.List[torch.Tensor]:
"""
Process image in batches to a list of tensors.
Expand Down Expand Up @@ -51,10 +51,10 @@ def transform(img: Image) -> BatchEncoding:

@torch.no_grad()
def embed_image_batch(
image_batch: t.List[torch.Tensor],
*,
model: CLIPVisionModelWithProjection,
index: pd.Series,
image_batch: t.List[torch.Tensor],
*,
model: CLIPVisionModelWithProjection,
index: pd.Series,
) -> pd.Series:
"""Embed a batch of images."""
input_batch = torch.cat(image_batch)
Expand All @@ -67,10 +67,10 @@ class EmbedImagesComponent(PandasTransformComponent):
"""Component that embeds images using a CLIP model from the Hugging Face hub."""

def __init__(
self,
*_,
model_id: str,
batch_size: int,
self,
*_,
model_id: str,
batch_size: int,
):
"""
Args:
Expand All @@ -82,7 +82,9 @@ def __init__(

logger.info("Initialize model '%s'", model_id)
self.processor = CLIPProcessor.from_pretrained(model_id)
self.model = CLIPVisionModelWithProjection.from_pretrained(model_id).to(self.device)
self.model = CLIPVisionModelWithProjection.from_pretrained(model_id).to(
self.device,
)
logger.info("Model initialized")

self.batch_size = batch_size
Expand All @@ -91,7 +93,10 @@ def transform(self, dataframe: pd.DataFrame) -> pd.DataFrame:
images = dataframe["images"]["data"]

results: t.List[pd.Series] = []
for batch in np.split(images, np.arange(self.batch_size, len(images), self.batch_size)):
for batch in np.split(
images,
np.arange(self.batch_size, len(images), self.batch_size),
):
if not batch.empty:
image_tensors = process_image_batch(
batch,
Expand All @@ -105,5 +110,4 @@ def transform(self, dataframe: pd.DataFrame) -> pd.DataFrame:
).T
results.append(embeddings)


return pd.concat(results).to_frame(name=("embeddings", "data"))
19 changes: 11 additions & 8 deletions components/embedding_based_laion_retrieval/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@ class LAIONRetrievalComponent(PandasTransformComponent):
"""Component that retrieves image URLs from LAION-5B based on a set of CLIP embeddings."""

def __init__(
self,
*,
num_images: int,
aesthetic_score: int,
aesthetic_weight: float,
self,
*,
num_images: int,
aesthetic_score: int,
aesthetic_weight: float,
) -> None:
"""
Expand All @@ -41,8 +41,8 @@ def __init__(
)

def transform(
self,
dataframe: pd.DataFrame,
self,
dataframe: pd.DataFrame,
) -> pd.DataFrame:
"""Asynchronously retrieve image URLs and ids based on prompts in the provided dataframe."""
results: t.List[t.Tuple[str]] = []
Expand All @@ -53,7 +53,10 @@ async def async_query():
futures = [
loop.run_in_executor(
executor,
functools.partial(self.client.query, embedding_input=embedding.tolist()),
functools.partial(
self.client.query,
embedding_input=embedding.tolist(),
),
)
for embedding in dataframe["embeddings"]["data"]
]
Expand Down
20 changes: 15 additions & 5 deletions components/filter_comments/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,24 @@ class FilterCommentsComponent(PandasTransformComponent):
max_comments_ratio: The maximum code to comment ratio
"""

def __init__(self, *args, min_comments_ratio: float, max_comments_ratio: float) -> None:
def __init__(
self,
*args,
min_comments_ratio: float,
max_comments_ratio: float,
) -> None:
self.min_comments_ratio = min_comments_ratio
self.max_comments_ratio = max_comments_ratio

def transform(
self,
dataframe: pd.DataFrame,
self,
dataframe: pd.DataFrame,
) -> pd.DataFrame:
comment_to_code_ratio = dataframe["code"]["content"].apply(get_comments_to_code_ratio)
mask = comment_to_code_ratio.between(self.min_comments_ratio, self.max_comments_ratio)
comment_to_code_ratio = dataframe["code"]["content"].apply(
get_comments_to_code_ratio,
)
mask = comment_to_code_ratio.between(
self.min_comments_ratio,
self.max_comments_ratio,
)
return dataframe[mask]
Loading

0 comments on commit 785940e

Please sign in to comment.