Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Large scale controlnet #260

Merged
merged 14 commits into from
Jul 6, 2023
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,7 @@ args:
aesthetic_weight:
description: Weight of the aesthetic embedding when added to the query, between 0 and 1
type: float
url:
description: The url of the backend clip retrieval service, defaults to the public service
type: str
default: https://knn.laion.ai/knn-service
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It doesn't seem like the main.py script has a default

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The defaults defined here translate internally to defaults defined in the argparser since kfp always requires a given provided argument if specified and cannot be empty.

parser.add_argument("--url", default="https://knn.laion.ai/knn-service")

The values defined in the argument parser generally take precedence over the default values defined in the main.py file so adding them there can be a bit misleading (e.g. if the user attempts to change them, the default values won't be used).

6 changes: 4 additions & 2 deletions components/prompt_based_laion_retrieval/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def setup(
num_images: int,
aesthetic_score: int,
aesthetic_weight: float,
url: str,
) -> None:
"""

Expand All @@ -30,10 +31,11 @@ def setup(
between 0 and 9.
aesthetic_weight: weight of the aesthetic embedding to add to the query,
between 0 and 1.
url: The url of the backend clip retrieval service, defaults to the public clip url.
"""
self.client = ClipClient(
url="https://knn.laion.ai/knn-service",
indice_name="laion5B-L-14",
url=url,
indice_name="laion5B", #TODO:revert back to laion5b-L-14 after backend correction
num_images=num_images,
aesthetic_score=aesthetic_score,
aesthetic_weight=aesthetic_weight,
Expand Down
13 changes: 9 additions & 4 deletions components/segment_images/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,15 @@
logger = logging.getLogger(__name__)


def convert_to_rgb(seg: np.array):
def convert_to_rgb(seg: np.array) -> bytes:
"""
Converts a 2D segmentation to a RGB one which makes it possible to visualize it.

Args:
seg: 2D segmentation map as a NumPy array.

Returns:
color_seg: 3D segmentation map contain RGB values for each pixel.
color_seg: the RGB segmentation map as a binary string
"""
color_seg = np.zeros(
(seg.shape[0], seg.shape[1], 3), dtype=np.uint8,
Expand All @@ -32,9 +32,13 @@ def convert_to_rgb(seg: np.array):
for label, color in enumerate(palette):
color_seg[seg == label, :] = color

color_seg = color_seg.astype(np.uint8).tobytes()
color_seg = color_seg.astype(np.uint8)
image = Image.fromarray(color_seg).convert('RGB')

return color_seg
crop_bytes = io.BytesIO()
image.save(crop_bytes, format="JPEG")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this actually save the image to disk?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, this just saves it to crop_bytes which is a BytesIO object (in-memory buffer to store the image in binary format)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok makes sense, thanks!


return crop_bytes.getvalue()


def process_image(image: bytes, *, processor: SegformerImageProcessor, device: str) -> torch.Tensor:
Expand All @@ -46,6 +50,7 @@ def process_image(image: bytes, *, processor: SegformerImageProcessor, device: s
processor: The processor object for transforming the image.
device: The device to move the transformed image to.
"""

def load(img: bytes) -> Image:
"""Load the bytestring as an image."""
bytes_ = io.BytesIO(img)
Expand Down
3 changes: 2 additions & 1 deletion components/write_to_hf_hub/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

# Define the schema for the struct using PyArrow
import huggingface_hub
from datasets.features.features import generate_from_arrow_type
from PIL import Image

from fondant.component import WriteComponent
Expand Down Expand Up @@ -71,7 +72,7 @@ def write(
if image_column_names and column_name in image_column_names:
schema_dict[column_name] = datasets.Image()
else:
schema_dict[column_name] = datasets.Value(str(field.type.value))
schema_dict[column_name] = generate_from_arrow_type(field.type.value)

schema = datasets.Features(schema_dict).arrow_schema
dataframe = dataframe[write_columns]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,7 @@ consumes:
segmentations:
fields:
data:
type: array
items:
type: binary
type: binary

args:
hf_token:
Expand Down
9 changes: 6 additions & 3 deletions examples/pipelines/controlnet-interior-design/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,12 @@
)
laion_retrieval_op = ComponentOp.from_registry(
name="prompt_based_laion_retrieval",
arguments={"num_images": 2, "aesthetic_score": 9, "aesthetic_weight": 0.5},
arguments={
"num_images": 2,
"aesthetic_score": 9,
"aesthetic_weight": 0.5,
"url": None,
},
)
download_images_op = ComponentOp.from_registry(
name="download_images",
Expand Down Expand Up @@ -63,8 +68,6 @@
"hf_token": "hf_token",
"image_column_names": ["images_data"],
},
number_of_gpus=1,
node_pool_name="model-inference-pool",
)

pipeline = Pipeline(pipeline_name=pipeline_name, base_path=PipelineConfigs.BASE_PATH)
Expand Down