Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

Hardlink or copy #5502

Merged
merged 3 commits into from
Dec 10, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed the docstring information for the `FBetaMultiLabelMeasure` metric.
- Various fixes for Python 3.9
- Fixed the name that the `push-to-hf` command uses to store weights.
- Support for inferior operating systems when making hardlinks
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shots fired 🤣

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


### Removed

Expand Down
11 changes: 11 additions & 0 deletions allennlp/common/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
find_latest_cached as _find_latest_cached,
)
from cached_path.cache_file import CacheFile
from cached_path.common import PathOrStr
from cached_path.file_lock import FileLock
from cached_path.meta import Meta as _Meta
import torch
Expand Down Expand Up @@ -606,3 +607,13 @@ def inspect_cache(patterns: List[str] = None, cache_dir: Union[str, Path] = None

def filename_is_safe(filename: str) -> bool:
return all(c in SAFE_FILENAME_CHARS for c in filename)


def hardlink_or_copy(source: PathOrStr, dest: PathOrStr):
try:
os.link(source, dest)
except OSError as e:
if e.errno in {18, 95}: # Cross-device link and Windows
shutil.copy(source, dest)
else:
raise
11 changes: 2 additions & 9 deletions allennlp/common/sqlite_sparse_sequence.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import os
import shutil
from os import PathLike
from typing import MutableSequence, Any, Union, Iterable
from sqlitedict import SqliteDict

from allennlp.common.file_utils import hardlink_or_copy
from allennlp.common.sequences import SlicedSequence


Expand Down Expand Up @@ -95,10 +94,4 @@ def close(self) -> None:
self.table = None

def copy_to(self, target: Union[str, PathLike]):
try:
os.link(self.table.filename, target)
except OSError as e:
if e.errno == 18: # Cross-device link
shutil.copy(self.table.filename, target)
else:
raise
hardlink_or_copy(self.table.filename, target)
3 changes: 2 additions & 1 deletion allennlp/training/gradient_descent_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from allennlp.common.checks import ConfigurationError, check_for_gpu
from allennlp.common import util as common_util, Tqdm, Lazy
from allennlp.common.file_utils import hardlink_or_copy
from allennlp.data.data_loaders.data_loader import DataLoader, TensorDict
from allennlp.models.model import Model
from allennlp.nn.parallel import DdpAccelerator, DdpWrappedModel, TorchDdpAccelerator
Expand Down Expand Up @@ -913,7 +914,7 @@ def _try_train(self) -> Tuple[Dict[str, Any], int]:
model_state_file, _ = last_checkpoint
if os.path.exists(self._best_model_filename):
os.remove(self._best_model_filename)
os.link(model_state_file, self._best_model_filename)
hardlink_or_copy(model_state_file, self._best_model_filename)
else:
self._save_model_state(self._best_model_filename)
else:
Expand Down