Skip to content

Commit ab5a2f0

Browse files
authored
Add option for default prompts/negative prompts in eval (#148)
1 parent 6acffcd commit ab5a2f0

File tree

1 file changed

+36
-5
lines changed

1 file changed

+36
-5
lines changed

diffusion/evaluation/clean_fid_eval.py

Lines changed: 36 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,9 @@ class CleanFIDEvaluator:
4949
precision (str): The precision to use for evaluation. Default: ``'amp_fp16'``.
5050
prompts (List[str], optional): The prompts to use for image visualtization.
5151
Default: ``["A shiba inu wearing a blue sweater]``.
52+
default_prompt (Optional[str]): An optional default prompt to add before each eval prompt. Default: ``None``.
53+
default_negative_prompt (Optional[str]): An optional default negative prompt to add before each
54+
negative prompt. Default: ``None``.
5255
additional_generate_kwargs (Dict, optional): Additional keyword arguments to pass to the model.generate method.
5356
5457
"""
@@ -70,6 +73,8 @@ def __init__(self,
7073
num_samples: Optional[int] = None,
7174
precision: str = 'amp_fp16',
7275
prompts: Optional[List[str]] = None,
76+
default_prompt: Optional[str] = None,
77+
default_negative_prompt: Optional[str] = None,
7378
additional_generate_kwargs: Optional[Dict] = None):
7479
self.model = model
7580
self.tokenizer: PreTrainedTokenizerBase = model.tokenizer
@@ -87,6 +92,8 @@ def __init__(self,
8792
self.num_samples = num_samples if num_samples is not None else float('inf')
8893
self.precision = precision
8994
self.prompts = prompts if prompts is not None else ['A shiba inu wearing a blue sweater']
95+
self.default_prompt = default_prompt
96+
self.default_negative_prompt = default_negative_prompt
9097
self.additional_generate_kwargs = additional_generate_kwargs if additional_generate_kwargs is not None else {}
9198
self.sdxl = model.sdxl
9299

@@ -141,7 +148,17 @@ def _generate_images(self, guidance_scale: float):
141148
break
142149

143150
real_images = batch[self.image_key]
144-
captions = batch[self.caption_key]
151+
tokenized_captions = batch[self.caption_key]
152+
# Get the prompts from the tokens
153+
text_captions = self.tokenizer.batch_decode(tokenized_captions, skip_special_tokens=True)
154+
# Add default prompts if specified
155+
augmented_captions = text_captions
156+
augmented_negative_prompt = None
157+
if self.default_prompt:
158+
augmented_captions = [f'{self.default_prompt} {caption}' for caption in text_captions]
159+
if self.default_negative_prompt:
160+
augmented_negative_prompt = [f'{self.default_negative_prompt}' for _ in text_captions]
161+
145162
if self.sdxl:
146163
crop_params = batch['cond_crops_coords_top_left']
147164
input_size_params = batch['cond_original_size']
@@ -153,7 +170,8 @@ def _generate_images(self, guidance_scale: float):
153170
seed = starting_seed + batch_id
154171
# Generate images from the captions
155172
with get_precision_context(self.precision):
156-
generated_images = self.model.generate(tokenized_prompts=captions,
173+
generated_images = self.model.generate(prompt=augmented_captions,
174+
negative_prompt=augmented_negative_prompt,
157175
height=self.size,
158176
width=self.size,
159177
guidance_scale=guidance_scale,
@@ -162,8 +180,6 @@ def _generate_images(self, guidance_scale: float):
162180
input_size_params=input_size_params,
163181
progress_bar=False,
164182
**self.additional_generate_kwargs) # type: ignore
165-
# Get the prompts from the tokens
166-
text_captions = self.tokenizer.batch_decode(captions, skip_special_tokens=True)
167183
self.clip_metric.update((generated_images * 255).to(torch.uint8), text_captions)
168184
# Save the real images
169185
# Verify that the real images are in the proper range
@@ -233,8 +249,23 @@ def _compute_metrics(self, guidance_scale: float):
233249
def _generate_images_from_prompts(self, guidance_scale: float):
234250
"""Generate images from prompts for visualization."""
235251
if self.prompts:
252+
# Augment the prompt
253+
augmented_prompts = self.prompts
254+
if self.default_prompt:
255+
augmented_prompts = [f'{self.default_prompt} {prompt}' for prompt in self.prompts]
256+
# Augment the negative prompt
257+
augmented_negative_prompts = None
258+
if 'negative prompt' in self.additional_generate_kwargs:
259+
negative_prompts = self.additional_generate_kwargs['negative prompt']
260+
augmented_negative_prompts = [
261+
f'{self.default_negative_prompt} {neg_prompt}' for neg_prompt in negative_prompts
262+
]
263+
if self.default_negative_prompt and augmented_negative_prompts is None:
264+
augmented_negative_prompts = [f'{self.default_negative_prompt}' for _ in self.prompts]
265+
236266
with get_precision_context(self.precision):
237-
generated_images = self.model.generate(prompt=self.prompts,
267+
generated_images = self.model.generate(prompt=augmented_prompts,
268+
negative_prompt=augmented_negative_prompts,
238269
height=self.size,
239270
width=self.size,
240271
guidance_scale=guidance_scale,

0 commit comments

Comments
 (0)