Skip to content

Commit 7cae6c9

Browse files
Backport PR #3331 on branch 1.3.x (update bug of installation with custom dataloder) (#3332)
Backport PR #3331: update bug of installation with custom dataloder Co-authored-by: Ori Kronfeld <[email protected]>
1 parent 22491cd commit 7cae6c9

File tree

3 files changed

+9
-7
lines changed

3 files changed

+9
-7
lines changed

docs/api/developer.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,7 @@ Module classes in the external API with respective generative and inference proc
185185
external.methylvi.METHYLANVAE
186186
external.decipher.DecipherPyroModule
187187
external.resolvi.RESOLVAE
188+
external.totalanvi.TOTALANVAE
188189
external.sysvi.SysVAE
189190
190191
```

docs/api/user.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ import scvi
6262
external.METHYLVI
6363
external.METHYLANVI
6464
external.Decipher
65+
external.TOTALANVI
6566
external.RESOLVI
6667
external.SysVI
6768
```

src/scvi/dataloaders/_custom_dataloders.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,13 @@
1616
if TYPE_CHECKING:
1717
from typing import Any
1818

19+
import lamindb as ln
1920
import pandas as pd
21+
import tiledbsoma as soma
2022

2123

22-
@dependencies("lamindb")
2324
class MappedCollectionDataModule(LightningDataModule):
24-
import lamindb as ln
25-
25+
@dependencies("lamindb")
2626
def __init__(
2727
self,
2828
collection: ln.Collection,
@@ -353,18 +353,15 @@ def __len__(self):
353353
return len(self.dataloader)
354354

355355

356-
@dependencies("tiledbsoma")
357-
@dependencies("tiledbsoma_ml")
358356
class TileDBDataModule(LightningDataModule):
359-
import tiledbsoma as soma
360-
361357
"""PyTorch Lightning DataModule for training scVI models from SOMA data
362358
363359
Wraps a `tiledbsoma_ml.ExperimentDataset` to stream the results of a SOMA
364360
`ExperimentAxisQuery`, exposing a `DataLoader` to generate tensors ready for scVI model
365361
training. Also handles deriving the scVI batch label as a tuple of obs columns.
366362
"""
367363

364+
@dependencies("tiledbsoma")
368365
def __init__(
369366
self,
370367
query: soma.ExperimentAxisQuery,
@@ -503,6 +500,7 @@ def __init__(
503500
accelerator=accelerator, devices=device, return_device="torch"
504501
)
505502

503+
@dependencies("tiledbsoma_ml")
506504
def setup(self, stage: str | None = None) -> None:
507505
# Instantiate the ExperimentDataset with the provided args and kwargs.
508506
from tiledbsoma_ml import ExperimentDataset
@@ -539,6 +537,7 @@ def setup(self, stage: str | None = None) -> None:
539537
else:
540538
self.val_dataset = None
541539

540+
@dependencies("tiledbsoma_ml")
542541
def train_dataloader(self) -> DataLoader:
543542
from tiledbsoma_ml import experiment_dataloader
544543

@@ -547,6 +546,7 @@ def train_dataloader(self) -> DataLoader:
547546
**self.dataloader_kwargs,
548547
)
549548

549+
@dependencies("tiledbsoma_ml")
550550
def val_dataloader(self) -> DataLoader:
551551
from tiledbsoma_ml import experiment_dataloader
552552

0 commit comments

Comments
 (0)