Skip to content

Commit

Permalink
BUG: Improve PdfWriter handing of context manager
Browse files Browse the repository at this point in the history
  • Loading branch information
pubpub-zz committed Oct 20, 2024
1 parent c9dda9a commit fb1ee44
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 30 deletions.
78 changes: 48 additions & 30 deletions pypdf/_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,11 +152,16 @@ class PdfWriter(PdfDocCommon):
Typically data is added from a :class:`PdfReader<pypdf.PdfReader>`.
Args:
* : 1st argument is assigned to fileobj or clone_from based on context:
assigned to clone_from if str/path to a non empty file or stream or PdfReader
else assigned to fileobj.
fileobj: output file/stream. To be used with context manager only.
clone_from: identical to fileobj (for compatibility)
incremental: If true, loads the document and set the PdfWriter in incremental mode.
When writing incrementally, the original document is written first and new/modified
content is appended. To be used for signed document/forms to keep signature valid.
Expand All @@ -166,6 +171,7 @@ class PdfWriter(PdfDocCommon):

def __init__(
self,
*args: Any,
fileobj: Union[None, PdfReader, StrByteType, Path] = "",
clone_from: Union[None, PdfReader, StrByteType, Path] = None,
incremental: bool = False,
Expand Down Expand Up @@ -202,50 +208,65 @@ def __init__(
self._ID: Union[ArrayObject, None] = None
self._info_obj: Optional[PdfObject]

if self.incremental:
if isinstance(fileobj, (str, Path)):
with open(fileobj, "rb") as f:
fileobj = BytesIO(f.read(-1))
if isinstance(fileobj, BytesIO):
fileobj = PdfReader(fileobj)
if not isinstance(fileobj, PdfReader):
raise PyPdfError("Invalid type for incremental mode")
self._reader = fileobj # prev content is in _reader.stream
self._header = fileobj.pdf_header.encode()
self._readonly = True # !!!TODO: to be analysed
else:
self._header = b"%PDF-1.3"
self._info_obj = self._add_object(
DictionaryObject(
{NameObject("/Producer"): create_string_object("pypdf")}
)
)
manualset_fileobj = True
if len(args) > 0:
if fileobj == "":
fileobj = args[0]
manualset_fileobj = False
elif clone_from is None:
clone_from = args[0]

def _get_clone_from(
fileobj: Union[None, PdfReader, str, Path, IO[Any], BytesIO],
clone_from: Union[None, PdfReader, str, Path, IO[Any], BytesIO],
) -> Union[None, PdfReader, str, Path, IO[Any], BytesIO]:
if isinstance(fileobj, (str, Path, IO, BytesIO)) and (
fileobj == "" or clone_from is not None
manualset_fileobj: bool,
) -> Tuple[
Union[None, PdfReader, str, Path, IO[Any], BytesIO],
Union[None, PdfReader, str, Path, IO[Any], BytesIO],
]:
if manualset_fileobj or (
isinstance(fileobj, (str, Path, IO, BytesIO))
and (fileobj in ("", None) or clone_from is not None)
):
return clone_from
return clone_from, fileobj
cloning = True
if isinstance(fileobj, (str, Path)) and (
not Path(str(fileobj)).exists()
or Path(str(fileobj)).stat().st_size == 0
):
cloning = False

if isinstance(fileobj, (IO, BytesIO)):
t = fileobj.tell()
fileobj.seek(-1, 2)
if fileobj.tell() == 0:
cloning = False
fileobj.seek(t, 0)
if cloning:
clone_from = fileobj
return clone_from
return fileobj, None
return clone_from, fileobj

clone_from, fileobj = _get_clone_from(fileobj, clone_from, manualset_fileobj)

if self.incremental:
if isinstance(clone_from, (str, Path)):
with open(clone_from, "rb") as f:
clone_from = BytesIO(f.read(-1))
if isinstance(clone_from, (IO, BytesIO)):
clone_from = PdfReader(clone_from)
if not isinstance(clone_from, PdfReader):
raise PyPdfError("Invalid type for incremental mode")
self._reader = clone_from # prev content is in _reader.stream
self._header = clone_from.pdf_header.encode()
self._readonly = True # !!!TODO: to be analysed
else:
self._header = b"%PDF-1.3"
self._info_obj = self._add_object(
DictionaryObject(
{NameObject("/Producer"): create_string_object("pypdf")}
)
)

clone_from = _get_clone_from(fileobj, clone_from)
# to prevent overwriting
self.temp_fileobj = fileobj
self.fileobj = ""
Expand Down Expand Up @@ -354,10 +375,7 @@ def xmp_metadata(self, value: Optional[XmpInformation]) -> None:

def __enter__(self) -> "PdfWriter":
"""Store that writer is initialized by 'with'."""
t = self.temp_fileobj
self.__init__() # type: ignore
self.with_as_usage = True
self.fileobj = t # type: ignore
return self

def __exit__(
Expand Down Expand Up @@ -1393,7 +1411,7 @@ def write(self, stream: Union[Path, StrByteType]) -> Tuple[bool, IO[Any]]:

self.write_stream(stream)

if self.with_as_usage:
if my_file:
stream.close()

return my_file, stream
Expand Down
60 changes: 60 additions & 0 deletions tests/test_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2480,3 +2480,63 @@ def test_append_pdf_with_dest_without_page(caplog):
writer.append(reader)
assert "/__WKANCHOR_8" not in writer.named_destinations
assert len(writer.named_destinations) == 3


def test_writer_contextmanager():
"""To test the writer with context manager, cf #2912"""
pdf_path = str(RESOURCE_ROOT / "crazyones.pdf")
with PdfWriter(pdf_path) as w:
assert len(w.pages) > 0
assert not w.fileobj
with open(pdf_path, "rb") as f, PdfWriter(f) as w:
assert len(w.pages) > 0
assert not w.fileobj
with open(pdf_path, "rb") as f, PdfWriter(BytesIO(f.read(-1))) as w:
assert len(w.pages) > 0
assert not w.fileobj

try:
with NamedTemporaryFile(suffix=".pdf", delete=False) as tmp:
tmp_file = Path(tmp.name)
with PdfWriter(tmp_file) as w:
assert len(w.pages) == 0

with open(tmp_file, "wb") as f1, open(pdf_path, "rb") as f:
f1.write(f.read(-1))
with PdfWriter(tmp_file) as w:
assert len(w.pages) > 0
assert tmp_file.stat().st_size > 0

with PdfWriter(tmp_file, incremental=True) as w:
assert w._reader
assert not w.fileobj
assert tmp_file.stat().st_size > 0

with PdfWriter(clone_from=tmp_file) as w:
assert len(w.pages) > 0
assert not w.fileobj
assert tmp_file.stat().st_size > 0

with PdfWriter(fileobj=tmp_file) as w:
assert len(w.pages) == 0
assert 8 <= tmp_file.stat().st_size <= 1024

b = BytesIO()
with PdfWriter(fileobj=b) as w:
assert len(w.pages) == 0
assert not b.closed
assert 8 <= len(b.getbuffer()) <= 1024

with NamedTemporaryFile(mode="wb", suffix=".pdf", delete=True) as tmp:
with PdfWriter(pdf_path, fileobj=tmp, incremental=True) as w:
assert w._reader
assert not tmp.closed
assert Path(tmp.name).stat().st_size == Path(pdf_path).stat().st_size

with PdfWriter(tmp_file) as w:
assert len(w.pages) == 0

except Exception as e:
raise e
finally:
tmp_file.unlink()

0 comments on commit fb1ee44

Please sign in to comment.