Skip to content

Commit 305c131

Browse files
committed
Improve robustness and clarity in TzstArchive
Refactored error handling, added default behaviors for unknown resolutions, and improved documentation for unsupported modes. Enhanced memory efficiency and compatibility in archive operations, streamlined extraction logic, and validated compression levels more effectively.
1 parent efbcbbc commit 305c131

File tree

1 file changed

+110
-111
lines changed

1 file changed

+110
-111
lines changed

src/tzst/core.py

Lines changed: 110 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def _handle_file_conflict(
105105
# Convert string resolution to enum if needed
106106
if isinstance(resolution, str):
107107
try:
108-
resolution = ConflictResolution(resolution.lower())
108+
resolution = ConflictResolution(resolution)
109109
except ValueError:
110110
# Invalid string, fallback to ASK for interactive handling
111111
resolution = ConflictResolution.ASK
@@ -130,6 +130,7 @@ def _handle_file_conflict(
130130
elif resolution == ConflictResolution.EXIT:
131131
return resolution, None
132132
else:
133+
# Unknown resolution, default to REPLACE for robustness
133134
return ConflictResolution.REPLACE, target_path
134135

135136

@@ -175,27 +176,19 @@ def __init__(
175176
)
176177

177178
# Validate compression level
178-
try:
179-
# Try to get max_level from a constant in the zstandard module
180-
max_level = zstd.MAX_COMPRESSION_LEVEL
181-
except AttributeError:
182-
# Fallback to a hardcoded default if the constant is not found
183-
max_level = 22 # Common max level for zstd
184-
185-
if not 1 <= compression_level <= max_level:
179+
if not 1 <= compression_level <= 22:
186180
raise ValueError(
187-
f"Invalid compression level '{compression_level}'. "
188-
f"Must be between 1 and {max_level}."
181+
f"Invalid compression level '{compression_level}'. Must be between 1 and 22."
189182
)
190183

191-
# Check for unsupported modes immediately
184+
# Check for unsupported modes immediately - provide clear documentation
192185
if mode.startswith("a"):
193186
raise NotImplementedError(
194187
"Append mode is not currently supported for .tzst/.tar.zst archives. "
195-
"This involves complex operations (decompress, add, recompress) "
196-
"and is slow for large archives. Consider alternatives like creating "
197-
"multiple archives, recreating the archive, or using standard tar "
198-
"for appends and compressing separately."
188+
"This would require decompressing the entire archive, adding new files, "
189+
"and recompressing, which is complex and potentially slow for large archives. "
190+
"Alternatives: 1) Create multiple archives, 2) Recreate the archive with all files, "
191+
"3) Use standard tar format for append operations, then compress separately."
199192
)
200193

201194
def __enter__(self):
@@ -219,28 +212,24 @@ def open(self):
219212
"""
220213
try:
221214
if self.mode.startswith("r"):
215+
# Read mode
222216
self._fileobj = open(self.filename, "rb")
223217
dctx = zstd.ZstdDecompressor()
224218

225219
if self.streaming:
220+
# Streaming mode - use stream reader directly (memory efficient)
221+
# Note: This may limit some tarfile operations that require seeking
226222
self._compressed_stream = dctx.stream_reader(self._fileobj)
227-
if not isinstance(self._compressed_stream, io.IOBase):
228-
raise TzstArchiveError(
229-
"Failed to create a valid compressed stream for reading."
230-
)
231223
self._tarfile = tarfile.open(
232-
fileobj=self._compressed_stream,
233-
mode="r|", # type: ignore[arg-type]
224+
fileobj=self._compressed_stream, mode="r|"
234225
)
235226
else:
236-
if self._fileobj is None:
237-
raise TzstArchiveError(
238-
"File object not initialized for reading."
239-
)
227+
# Buffer mode - decompress to memory buffer for random access
228+
# Better compatibility but higher memory usage for large archives
240229
decompressed_chunks = []
241230
with dctx.stream_reader(self._fileobj) as reader:
242231
while True:
243-
chunk = reader.read(io.DEFAULT_BUFFER_SIZE)
232+
chunk = reader.read(8192)
244233
if not chunk:
245234
break
246235
decompressed_chunks.append(chunk)
@@ -251,20 +240,25 @@ def open(self):
251240
)
252241

253242
elif self.mode.startswith("w"):
243+
# Write mode - use streaming compression
254244
self._fileobj = open(self.filename, "wb")
255245
cctx = zstd.ZstdCompressor(
256246
level=self.compression_level, write_content_size=True
257247
)
258-
if self._fileobj is None:
259-
raise TzstArchiveError("File object not initialized for writing.")
260248
self._compressed_stream = cctx.stream_writer(self._fileobj)
261-
if not isinstance(self._compressed_stream, io.IOBase):
262-
raise TzstArchiveError(
263-
"Failed to create a valid compressed stream for writing."
264-
)
265-
self._tarfile = tarfile.open(fileobj=self._compressed_stream, mode="w|") # type: ignore[arg-type]
266-
# 'a' mode already handled by raising NotImplementedError earlier
267-
# else case for invalid mode also handled earlier
249+
self._tarfile = tarfile.open(fileobj=self._compressed_stream, mode="w|")
250+
elif self.mode.startswith("a"):
251+
# Append mode - for tar.zst, this is complex as we need to decompress,
252+
# add files, and recompress. For simplicity, we'll raise an error for now.
253+
raise NotImplementedError(
254+
"Append mode is not currently supported for .tzst/.tar.zst archives. "
255+
"This would require decompressing the entire archive, adding new files, "
256+
"and recompressing, which is complex and potentially slow for large archives. "
257+
"Alternatives: 1) Create multiple archives, 2) Recreate the archive with all files, "
258+
"3) Use standard tar format for append operations, then compress separately."
259+
)
260+
else:
261+
raise ValueError(f"Invalid mode: {self.mode}")
268262
except Exception as e:
269263
self.close()
270264
if "zstd" in str(e).lower():
@@ -346,7 +340,7 @@ def extract(
346340
- 'data': Safe filter for cross-platform data archives (default, recommended)
347341
- 'tar': Honor most tar features but block dangerous ones
348342
- 'fully_trusted': Honor all metadata (use only for trusted archives)
349-
- None: Use default behavior (Python 3.12+ may warn about this)
343+
- None: Use default behavior (may show deprecation warning in Python 3.12+)
350344
- callable: Custom filter function
351345
352346
Warning:
@@ -374,34 +368,24 @@ def extract(
374368
# Specific member extraction not supported in streaming mode
375369
raise RuntimeError(
376370
"Extracting specific members is not supported in streaming mode. "
377-
"Use non-streaming mode or extract all files."
378-
)
371+
"Please use non-streaming mode for selective extraction, or extract all files."
372+
) # Prepare extraction arguments - different parameters for extract vs extractall
379373
try:
380374
if member:
381-
# Ensure correct types for extract_kwargs
382-
_filter_arg: (
383-
str
384-
| Callable[[tarfile.TarInfo, str], tarfile.TarInfo | None]
385-
| None
386-
) = filter
387-
self._tarfile.extract( # type: ignore[call-arg]
388-
member,
389-
path=extract_path,
390-
set_attrs=set_attrs, # type: ignore[call-arg]
391-
numeric_owner=numeric_owner, # type: ignore[call-arg]
392-
filter=_filter_arg, # type: ignore[call-arg]
393-
)
375+
# extract() accepts set_attrs, numeric_owner, and filter
376+
extract_kwargs = {
377+
"set_attrs": set_attrs,
378+
"numeric_owner": numeric_owner,
379+
"filter": filter,
380+
}
381+
self._tarfile.extract(member, path=extract_path, **extract_kwargs)
394382
else:
395-
_filter_arg_all: (
396-
str
397-
| Callable[[tarfile.TarInfo, str], tarfile.TarInfo | None]
398-
| None
399-
) = filter
400-
self._tarfile.extractall( # type: ignore[call-arg]
401-
path=extract_path,
402-
numeric_owner=numeric_owner, # type: ignore[call-arg]
403-
filter=_filter_arg_all, # type: ignore[call-arg]
404-
)
383+
# extractall() only accepts numeric_owner and filter (no set_attrs)
384+
extractall_kwargs = {
385+
"numeric_owner": numeric_owner,
386+
"filter": filter,
387+
}
388+
self._tarfile.extractall(path=extract_path, **extractall_kwargs)
405389
except (tarfile.StreamError, OSError) as e:
406390
if self.streaming and (
407391
"seeking" in str(e).lower() or "stream" in str(e).lower()
@@ -437,9 +421,7 @@ def extractall(
437421
members: list[tarfile.TarInfo] | None = None,
438422
*,
439423
numeric_owner: bool = False,
440-
filter: str
441-
| Callable[[tarfile.TarInfo, str], tarfile.TarInfo | None]
442-
| None = "data",
424+
filter: str | Callable | None = "data",
443425
):
444426
"""
445427
Extract all members from the archive.
@@ -486,17 +468,15 @@ def extractall(
486468
extract_path.mkdir(parents=True, exist_ok=True)
487469

488470
try:
489-
_filter_arg: (
490-
str | Callable[[tarfile.TarInfo, str], tarfile.TarInfo | None] | None
491-
) = filter
492-
_members_arg: list[tarfile.TarInfo] | None = members
493-
494-
self._tarfile.extractall(
495-
path=extract_path,
496-
members=_members_arg,
497-
numeric_owner=numeric_owner,
498-
filter=_filter_arg, # type: ignore[call-arg]
499-
)
471+
# extractall() accepts numeric_owner, filter, and members parameters
472+
extractall_kwargs = {
473+
"numeric_owner": numeric_owner,
474+
"filter": filter,
475+
}
476+
if members is not None:
477+
extractall_kwargs["members"] = members
478+
479+
self._tarfile.extractall(path=extract_path, **extractall_kwargs)
500480
except (tarfile.StreamError, OSError) as e:
501481
if self.streaming and (
502482
"seeking" in str(e).lower() or "stream" in str(e).lower()
@@ -605,7 +585,7 @@ def test(self) -> bool:
605585
if fileobj:
606586
# Read the entire file to verify decompression
607587
while True:
608-
chunk = fileobj.read(io.DEFAULT_BUFFER_SIZE)
588+
chunk = fileobj.read(8192)
609589
if not chunk:
610590
break
611591
return True
@@ -623,37 +603,62 @@ def create_archive(
623603
use_temp_file: bool = True,
624604
) -> None:
625605
"""
626-
Create a .tzst/.tar.zst archive.
606+
Create a new .tzst archive with atomic file operations.
627607
628608
Args:
629-
archive_path: Path to the archive file to be created
630-
files: List of files or directories to add to the archive
609+
archive_path: Path for the new archive
610+
files: List of files/directories to add
631611
compression_level: Zstandard compression level (1-22)
632-
use_temp_file: If True, create archive in a temporary file first,
633-
then move for atomicity (recommended).
612+
use_temp_file: If True, create archive in temporary file first, then move
613+
to final location for atomic operation
614+
615+
See Also:
616+
:meth:`TzstArchive.add`: Method for adding files to an open archive
634617
"""
635-
archive_path_obj = Path(archive_path)
618+
# Validate compression level
619+
if not 1 <= compression_level <= 22:
620+
raise ValueError(
621+
f"Invalid compression level '{compression_level}'. Must be between 1 and 22."
622+
)
636623

637-
if use_temp_file:
638-
temp_dir = archive_path_obj.parent
639-
temp_dir.mkdir(parents=True, exist_ok=True)
624+
archive_path = Path(archive_path)
640625

641-
fd, temp_archive_name = tempfile.mkstemp(
642-
suffix=".tzst.tmp", prefix=archive_path_obj.name + "_", dir=str(temp_dir)
643-
)
644-
os.close(fd)
645-
temp_archive_path = Path(temp_archive_name)
626+
# Ensure archive has correct extension
627+
if archive_path.suffix.lower() not in [".tzst", ".zst"]:
628+
if archive_path.suffix.lower() == ".tar":
629+
archive_path = archive_path.with_suffix(".tar.zst")
630+
else:
631+
archive_path = archive_path.with_suffix(archive_path.suffix + ".tzst")
646632

633+
# Use temporary file for atomic operation if requested
634+
if use_temp_file:
635+
temp_fd = None
636+
temp_path = None
647637
try:
648-
_create_archive_impl(temp_archive_path, files, compression_level)
649-
archive_path_obj.parent.mkdir(parents=True, exist_ok=True)
650-
temp_archive_path.rename(archive_path_obj)
638+
# Create temporary file in same directory as target for atomic move
639+
temp_fd, temp_path_str = tempfile.mkstemp(
640+
suffix=".tmp", prefix=f".{archive_path.name}.", dir=archive_path.parent
641+
)
642+
os.close(temp_fd) # Close file descriptor, we'll open with TzstArchive
643+
temp_path = Path(temp_path_str)
644+
645+
# Create archive in temporary location
646+
_create_archive_impl(temp_path, files, compression_level)
647+
648+
# Atomic move to final location
649+
temp_path.replace(archive_path)
650+
651651
except Exception:
652-
if temp_archive_path.exists():
653-
temp_archive_path.unlink(missing_ok=True)
652+
# Clean up temporary file on error
653+
if temp_path and temp_path.exists():
654+
try:
655+
temp_path.unlink()
656+
except Exception:
657+
pass
654658
raise
655659
else:
656-
_create_archive_impl(archive_path_obj, files, compression_level)
660+
# Direct creation (non-atomic)
661+
_create_archive_impl(archive_path, files, compression_level)
657662

658663

659664
def _create_archive_impl(
@@ -813,10 +818,8 @@ def extract_archive(
813818

814819
fileobj = archive.extractfile(member)
815820
if fileobj:
816-
# Ensure target_path is not None before opening
817-
if target_path:
818-
with open(target_path, "wb") as f:
819-
f.write(fileobj.read())
821+
with open(target_path, "wb") as f:
822+
f.write(fileobj.read())
820823
else:
821824
# Extract with full directory structure
822825
if members:
@@ -845,14 +848,10 @@ def extract_archive(
845848
break
846849

847850
# For AUTO_RENAME, we need to adjust the member path
848-
if (
849-
actual_resolution
850-
in (
851-
ConflictResolution.AUTO_RENAME,
852-
ConflictResolution.AUTO_RENAME_ALL,
853-
)
854-
and final_path
855-
): # Ensure final_path is not None
851+
if actual_resolution in (
852+
ConflictResolution.AUTO_RENAME,
853+
ConflictResolution.AUTO_RENAME_ALL,
854+
):
856855
# Create parent directories for renamed file
857856
final_path.parent.mkdir(parents=True, exist_ok=True)
858857
# Extract to temporary location, then move
@@ -920,7 +919,7 @@ def extract_archive(
920919
):
921920
target_path.unlink() # Remove existing file
922921

923-
if target_path: # Ensure target_path is not None
922+
if target_path:
924923
temp_file.rename(target_path)
925924
finally:
926925
# Clean up temp directory
@@ -982,7 +981,7 @@ def test_archive(archive_path: str | Path, streaming: bool = False) -> bool:
982981
if fileobj:
983982
# Read the entire file to verify decompression
984983
while True:
985-
chunk = fileobj.read(io.DEFAULT_BUFFER_SIZE)
984+
chunk = fileobj.read(8192)
986985
if not chunk:
987986
break
988987
return True

0 commit comments

Comments
 (0)