Skip to content

Be more flexible when saving / loading prov and libs data #1

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/anisette/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from importlib.metadata import version

from ._device import AnisetteDeviceConfig
from .anisette import Anisette
from .anisette import Anisette, AnisetteHeaders

__version__ = version("anisette")

__all__ = ("Anisette", "AnisetteDeviceConfig")
__all__ = ("Anisette", "AnisetteDeviceConfig", "AnisetteHeaders")
27 changes: 19 additions & 8 deletions src/anisette/_ani_provider.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,26 @@
from __future__ import annotations

import logging
from typing import BinaryIO
from typing import BinaryIO, Callable

from typing_extensions import Self

from ._adi import ADI
from ._device import AnisetteDeviceConfig, Device
from ._fs import FSCollection
from ._fs import FSCollection, VirtualFileSystem
from ._library import LibraryStore
from ._session import ProvisioningSession


class AnisetteProvider:
def __init__(self, fs_collection: FSCollection, default_device_config: AnisetteDeviceConfig | None) -> None:
def __init__(
self,
fs_collection: FSCollection,
fs_fallback: Callable[[], VirtualFileSystem],
default_device_config: AnisetteDeviceConfig | None,
) -> None:
self._fs_collection = fs_collection
self._fs_fallback = fs_fallback
self._default_device_config = default_device_config or AnisetteDeviceConfig.default()

self._lib_store: LibraryStore | None = None
Expand All @@ -23,8 +29,13 @@ def __init__(self, fs_collection: FSCollection, default_device_config: AnisetteD
self._provisioning_session: ProvisioningSession | None = None

@classmethod
def load(cls, *files: BinaryIO, default_device_config: AnisetteDeviceConfig | None = None) -> Self:
provider = cls(FSCollection.load(*files), default_device_config)
def load(
cls,
*files: BinaryIO,
fs_fallback: Callable[[], VirtualFileSystem],
default_device_config: AnisetteDeviceConfig | None = None,
) -> Self:
provider = cls(FSCollection.load(*files), fs_fallback, default_device_config)
assert provider.library_store is not None # verify that library store exists
return provider

Expand All @@ -34,10 +45,10 @@ def save(self, file: BinaryIO, include: list[str] | None = None, exclude: list[s
@property
def library_store(self) -> LibraryStore:
if self._lib_store is None:
lib_fs = self._fs_collection.get("libs", False)
lib_fs = self._fs_collection.get("libs", create_if_missing=False)
if lib_fs is None:
msg = "Library filesystem missing"
raise RuntimeError(msg)
lib_fs = self._fs_fallback()
self._fs_collection.add("libs", lib_fs)
self._lib_store = LibraryStore.from_virtfs(lib_fs)
return self._lib_store

Expand Down
3 changes: 3 additions & 0 deletions src/anisette/_fs.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,9 @@ def load(cls, *files: BinaryIO) -> Self:
filesystems[name] = VirtualFileSystem(fs)
return cls(**filesystems)

def add(self, name: str, fs: VirtualFileSystem) -> None:
self._filesystems[name] = fs

def save(self, file: BinaryIO, include: list[str] | None = None, exclude: list[str] | None = None) -> None:
to_save = set(self._filesystems.keys()) if include is None else set(include)
if exclude is not None:
Expand Down
38 changes: 28 additions & 10 deletions src/anisette/anisette.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,13 @@
)


def _get_libs(file: BinaryIO | str | Path | None = None) -> LibraryStore:
file = file or DEFAULT_LIBS_URL

with open_file(file, "rb") as f:
return LibraryStore.from_file(f)


class Anisette:
"""
The main Anisette provider class.
Expand All @@ -64,6 +71,11 @@ def __init__(self, ani_provider: AnisetteProvider) -> None:

self._ds_id = c_ulonglong(-2).value

@property
def is_provisioned(self) -> bool:
"""Whether this Anisette session has been provisioned yet or not."""
return self._ani_provider.adi.is_machine_provisioned(self._ds_id)

@classmethod
def init(
cls,
Expand All @@ -81,14 +93,11 @@ def init(
:return: An instance of :class:`Anisette`.
:rtype: :class:`Anisette`
"""
file = file or DEFAULT_LIBS_URL

with open_file(file, "rb") as f:
library_store = LibraryStore.from_file(f)

fs_collection = FSCollection(libs=library_store)
ani_provider = AnisetteProvider(fs_collection, default_device_config)

ani_provider = AnisetteProvider(
FSCollection(),
lambda: _get_libs(file),
default_device_config,
)
return cls(ani_provider)

@classmethod
Expand All @@ -106,7 +115,11 @@ def load(cls, *files: BinaryIO | str | Path, default_device_config: AnisetteDevi
"""
with ExitStack() as stack:
file_objs = [stack.enter_context(open_file(f, "rb")) for f in files]
ani_provider = AnisetteProvider.load(*file_objs, default_device_config=default_device_config)
ani_provider = AnisetteProvider.load(
*file_objs,
fs_fallback=lambda: _get_libs(),
default_device_config=default_device_config,
)

return cls(ani_provider)

Expand All @@ -126,6 +139,8 @@ def save_provisioning(self, file: BinaryIO | str | Path) -> None:
:param file: The file or path to save provisioning data to.
:type file: BinaryIO, str, Path
"""
self.provision()

with open_file(file, "wb+") as f:
self._ani_provider.save(f, exclude=["libs"])

Expand All @@ -144,6 +159,9 @@ def save_libs(self, file: BinaryIO | str | Path) -> None:
:param file: The file or path to save library data to.
:type file: BinaryIO, str, Path
"""
# force fetch of library store to make sure it exists when saving
_ = self._ani_provider.library_store

with open_file(file, "wb+") as f:
self._ani_provider.save(f, include=["libs"])

Expand Down Expand Up @@ -173,7 +191,7 @@ def provision(self) -> None:
In most cases it is not necessary to manually use this method, since :meth:`Anisette.get_data`
will call it implicitly.
"""
if not self._ani_provider.adi.is_machine_provisioned(self._ds_id):
if not self.is_provisioned:
logging.info("Provisioning...")
self._ani_provider.provisioning_session.provision(self._ds_id)

Expand Down