Skip to content

Commit 320ed0a

Browse files
committed
Improve robustness in TzstArchive handling
Enhanced error handling and validation in TzstArchive methods, including better checks for file object initialization and compressed stream creation. Updated compression level validation to use dynamic max level from zstd library. Simplified and clarified comments and error messages for unsupported modes and extraction operations.
1 parent 36ffd4f commit 320ed0a

File tree

1 file changed

+91
-103
lines changed

1 file changed

+91
-103
lines changed

src/tzst/core.py

Lines changed: 91 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def _handle_file_conflict(
105105
# Convert string resolution to enum if needed
106106
if isinstance(resolution, str):
107107
try:
108-
resolution = ConflictResolution(resolution)
108+
resolution = ConflictResolution(resolution.lower())
109109
except ValueError:
110110
# Invalid string, fallback to ASK for interactive handling
111111
resolution = ConflictResolution.ASK
@@ -130,7 +130,6 @@ def _handle_file_conflict(
130130
elif resolution == ConflictResolution.EXIT:
131131
return resolution, None
132132
else:
133-
# Unknown resolution, default to REPLACE for robustness
134133
return ConflictResolution.REPLACE, target_path
135134

136135

@@ -176,19 +175,20 @@ def __init__(
176175
)
177176

178177
# Validate compression level
179-
if not 1 <= compression_level <= 22:
178+
if not 1 <= compression_level <= zstd.ZstdCompressor.max_level():
180179
raise ValueError(
181-
f"Invalid compression level '{compression_level}'. Must be between 1 and 22."
180+
f"Invalid compression level '{compression_level}'. "
181+
f"Must be between 1 and {zstd.ZstdCompressor.max_level()}."
182182
)
183183

184-
# Check for unsupported modes immediately - provide clear documentation
184+
# Check for unsupported modes immediately
185185
if mode.startswith("a"):
186186
raise NotImplementedError(
187187
"Append mode is not currently supported for .tzst/.tar.zst archives. "
188-
"This would require decompressing the entire archive, adding new files, "
189-
"and recompressing, which is complex and potentially slow for large archives. "
190-
"Alternatives: 1) Create multiple archives, 2) Recreate the archive with all files, "
191-
"3) Use standard tar format for append operations, then compress separately."
188+
"This involves complex operations (decompress, add, recompress) "
189+
"and is slow for large archives. Consider alternatives like creating "
190+
"multiple archives, recreating the archive, or using standard tar "
191+
"for appends and compressing separately."
192192
)
193193

194194
def __enter__(self):
@@ -212,24 +212,28 @@ def open(self):
212212
"""
213213
try:
214214
if self.mode.startswith("r"):
215-
# Read mode
216215
self._fileobj = open(self.filename, "rb")
217216
dctx = zstd.ZstdDecompressor()
218217

219218
if self.streaming:
220-
# Streaming mode - use stream reader directly (memory efficient)
221-
# Note: This may limit some tarfile operations that require seeking
222219
self._compressed_stream = dctx.stream_reader(self._fileobj)
220+
if not isinstance(self._compressed_stream, io.IOBase):
221+
raise TzstArchiveError(
222+
"Failed to create a valid compressed stream for reading."
223+
)
223224
self._tarfile = tarfile.open(
224-
fileobj=self._compressed_stream, mode="r|"
225+
fileobj=self._compressed_stream,
226+
mode="r|", # type: ignore[arg-type]
225227
)
226228
else:
227-
# Buffer mode - decompress to memory buffer for random access
228-
# Better compatibility but higher memory usage for large archives
229+
if self._fileobj is None:
230+
raise TzstArchiveError(
231+
"File object not initialized for reading."
232+
)
229233
decompressed_chunks = []
230234
with dctx.stream_reader(self._fileobj) as reader:
231235
while True:
232-
chunk = reader.read(8192)
236+
chunk = reader.read(io.DEFAULT_BUFFER_SIZE)
233237
if not chunk:
234238
break
235239
decompressed_chunks.append(chunk)
@@ -240,25 +244,20 @@ def open(self):
240244
)
241245

242246
elif self.mode.startswith("w"):
243-
# Write mode - use streaming compression
244247
self._fileobj = open(self.filename, "wb")
245248
cctx = zstd.ZstdCompressor(
246249
level=self.compression_level, write_content_size=True
247250
)
251+
if self._fileobj is None:
252+
raise TzstArchiveError("File object not initialized for writing.")
248253
self._compressed_stream = cctx.stream_writer(self._fileobj)
249-
self._tarfile = tarfile.open(fileobj=self._compressed_stream, mode="w|")
250-
elif self.mode.startswith("a"):
251-
# Append mode - for tar.zst, this is complex as we need to decompress,
252-
# add files, and recompress. For simplicity, we'll raise an error for now.
253-
raise NotImplementedError(
254-
"Append mode is not currently supported for .tzst/.tar.zst archives. "
255-
"This would require decompressing the entire archive, adding new files, "
256-
"and recompressing, which is complex and potentially slow for large archives. "
257-
"Alternatives: 1) Create multiple archives, 2) Recreate the archive with all files, "
258-
"3) Use standard tar format for append operations, then compress separately."
259-
)
260-
else:
261-
raise ValueError(f"Invalid mode: {self.mode}")
254+
if not isinstance(self._compressed_stream, io.IOBase):
255+
raise TzstArchiveError(
256+
"Failed to create a valid compressed stream for writing."
257+
)
258+
self._tarfile = tarfile.open(fileobj=self._compressed_stream, mode="w|") # type: ignore[arg-type]
259+
# 'a' mode already handled by raising NotImplementedError earlier
260+
# else case for invalid mode also handled earlier
262261
except Exception as e:
263262
self.close()
264263
if "zstd" in str(e).lower():
@@ -340,7 +339,7 @@ def extract(
340339
- 'data': Safe filter for cross-platform data archives (default, recommended)
341340
- 'tar': Honor most tar features but block dangerous ones
342341
- 'fully_trusted': Honor all metadata (use only for trusted archives)
343-
- None: Use default behavior (may show deprecation warning in Python 3.12+)
342+
- None: Use default behavior (Python 3.12+ may warn about this)
344343
- callable: Custom filter function
345344
346345
Warning:
@@ -368,24 +367,34 @@ def extract(
368367
# Specific member extraction not supported in streaming mode
369368
raise RuntimeError(
370369
"Extracting specific members is not supported in streaming mode. "
371-
"Please use non-streaming mode for selective extraction, or extract all files."
372-
) # Prepare extraction arguments - different parameters for extract vs extractall
370+
"Use non-streaming mode or extract all files."
371+
)
373372
try:
374373
if member:
375-
# extract() accepts set_attrs, numeric_owner, and filter
376-
extract_kwargs = {
377-
"set_attrs": set_attrs,
378-
"numeric_owner": numeric_owner,
379-
"filter": filter,
380-
}
381-
self._tarfile.extract(member, path=extract_path, **extract_kwargs)
374+
# Ensure correct types for extract_kwargs
375+
_filter_arg: (
376+
str
377+
| Callable[[tarfile.TarInfo, str], tarfile.TarInfo | None]
378+
| None
379+
) = filter
380+
self._tarfile.extract( # type: ignore[call-arg]
381+
member,
382+
path=extract_path,
383+
set_attrs=set_attrs, # type: ignore[call-arg]
384+
numeric_owner=numeric_owner, # type: ignore[call-arg]
385+
filter=_filter_arg, # type: ignore[call-arg]
386+
)
382387
else:
383-
# extractall() only accepts numeric_owner and filter (no set_attrs)
384-
extractall_kwargs = {
385-
"numeric_owner": numeric_owner,
386-
"filter": filter,
387-
}
388-
self._tarfile.extractall(path=extract_path, **extractall_kwargs)
388+
_filter_arg_all: (
389+
str
390+
| Callable[[tarfile.TarInfo, str], tarfile.TarInfo | None]
391+
| None
392+
) = filter
393+
self._tarfile.extractall( # type: ignore[call-arg]
394+
path=extract_path,
395+
numeric_owner=numeric_owner, # type: ignore[call-arg]
396+
filter=_filter_arg_all, # type: ignore[call-arg]
397+
)
389398
except (tarfile.StreamError, OSError) as e:
390399
if self.streaming and (
391400
"seeking" in str(e).lower() or "stream" in str(e).lower()
@@ -421,7 +430,9 @@ def extractall(
421430
members: list[tarfile.TarInfo] | None = None,
422431
*,
423432
numeric_owner: bool = False,
424-
filter: str | Callable | None = "data",
433+
filter: str
434+
| Callable[[tarfile.TarInfo, str], tarfile.TarInfo | None]
435+
| None = "data",
425436
):
426437
"""
427438
Extract all members from the archive.
@@ -468,15 +479,17 @@ def extractall(
468479
extract_path.mkdir(parents=True, exist_ok=True)
469480

470481
try:
471-
# extractall() accepts numeric_owner, filter, and members parameters
472-
extractall_kwargs = {
473-
"numeric_owner": numeric_owner,
474-
"filter": filter,
475-
}
476-
if members is not None:
477-
extractall_kwargs["members"] = members
478-
479-
self._tarfile.extractall(path=extract_path, **extractall_kwargs)
482+
_filter_arg: (
483+
str | Callable[[tarfile.TarInfo, str], tarfile.TarInfo | None] | None
484+
) = filter
485+
_members_arg: list[tarfile.TarInfo] | None = members
486+
487+
self._tarfile.extractall(
488+
path=extract_path,
489+
members=_members_arg,
490+
numeric_owner=numeric_owner,
491+
filter=_filter_arg, # type: ignore[call-arg]
492+
)
480493
except (tarfile.StreamError, OSError) as e:
481494
if self.streaming and (
482495
"seeking" in str(e).lower() or "stream" in str(e).lower()
@@ -585,7 +598,7 @@ def test(self) -> bool:
585598
if fileobj:
586599
# Read the entire file to verify decompression
587600
while True:
588-
chunk = fileobj.read(8192)
601+
chunk = fileobj.read(io.DEFAULT_BUFFER_SIZE)
589602
if not chunk:
590603
break
591604
return True
@@ -603,62 +616,37 @@ def create_archive(
603616
use_temp_file: bool = True,
604617
) -> None:
605618
"""
606-
Create a new .tzst archive with atomic file operations.
619+
Create a .tzst/.tar.zst archive.
607620
608621
Args:
609-
archive_path: Path for the new archive
610-
files: List of files/directories to add
622+
archive_path: Path to the archive file to be created
623+
files: List of files or directories to add to the archive
611624
compression_level: Zstandard compression level (1-22)
612-
use_temp_file: If True, create archive in temporary file first, then move
613-
to final location for atomic operation
614-
615-
See Also:
616-
:meth:`TzstArchive.add`: Method for adding files to an open archive
625+
use_temp_file: If True, create archive in a temporary file first,
626+
then move for atomicity (recommended).
617627
"""
618-
# Validate compression level
619-
if not 1 <= compression_level <= 22:
620-
raise ValueError(
621-
f"Invalid compression level '{compression_level}'. Must be between 1 and 22."
622-
)
623-
624-
archive_path = Path(archive_path)
625-
626-
# Ensure archive has correct extension
627-
if archive_path.suffix.lower() not in [".tzst", ".zst"]:
628-
if archive_path.suffix.lower() == ".tar":
629-
archive_path = archive_path.with_suffix(".tar.zst")
630-
else:
631-
archive_path = archive_path.with_suffix(archive_path.suffix + ".tzst")
628+
archive_path_obj = Path(archive_path)
632629

633-
# Use temporary file for atomic operation if requested
634630
if use_temp_file:
635-
temp_fd = None
636-
temp_path = None
637-
try:
638-
# Create temporary file in same directory as target for atomic move
639-
temp_fd, temp_path_str = tempfile.mkstemp(
640-
suffix=".tmp", prefix=f".{archive_path.name}.", dir=archive_path.parent
641-
)
642-
os.close(temp_fd) # Close file descriptor, we'll open with TzstArchive
643-
temp_path = Path(temp_path_str)
644-
645-
# Create archive in temporary location
646-
_create_archive_impl(temp_path, files, compression_level)
631+
temp_dir = archive_path_obj.parent
632+
temp_dir.mkdir(parents=True, exist_ok=True)
647633

648-
# Atomic move to final location
649-
temp_path.replace(archive_path)
634+
fd, temp_archive_name = tempfile.mkstemp(
635+
suffix=".tzst.tmp", prefix=archive_path_obj.name + "_", dir=str(temp_dir)
636+
)
637+
os.close(fd)
638+
temp_archive_path = Path(temp_archive_name)
650639

640+
try:
641+
_create_archive_impl(temp_archive_path, files, compression_level)
642+
archive_path_obj.parent.mkdir(parents=True, exist_ok=True)
643+
temp_archive_path.rename(archive_path_obj)
651644
except Exception:
652-
# Clean up temporary file on error
653-
if temp_path and temp_path.exists():
654-
try:
655-
temp_path.unlink()
656-
except Exception:
657-
pass
645+
if temp_archive_path.exists():
646+
temp_archive_path.unlink(missing_ok=True)
658647
raise
659648
else:
660-
# Direct creation (non-atomic)
661-
_create_archive_impl(archive_path, files, compression_level)
649+
_create_archive_impl(archive_path_obj, files, compression_level)
662650

663651

664652
def _create_archive_impl(
@@ -981,7 +969,7 @@ def test_archive(archive_path: str | Path, streaming: bool = False) -> bool:
981969
if fileobj:
982970
# Read the entire file to verify decompression
983971
while True:
984-
chunk = fileobj.read(8192)
972+
chunk = fileobj.read(io.DEFAULT_BUFFER_SIZE)
985973
if not chunk:
986974
break
987975
return True

0 commit comments

Comments
 (0)