diff --git a/html2docx/image.py b/html2docx/image.py
index 145808d..a8465aa 100644
--- a/html2docx/image.py
+++ b/html2docx/image.py
@@ -1,10 +1,12 @@
+import base64
+import binascii
import http
import io
import pathlib
import time
import urllib.error
import urllib.request
-from typing import Dict, Optional
+from typing import Dict, Optional, cast
from docx.image.exceptions import UnrecognizedImageError
from docx.image.image import Image
@@ -17,11 +19,29 @@
MAX_IMAGE_SIZE = 10 * 1024 * 1024 # 10 MiB
+RFC_2397_BASE64 = ";base64"
-def load_image(src: str) -> io.BytesIO:
+
+def make_image(data: Optional[bytes]) -> io.BytesIO:
image_buffer = None
+ if data:
+ image_buffer = io.BytesIO(data)
+ try:
+ Image.from_blob(image_buffer.getbuffer())
+ except UnrecognizedImageError:
+ image_buffer = None
+
+ if not image_buffer:
+ broken_img_path = pathlib.Path(__file__).parent / "image-broken.png"
+ image_buffer = io.BytesIO(broken_img_path.read_bytes())
+
+ return image_buffer
+
+
+def load_external_image(src: str) -> Optional[bytes]:
+ data = None
retry = 3
- while retry and not image_buffer:
+ while retry and not data:
try:
with urllib.request.urlopen(src) as response:
size = response.getheader("Content-Length")
@@ -30,7 +50,7 @@ def load_image(src: str) -> io.BytesIO:
# Read up to MAX_IMAGE_SIZE when response does not contain
# the Content-Length header. The extra byte avoids an extra read to
# check whether the EOF was reached.
- data = response.read(MAX_IMAGE_SIZE + 1)
+ data = cast(bytes, response.read(MAX_IMAGE_SIZE + 1))
except (ValueError, http.client.HTTPException, urllib.error.HTTPError):
# ValueError: Invalid URL or non-integer Content-Length.
# HTTPException: Server does not speak HTTP properly.
@@ -43,19 +63,29 @@ def load_image(src: str) -> io.BytesIO:
time.sleep(1)
else:
if len(data) <= MAX_IMAGE_SIZE:
- image_buffer = io.BytesIO(data)
+ return data
+ return None
+
- if image_buffer:
+def load_inline_image(src: str) -> Optional[bytes]:
+ image_data = None
+ header_data = src.split(RFC_2397_BASE64 + ",", maxsplit=1)
+ if len(header_data) == 2:
+ data = header_data[1]
try:
- Image.from_blob(image_buffer.getbuffer())
- except UnrecognizedImageError:
- image_buffer = None
+ image_data = base64.b64decode(data, validate=True)
+ except (binascii.Error, ValueError):
+ # binascii.Error: Character outside of base64 set.
+ # ValueError: Character outside of ASCII.
+ pass
+ return image_data
- if not image_buffer:
- broken_img_path = pathlib.Path(__file__).parent / "image-broken.png"
- image_buffer = io.BytesIO(broken_img_path.read_bytes())
- return image_buffer
+def load_image(src: str) -> io.BytesIO:
+ image_bytes = (
+ load_inline_image(src) if src.startswith("data:") else load_external_image(src)
+ )
+ return make_image(image_bytes)
def image_size(
diff --git a/tests/test_image_size.py b/tests/test_image_size.py
index 54c9544..6f663fd 100644
--- a/tests/test_image_size.py
+++ b/tests/test_image_size.py
@@ -1,16 +1,13 @@
-from io import BytesIO
from math import ceil
from docx.shared import Inches
-from PIL import Image
from html2docx.image import USABLE_HEIGHT, USABLE_WIDTH, image_size
-from .utils import PROJECT_DIR
+from .utils import DPI, PROJECT_DIR, generate_image
broken_image = PROJECT_DIR / "html2docx" / "image-broken.png"
broken_image_bytes = broken_image.read_bytes()
-DPI = 72
def inches_to_px(inches: int, dpi: int = DPI) -> int:
@@ -21,13 +18,6 @@ def px_to_inches(px: int, dpi: int = DPI) -> int:
return ceil(px * Inches(1) / dpi)
-def generate_image(width: int, height: int, dpi=(DPI, DPI)) -> BytesIO:
- data = BytesIO()
- with Image.new("L", (width, height)) as image:
- image.save(data, format="png", dpi=dpi)
- return data
-
-
def test_one_px():
image = generate_image(width=1, height=1)
size = image_size(image, 1, 1)
diff --git a/tests/test_load_image.py b/tests/test_load_image.py
index 7212cab..8782c54 100644
--- a/tests/test_load_image.py
+++ b/tests/test_load_image.py
@@ -1,10 +1,11 @@
+import base64
import urllib.error
import urllib.request
from unittest import mock
from html2docx.image import load_image
-from .utils import PROJECT_DIR, TEST_DIR
+from .utils import PROJECT_DIR, TEST_DIR, generate_image
broken_image = PROJECT_DIR / "html2docx" / "image-broken.png"
broken_image_bytes = broken_image.read_bytes()
@@ -58,3 +59,47 @@ def test_bad_content_length(bad_content_length_server):
image_data = load_image(bad_content_length_server.base_url)
assert image_data.getbuffer() == broken_image_bytes
assert bad_content_length_server.httpd.request_count == 1
+
+
+def test_inline_base64():
+ image = generate_image(width=1, height=1)
+ image_b64 = base64.b64encode(image.getbuffer()).decode()
+ src = f"data:image/png;base64,{image_b64}"
+ image_data = load_image(src)
+ assert image_data.getbuffer() == image.getbuffer()
+
+
+def test_inline_non_ascii():
+ src = "data:image/png;base64,🦝"
+ image_data = load_image(src)
+ assert image_data.getbuffer() == broken_image_bytes
+
+
+def test_inline_non_base64():
+ src = "://example.org/"
+ image_data = load_image(src)
+ assert image_data.getbuffer() == broken_image_bytes
+
+
+def test_inline_unknown_encoding():
+ src = "data:image/png;unknown,foobar"
+ image_data = load_image(src)
+ assert image_data.getbuffer() == broken_image_bytes
+
+
+def test_inline_base64_marker_in_data():
+ src = "data:text/plain,this is not ;base64, encoded."
+ image_data = load_image(src)
+ assert image_data.getbuffer() == broken_image_bytes
+
+
+def test_inline_missing_comma():
+ src = "data:image/png;base64https://example.org/"
+ image_data = load_image(src)
+ assert image_data.getbuffer() == broken_image_bytes
+
+
+def test_unknown_scheme():
+ src = ""
+ image_data = load_image(src)
+ assert image_data.getbuffer() == broken_image_bytes
diff --git a/tests/utils.py b/tests/utils.py
index 916f68e..13b9df8 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -1,4 +1,15 @@
import pathlib
+from io import BytesIO
+
+from PIL import Image
TEST_DIR = pathlib.Path(__file__).parent.resolve(strict=True)
PROJECT_DIR = TEST_DIR.parent
+DPI = 72
+
+
+def generate_image(width: int, height: int, dpi=(DPI, DPI)) -> BytesIO:
+ data = BytesIO()
+ with Image.new("L", (width, height)) as image:
+ image.save(data, format="png", dpi=dpi)
+ return data