Skip to content

Commit ce4b043

Browse files
committed
Verify that lib store exists when loading
1 parent 57b9bec commit ce4b043

File tree

2 files changed

+21
-4
lines changed

2 files changed

+21
-4
lines changed

anisette/_anisette.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,15 +24,20 @@ def __init__(self, fs_collection: FSCollection, default_device_config: AnisetteD
2424

2525
@classmethod
2626
def load(cls, *files: BinaryIO, default_device_config: AnisetteDeviceConfig | None = None) -> Self:
27-
return cls(FSCollection.load(*files), default_device_config)
27+
provider = cls(FSCollection.load(*files), default_device_config)
28+
assert provider.library_store is not None # verify that library store exists
29+
return provider
2830

2931
def save(self, file: BinaryIO, include: list[str] | None = None, exclude: list[str] | None = None) -> None:
3032
return self._fs_collection.save(file, include, exclude)
3133

3234
@property
3335
def library_store(self) -> LibraryStore:
3436
if self._lib_store is None:
35-
lib_fs = self._fs_collection.get("libs")
37+
lib_fs = self._fs_collection.get("libs", False)
38+
if lib_fs is None:
39+
msg = "Library filesystem missing"
40+
raise RuntimeError(msg)
3641
self._lib_store = LibraryStore.from_virtfs(lib_fs)
3742
return self._lib_store
3843

anisette/_fs.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from contextlib import ExitStack
88
from dataclasses import dataclass
99
from pathlib import Path
10-
from typing import IO, TYPE_CHECKING, BinaryIO, Union
10+
from typing import IO, TYPE_CHECKING, BinaryIO, Literal, Union, overload
1111

1212
from fs import open_fs
1313
from fs.copy import copy_dir, copy_file, copy_fs
@@ -185,11 +185,23 @@ def save(self, file: BinaryIO, include: list[str] | None = None, exclude: list[s
185185

186186
tar_fs.writetext("fs.json", json.dumps(fs_index))
187187

188-
def get(self, fs_name: str) -> VirtualFileSystem:
188+
@overload
189+
def get(self, fs_name: str) -> VirtualFileSystem: ...
190+
191+
@overload
192+
def get(self, fs_name: str, create_if_missing: Literal[True]) -> VirtualFileSystem: ...
193+
194+
@overload
195+
def get(self, fs_name: str, create_if_missing: Literal[False]) -> VirtualFileSystem | None: ...
196+
197+
def get(self, fs_name: str, create_if_missing: bool = True) -> VirtualFileSystem | None:
189198
if fs_name in self._filesystems:
190199
logging.debug("Get FS from collection: %s", fs_name)
191200
return self._filesystems[fs_name]
192201

202+
if not create_if_missing:
203+
return None
204+
193205
logging.debug("Create new VFS: %s", fs_name)
194206
fs = VirtualFileSystem()
195207
self._filesystems[fs_name] = fs

0 commit comments

Comments
 (0)