@@ -293,14 +293,59 @@ def sample_iter(self, mol_ids=False):
293
293
yield data
294
294
295
295
296
+ def download_gitea_dataset (path , tmpdir ):
297
+ try :
298
+ from git import Repo
299
+ except ImportError :
300
+ raise ImportError (
301
+ "Could not import GitPython library. Please install it first with `pip install GitPython`"
302
+ )
303
+
304
+ assert path .startswith ("ssh://" )
305
+
306
+ # Parse the gitea URL
307
+ pieces = path .split ("/" )
308
+ repo_url = "/" .join (pieces [:5 ])
309
+ user = pieces [3 ]
310
+ repo_name = pieces [4 ]
311
+ file_name = pieces [- 1 ]
312
+ branch = "main"
313
+ commit = None
314
+ if "branch" in pieces :
315
+ branch = pieces [pieces .index ("branch" ) + 1 ]
316
+ if "commit" in pieces :
317
+ commit = pieces [pieces .index ("commit" ) + 1 ]
318
+
319
+ outdir = os .path .join (tmpdir , f"{ user } _{ repo_name } " )
320
+ if not os .path .exists (outdir ):
321
+ repo = Repo .clone_from (repo_url , outdir , no_checkout = True )
322
+ else :
323
+ repo = Repo (outdir )
324
+
325
+ origin = repo .remotes .origin
326
+ origin .pull ()
327
+ if commit is not None :
328
+ repo .git .checkout (commit )
329
+ else :
330
+ repo .git .checkout (branch )
331
+
332
+ return os .path .join (outdir , file_name )
333
+
334
+
296
335
class AceHF (Dataset ):
297
336
def __init__ (
298
337
self , root = "parquet" , paths = None , split = "train" , max_gradient = None
299
338
) -> None :
300
339
from datasets import load_dataset
301
340
import numpy as np
302
341
303
- self .dataset = load_dataset (root , data_files = paths , split = split )
342
+ # Handle gitea parquet datasets
343
+ newpaths = paths .copy ()
344
+ for i , path in enumerate (paths ):
345
+ if "gitea" in path :
346
+ newpaths [i ] = download_gitea_dataset (path , "/tmp" )
347
+
348
+ self .dataset = load_dataset (root , data_files = newpaths , split = split )
304
349
if max_gradient is not None :
305
350
306
351
def _filter (x ):
0 commit comments