diff --git a/html2docx/image.py b/html2docx/image.py index 145808d..a8465aa 100644 --- a/html2docx/image.py +++ b/html2docx/image.py @@ -1,10 +1,12 @@ +import base64 +import binascii import http import io import pathlib import time import urllib.error import urllib.request -from typing import Dict, Optional +from typing import Dict, Optional, cast from docx.image.exceptions import UnrecognizedImageError from docx.image.image import Image @@ -17,11 +19,29 @@ MAX_IMAGE_SIZE = 10 * 1024 * 1024 # 10 MiB +RFC_2397_BASE64 = ";base64" -def load_image(src: str) -> io.BytesIO: + +def make_image(data: Optional[bytes]) -> io.BytesIO: image_buffer = None + if data: + image_buffer = io.BytesIO(data) + try: + Image.from_blob(image_buffer.getbuffer()) + except UnrecognizedImageError: + image_buffer = None + + if not image_buffer: + broken_img_path = pathlib.Path(__file__).parent / "image-broken.png" + image_buffer = io.BytesIO(broken_img_path.read_bytes()) + + return image_buffer + + +def load_external_image(src: str) -> Optional[bytes]: + data = None retry = 3 - while retry and not image_buffer: + while retry and not data: try: with urllib.request.urlopen(src) as response: size = response.getheader("Content-Length") @@ -30,7 +50,7 @@ def load_image(src: str) -> io.BytesIO: # Read up to MAX_IMAGE_SIZE when response does not contain # the Content-Length header. The extra byte avoids an extra read to # check whether the EOF was reached. - data = response.read(MAX_IMAGE_SIZE + 1) + data = cast(bytes, response.read(MAX_IMAGE_SIZE + 1)) except (ValueError, http.client.HTTPException, urllib.error.HTTPError): # ValueError: Invalid URL or non-integer Content-Length. # HTTPException: Server does not speak HTTP properly. @@ -43,19 +63,29 @@ def load_image(src: str) -> io.BytesIO: time.sleep(1) else: if len(data) <= MAX_IMAGE_SIZE: - image_buffer = io.BytesIO(data) + return data + return None + - if image_buffer: +def load_inline_image(src: str) -> Optional[bytes]: + image_data = None + header_data = src.split(RFC_2397_BASE64 + ",", maxsplit=1) + if len(header_data) == 2: + data = header_data[1] try: - Image.from_blob(image_buffer.getbuffer()) - except UnrecognizedImageError: - image_buffer = None + image_data = base64.b64decode(data, validate=True) + except (binascii.Error, ValueError): + # binascii.Error: Character outside of base64 set. + # ValueError: Character outside of ASCII. + pass + return image_data - if not image_buffer: - broken_img_path = pathlib.Path(__file__).parent / "image-broken.png" - image_buffer = io.BytesIO(broken_img_path.read_bytes()) - return image_buffer +def load_image(src: str) -> io.BytesIO: + image_bytes = ( + load_inline_image(src) if src.startswith("data:") else load_external_image(src) + ) + return make_image(image_bytes) def image_size( diff --git a/tests/test_image_size.py b/tests/test_image_size.py index 54c9544..6f663fd 100644 --- a/tests/test_image_size.py +++ b/tests/test_image_size.py @@ -1,16 +1,13 @@ -from io import BytesIO from math import ceil from docx.shared import Inches -from PIL import Image from html2docx.image import USABLE_HEIGHT, USABLE_WIDTH, image_size -from .utils import PROJECT_DIR +from .utils import DPI, PROJECT_DIR, generate_image broken_image = PROJECT_DIR / "html2docx" / "image-broken.png" broken_image_bytes = broken_image.read_bytes() -DPI = 72 def inches_to_px(inches: int, dpi: int = DPI) -> int: @@ -21,13 +18,6 @@ def px_to_inches(px: int, dpi: int = DPI) -> int: return ceil(px * Inches(1) / dpi) -def generate_image(width: int, height: int, dpi=(DPI, DPI)) -> BytesIO: - data = BytesIO() - with Image.new("L", (width, height)) as image: - image.save(data, format="png", dpi=dpi) - return data - - def test_one_px(): image = generate_image(width=1, height=1) size = image_size(image, 1, 1) diff --git a/tests/test_load_image.py b/tests/test_load_image.py index 7212cab..8782c54 100644 --- a/tests/test_load_image.py +++ b/tests/test_load_image.py @@ -1,10 +1,11 @@ +import base64 import urllib.error import urllib.request from unittest import mock from html2docx.image import load_image -from .utils import PROJECT_DIR, TEST_DIR +from .utils import PROJECT_DIR, TEST_DIR, generate_image broken_image = PROJECT_DIR / "html2docx" / "image-broken.png" broken_image_bytes = broken_image.read_bytes() @@ -58,3 +59,47 @@ def test_bad_content_length(bad_content_length_server): image_data = load_image(bad_content_length_server.base_url) assert image_data.getbuffer() == broken_image_bytes assert bad_content_length_server.httpd.request_count == 1 + + +def test_inline_base64(): + image = generate_image(width=1, height=1) + image_b64 = base64.b64encode(image.getbuffer()).decode() + src = f"data:image/png;base64,{image_b64}" + image_data = load_image(src) + assert image_data.getbuffer() == image.getbuffer() + + +def test_inline_non_ascii(): + src = "data:image/png;base64,🦝" + image_data = load_image(src) + assert image_data.getbuffer() == broken_image_bytes + + +def test_inline_non_base64(): + src = "://example.org/" + image_data = load_image(src) + assert image_data.getbuffer() == broken_image_bytes + + +def test_inline_unknown_encoding(): + src = "data:image/png;unknown,foobar" + image_data = load_image(src) + assert image_data.getbuffer() == broken_image_bytes + + +def test_inline_base64_marker_in_data(): + src = "data:text/plain,this is not ;base64, encoded." + image_data = load_image(src) + assert image_data.getbuffer() == broken_image_bytes + + +def test_inline_missing_comma(): + src = "data:image/png;base64https://example.org/" + image_data = load_image(src) + assert image_data.getbuffer() == broken_image_bytes + + +def test_unknown_scheme(): + src = "" + image_data = load_image(src) + assert image_data.getbuffer() == broken_image_bytes diff --git a/tests/utils.py b/tests/utils.py index 916f68e..13b9df8 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,4 +1,15 @@ import pathlib +from io import BytesIO + +from PIL import Image TEST_DIR = pathlib.Path(__file__).parent.resolve(strict=True) PROJECT_DIR = TEST_DIR.parent +DPI = 72 + + +def generate_image(width: int, height: int, dpi=(DPI, DPI)) -> BytesIO: + data = BytesIO() + with Image.new("L", (width, height)) as image: + image.save(data, format="png", dpi=dpi) + return data