Skip to content

Commit 72d30dd

Browse files
committed
basic handling for training on gitea datasets
1 parent 11eebc0 commit 72d30dd

File tree

1 file changed

+46
-1
lines changed

1 file changed

+46
-1
lines changed

torchmdnet/datasets/ace.py

+46-1
Original file line numberDiff line numberDiff line change
@@ -293,14 +293,59 @@ def sample_iter(self, mol_ids=False):
293293
yield data
294294

295295

296+
def download_gitea_dataset(path, tmpdir):
297+
try:
298+
from git import Repo
299+
except ImportError:
300+
raise ImportError(
301+
"Could not import GitPython library. Please install it first with `pip install GitPython`"
302+
)
303+
304+
assert path.startswith("ssh://")
305+
306+
# Parse the gitea URL
307+
pieces = path.split("/")
308+
repo_url = "/".join(pieces[:5])
309+
user = pieces[3]
310+
repo_name = pieces[4]
311+
file_name = pieces[-1]
312+
branch = "main"
313+
commit = None
314+
if "branch" in pieces:
315+
branch = pieces[pieces.index("branch") + 1]
316+
if "commit" in pieces:
317+
commit = pieces[pieces.index("commit") + 1]
318+
319+
outdir = os.path.join(tmpdir, f"{user}_{repo_name}")
320+
if not os.path.exists(outdir):
321+
repo = Repo.clone_from(repo_url, outdir, no_checkout=True)
322+
else:
323+
repo = Repo(outdir)
324+
325+
origin = repo.remotes.origin
326+
origin.pull()
327+
if commit is not None:
328+
repo.git.checkout(commit)
329+
else:
330+
repo.git.checkout(branch)
331+
332+
return os.path.join(outdir, file_name)
333+
334+
296335
class AceHF(Dataset):
297336
def __init__(
298337
self, root="parquet", paths=None, split="train", max_gradient=None
299338
) -> None:
300339
from datasets import load_dataset
301340
import numpy as np
302341

303-
self.dataset = load_dataset(root, data_files=paths, split=split)
342+
# Handle gitea parquet datasets
343+
newpaths = paths.copy()
344+
for i, path in enumerate(paths):
345+
if "gitea" in path:
346+
newpaths[i] = download_gitea_dataset(path, "/tmp")
347+
348+
self.dataset = load_dataset(root, data_files=newpaths, split=split)
304349
if max_gradient is not None:
305350

306351
def _filter(x):

0 commit comments

Comments
 (0)