Skip to content

Commit a58203c

Browse files
committed
Add txt2img example using the inference API
1 parent 96808a2 commit a58203c

File tree

2 files changed

+181
-0
lines changed

2 files changed

+181
-0
lines changed

README.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,16 @@ depending on your use case and PyTorch version, manually.
106106

107107
## Inference
108108

109+
### Minimal txt2img demo
110+
111+
There is a minimal text-to-image demo available as `txt2img.py`:
112+
113+
```
114+
python txt2img.py --prompt "Big fluffy cat in a cereal bowl" --steps 25 --seed 1050
115+
```
116+
117+
### Streamlit demo
118+
109119
We provide a [streamlit](https://streamlit.io/) demo for text-to-image and image-to-image sampling in `scripts/demo/sampling.py`.
110120
We provide file hashes for the complete file as well as for only the saved tensors in the file (see [Model Spec](https://github.com/Stability-AI/ModelSpec) for a script to evaluate that).
111121
The following models are currently supported:

txt2img.py

Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
"""
2+
This is a very minimal txt2img example using `sgm.inference.api`.
3+
"""
4+
from __future__ import annotations
5+
6+
import argparse
7+
import dataclasses
8+
import logging
9+
import os
10+
import time
11+
from pathlib import Path
12+
from unittest.mock import patch
13+
14+
import numpy as np
15+
import torch
16+
from PIL import Image
17+
import einops
18+
import omegaconf
19+
import pytorch_lightning
20+
21+
from sgm import get_configs_path
22+
from sgm.inference.api import (
23+
model_specs,
24+
ModelArchitecture,
25+
SamplingParams,
26+
SamplingSpec,
27+
get_sampler_config,
28+
Discretization,
29+
)
30+
from sgm.inference.helpers import do_sample
31+
from sgm.util import load_model_from_config, get_default_device_name
32+
33+
logger = logging.getLogger("txt2img")
34+
35+
36+
def run_txt2img(
37+
*,
38+
model,
39+
spec: SamplingSpec,
40+
prompt: str,
41+
steps: int,
42+
width: int | None,
43+
height: int | None,
44+
cfg_scale: float | None,
45+
num_samples=1,
46+
seed: int,
47+
device: str,
48+
):
49+
params = SamplingParams(
50+
discretization=Discretization.EDM,
51+
height=(height or spec.height),
52+
rho=7,
53+
steps=steps,
54+
width=(width or spec.width),
55+
)
56+
if cfg_scale:
57+
params.cfg_scale = cfg_scale
58+
59+
with torch.no_grad(), model.ema_scope():
60+
pytorch_lightning.seed_everything(seed)
61+
sampler = get_sampler_config(params)
62+
value_dict = {
63+
**dataclasses.asdict(params),
64+
"prompt": prompt,
65+
"negative_prompt": "",
66+
"target_width": params.width,
67+
"target_height": params.height,
68+
}
69+
logger.info("Starting sampling with %s", params)
70+
return do_sample(
71+
model,
72+
sampler,
73+
value_dict,
74+
num_samples,
75+
params.height,
76+
params.width,
77+
spec.channels,
78+
spec.factor,
79+
force_uc_zero_embeddings=["txt"] if not spec.is_legacy else [],
80+
return_latents=False,
81+
filter=None,
82+
device=device,
83+
)
84+
85+
86+
@torch.no_grad()
87+
def fast_load(*, config, ckpt, device):
88+
config = omegaconf.OmegaConf.load(config)
89+
logger.info("Loading model")
90+
# This patch is borrowed from AUTOMATIC1111's stable-diffusion-webui;
91+
# we don't need to initialize the weights just for them to be overwritten
92+
# by the checkpoint.
93+
with (
94+
patch.object(torch.nn.init, "kaiming_uniform_"),
95+
patch.object(torch.nn.init, "_no_grad_normal_"),
96+
patch.object(torch.nn.init, "_no_grad_uniform_"),
97+
):
98+
model = load_model_from_config(
99+
config,
100+
ckpt=ckpt,
101+
device="cpu",
102+
freeze=True,
103+
verbose=False,
104+
)
105+
logger.info("Moving model to device")
106+
model.to(device)
107+
model.eval()
108+
return model
109+
110+
111+
def main():
112+
logging.basicConfig(
113+
level=logging.INFO, format="[%(levelname)s] %(name)s: %(message)s"
114+
)
115+
# Quiesce some uninformative CLIP and attention logging.
116+
logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR)
117+
logging.getLogger("sgm.modules.attention").setLevel(logging.ERROR)
118+
119+
ap = argparse.ArgumentParser()
120+
ap.add_argument(
121+
"--spec",
122+
default=ModelArchitecture.SDXL_V1_BASE.value,
123+
choices=[s.value for s in ModelArchitecture],
124+
)
125+
ap.add_argument("--device", default=get_default_device_name())
126+
ap.add_argument(
127+
"--prompt",
128+
default="Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
129+
)
130+
ap.add_argument("--seed", type=int, default=42)
131+
ap.add_argument("--steps", type=int, default=20)
132+
ap.add_argument("--width", type=int)
133+
ap.add_argument("--height", type=int)
134+
ap.add_argument("--cfg-scale", type=float)
135+
ap.add_argument("--num-samples", type=int, default=1)
136+
args = ap.parse_args()
137+
spec = model_specs[ModelArchitecture(args.spec)]
138+
logger.info(f"Using model spec: {spec}")
139+
model = fast_load(
140+
config=os.path.join(get_configs_path(), "inference", spec.config),
141+
ckpt=os.path.join("checkpoints", spec.ckpt),
142+
device=args.device,
143+
)
144+
145+
samples = run_txt2img(
146+
model=model,
147+
spec=spec,
148+
prompt=args.prompt,
149+
steps=args.steps,
150+
width=args.width,
151+
height=args.height,
152+
cfg_scale=args.cfg_scale,
153+
num_samples=args.num_samples,
154+
device=args.device,
155+
seed=args.seed,
156+
)
157+
158+
out_path = Path("outputs")
159+
out_path.mkdir(exist_ok=True)
160+
161+
prefix = int(time.time())
162+
163+
for i, sample in enumerate(samples, 1):
164+
filename = out_path / f"{prefix}-{i:04}.png"
165+
print(f"Saving {i}/{len(samples)}: {filename}")
166+
sample = 255.0 * einops.rearrange(sample, "c h w -> h w c")
167+
Image.fromarray(sample.cpu().numpy().astype(np.uint8)).save(filename)
168+
169+
170+
if __name__ == "__main__":
171+
main()

0 commit comments

Comments
 (0)