Skip to content

Commit b20cf83

Browse files
committed
Add txt2img example using the inference API
1 parent 96808a2 commit b20cf83

File tree

2 files changed

+174
-0
lines changed

2 files changed

+174
-0
lines changed

README.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,16 @@ depending on your use case and PyTorch version, manually.
106106

107107
## Inference
108108

109+
### Minimal txt2img demo
110+
111+
There is a minimal text-to-image demo available as `txt2img.py`:
112+
113+
```
114+
python txt2img.py --prompt "Big fluffy cat in a cereal bowl" --steps 25 --seed 1050
115+
```
116+
117+
### Streamlit demo
118+
109119
We provide a [streamlit](https://streamlit.io/) demo for text-to-image and image-to-image sampling in `scripts/demo/sampling.py`.
110120
We provide file hashes for the complete file as well as for only the saved tensors in the file (see [Model Spec](https://github.com/Stability-AI/ModelSpec) for a script to evaluate that).
111121
The following models are currently supported:

txt2img.py

Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
"""
2+
This is a very minimal txt2img example using `sgm.inference.api`.
3+
"""
4+
import argparse
5+
import dataclasses
6+
import logging
7+
import os
8+
import time
9+
from pathlib import Path
10+
from unittest.mock import patch
11+
12+
import numpy as np
13+
import torch
14+
from PIL import Image
15+
import einops
16+
import omegaconf
17+
import pytorch_lightning
18+
19+
from sgm import get_configs_path
20+
from sgm.inference.api import (
21+
model_specs,
22+
ModelArchitecture,
23+
SamplingParams,
24+
SamplingSpec,
25+
get_sampler_config,
26+
)
27+
from sgm.inference.helpers import do_sample
28+
from sgm.util import load_model_from_config, get_default_device_name
29+
30+
logger = logging.getLogger("txt2img")
31+
32+
33+
def run_txt2img(
34+
*,
35+
model,
36+
spec: SamplingSpec,
37+
prompt: str,
38+
steps: int = 10,
39+
width: int = 1024,
40+
height: int = 1024,
41+
cfg_scale=5.0,
42+
num_samples=1,
43+
seed: int,
44+
device: str,
45+
):
46+
params = SamplingParams(
47+
width=width,
48+
height=height,
49+
steps=steps,
50+
scale=cfg_scale,
51+
)
52+
with torch.no_grad(), model.ema_scope():
53+
pytorch_lightning.seed_everything(seed)
54+
sampler = get_sampler_config(params)
55+
value_dict = {
56+
**dataclasses.asdict(params),
57+
"prompt": prompt,
58+
"negative_prompt": "",
59+
"target_width": params.width,
60+
"target_height": params.height,
61+
}
62+
logger.info("Starting sampling")
63+
return do_sample(
64+
model,
65+
sampler,
66+
value_dict,
67+
num_samples,
68+
params.height,
69+
params.width,
70+
spec.channels,
71+
spec.factor,
72+
force_uc_zero_embeddings=["txt"] if not spec.is_legacy else [],
73+
return_latents=False,
74+
filter=None,
75+
device=device,
76+
)
77+
78+
79+
@torch.no_grad()
80+
def fast_load(*, config, ckpt, device):
81+
config = omegaconf.OmegaConf.load(config)
82+
logger.info("Loading model")
83+
# This patch is borrowed from AUTOMATIC1111's stable-diffusion-webui;
84+
# we don't need to initialize the weights just for them to be overwritten
85+
# by the checkpoint.
86+
with (
87+
patch.object(torch.nn.init, "kaiming_uniform_"),
88+
patch.object(torch.nn.init, "_no_grad_normal_"),
89+
patch.object(torch.nn.init, "_no_grad_uniform_"),
90+
):
91+
model = load_model_from_config(
92+
config,
93+
ckpt=ckpt,
94+
device="cpu",
95+
freeze=True,
96+
verbose=False,
97+
)
98+
logger.info("Moving model to device")
99+
model.to(device)
100+
model.eval()
101+
return model
102+
103+
104+
def main():
105+
logging.basicConfig(
106+
level=logging.INFO, format="[%(levelname)s] %(name)s: %(message)s"
107+
)
108+
# Quiesce some uninformative CLIP and attention logging.
109+
logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR)
110+
logging.getLogger("sgm.modules.attention").setLevel(logging.ERROR)
111+
112+
ap = argparse.ArgumentParser()
113+
ap.add_argument(
114+
"--spec",
115+
default=ModelArchitecture.SDXL_V1_BASE.value,
116+
choices=[s.value for s in ModelArchitecture],
117+
)
118+
ap.add_argument("--device", default=get_default_device_name())
119+
ap.add_argument(
120+
"--prompt",
121+
default="Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
122+
)
123+
ap.add_argument("--seed", type=int, default=42)
124+
ap.add_argument("--steps", type=int, default=20)
125+
ap.add_argument("--width", type=int, default=1024)
126+
ap.add_argument("--height", type=int, default=1024)
127+
ap.add_argument("--cfg-scale", type=float, default=5.0)
128+
ap.add_argument("--num-samples", type=int, default=1)
129+
args = ap.parse_args()
130+
spec = model_specs[ModelArchitecture(args.spec)]
131+
logger.info(f"Using model spec: {spec}")
132+
model = fast_load(
133+
config=os.path.join(get_configs_path(), "inference", spec.config),
134+
ckpt=os.path.join("checkpoints", spec.ckpt),
135+
device=args.device,
136+
)
137+
138+
samples = run_txt2img(
139+
model=model,
140+
spec=spec,
141+
prompt=args.prompt,
142+
steps=args.steps,
143+
width=args.width,
144+
height=args.height,
145+
cfg_scale=args.cfg_scale,
146+
num_samples=args.num_samples,
147+
device=args.device,
148+
seed=args.seed,
149+
)
150+
151+
out_path = Path("outputs")
152+
out_path.mkdir(exist_ok=True)
153+
154+
prefix = int(time.time())
155+
156+
for i, sample in enumerate(samples, 1):
157+
filename = out_path / f"{prefix}-{i:04}.png"
158+
print(f"Saving {i}/{len(samples)}: {filename}")
159+
sample = 255.0 * einops.rearrange(sample, "c h w -> h w c")
160+
Image.fromarray(sample.cpu().numpy().astype(np.uint8)).save(filename)
161+
162+
163+
if __name__ == "__main__":
164+
main()

0 commit comments

Comments
 (0)