Skip to content

Commit eb6b21c

Browse files
authored
ENH: Allow to pass input file without named argument (#2576)
1 parent ced67e1 commit eb6b21c

File tree

2 files changed

+56
-3
lines changed

2 files changed

+56
-3
lines changed

pypdf/_writer.py

+35-3
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ def is_encrypted(self) -> bool:
168168

169169
def __init__(
170170
self,
171-
fileobj: StrByteType = "",
171+
fileobj: Union[None, PdfReader, StrByteType, Path] = "",
172172
clone_from: Union[None, PdfReader, StrByteType, Path] = None,
173173
) -> None:
174174
self._header = b"%PDF-1.3"
@@ -213,12 +213,41 @@ def __init__(
213213
)
214214
self._root = self._add_object(self._root_object)
215215

216+
def _get_clone_from(
217+
fileobj: Union[None, PdfReader, str, Path, IO[Any], BytesIO],
218+
clone_from: Union[None, PdfReader, str, Path, IO[Any], BytesIO],
219+
) -> Union[None, PdfReader, str, Path, IO[Any], BytesIO]:
220+
if not isinstance(fileobj, (str, Path, IO, BytesIO)) or (
221+
fileobj != "" and clone_from is None
222+
):
223+
cloning = True
224+
if not (
225+
not isinstance(fileobj, (str, Path))
226+
or (
227+
Path(str(fileobj)).exists()
228+
and Path(str(fileobj)).stat().st_size > 0
229+
)
230+
):
231+
cloning = False
232+
if isinstance(fileobj, (IO, BytesIO)):
233+
t = fileobj.tell()
234+
fileobj.seek(-1, 2)
235+
if fileobj.tell() == 0:
236+
cloning = False
237+
fileobj.seek(t, 0)
238+
if cloning:
239+
clone_from = fileobj
240+
return clone_from
241+
242+
clone_from = _get_clone_from(fileobj, clone_from)
243+
# to prevent overwriting
244+
self.temp_fileobj = fileobj
245+
self.fileobj = ""
246+
self.with_as_usage = False
216247
if clone_from is not None:
217248
if not isinstance(clone_from, PdfReader):
218249
clone_from = PdfReader(clone_from)
219250
self.clone_document_from_reader(clone_from)
220-
self.fileobj = fileobj
221-
self.with_as_usage = False
222251

223252
self._encryption: Optional[Encryption] = None
224253
self._encrypt_entry: Optional[DictionaryObject] = None
@@ -268,7 +297,10 @@ def xmp_metadata(self, value: Optional[XmpInformation]) -> None:
268297

269298
def __enter__(self) -> "PdfWriter":
270299
"""Store that writer is initialized by 'with'."""
300+
t = self.temp_fileobj
301+
self.__init__() # type: ignore
271302
self.with_as_usage = True
303+
self.fileobj = t # type: ignore
272304
return self
273305

274306
def __exit__(

tests/test_writer.py

+21
Original file line numberDiff line numberDiff line change
@@ -2196,3 +2196,24 @@ def test_mime_jupyter():
21962196
writer = PdfWriter(clone_from=reader)
21972197
assert reader._repr_mimebundle_(("include",), ("exclude",)) == {}
21982198
assert writer._repr_mimebundle_(("include",), ("exclude",)) == {}
2199+
2200+
2201+
def test_init_without_named_arg():
2202+
"""Test to use file_obj argument and not clone_from"""
2203+
pdf_path = RESOURCE_ROOT / "crazyones.pdf"
2204+
reader = PdfReader(pdf_path)
2205+
writer = PdfWriter(clone_from=reader)
2206+
nb = len(writer._objects)
2207+
writer = PdfWriter(reader)
2208+
assert len(writer._objects) == nb
2209+
with open(pdf_path, "rb") as f:
2210+
writer = PdfWriter(f)
2211+
f.seek(0, 0)
2212+
by = BytesIO(f.read())
2213+
assert len(writer._objects) == nb
2214+
writer = PdfWriter(pdf_path)
2215+
assert len(writer._objects) == nb
2216+
writer = PdfWriter(str(pdf_path))
2217+
assert len(writer._objects) == nb
2218+
writer = PdfWriter(by)
2219+
assert len(writer._objects) == nb

0 commit comments

Comments
 (0)