@@ -49,6 +49,9 @@ class CleanFIDEvaluator:
49
49
precision (str): The precision to use for evaluation. Default: ``'amp_fp16'``.
50
50
prompts (List[str], optional): The prompts to use for image visualtization.
51
51
Default: ``["A shiba inu wearing a blue sweater]``.
52
+ default_prompt (Optional[str]): An optional default prompt to add before each eval prompt. Default: ``None``.
53
+ default_negative_prompt (Optional[str]): An optional default negative prompt to add before each
54
+ negative prompt. Default: ``None``.
52
55
additional_generate_kwargs (Dict, optional): Additional keyword arguments to pass to the model.generate method.
53
56
54
57
"""
@@ -70,6 +73,8 @@ def __init__(self,
70
73
num_samples : Optional [int ] = None ,
71
74
precision : str = 'amp_fp16' ,
72
75
prompts : Optional [List [str ]] = None ,
76
+ default_prompt : Optional [str ] = None ,
77
+ default_negative_prompt : Optional [str ] = None ,
73
78
additional_generate_kwargs : Optional [Dict ] = None ):
74
79
self .model = model
75
80
self .tokenizer : PreTrainedTokenizerBase = model .tokenizer
@@ -87,6 +92,8 @@ def __init__(self,
87
92
self .num_samples = num_samples if num_samples is not None else float ('inf' )
88
93
self .precision = precision
89
94
self .prompts = prompts if prompts is not None else ['A shiba inu wearing a blue sweater' ]
95
+ self .default_prompt = default_prompt
96
+ self .default_negative_prompt = default_negative_prompt
90
97
self .additional_generate_kwargs = additional_generate_kwargs if additional_generate_kwargs is not None else {}
91
98
self .sdxl = model .sdxl
92
99
@@ -141,7 +148,17 @@ def _generate_images(self, guidance_scale: float):
141
148
break
142
149
143
150
real_images = batch [self .image_key ]
144
- captions = batch [self .caption_key ]
151
+ tokenized_captions = batch [self .caption_key ]
152
+ # Get the prompts from the tokens
153
+ text_captions = self .tokenizer .batch_decode (tokenized_captions , skip_special_tokens = True )
154
+ # Add default prompts if specified
155
+ augmented_captions = text_captions
156
+ augmented_negative_prompt = None
157
+ if self .default_prompt :
158
+ augmented_captions = [f'{ self .default_prompt } { caption } ' for caption in text_captions ]
159
+ if self .default_negative_prompt :
160
+ augmented_negative_prompt = [f'{ self .default_negative_prompt } ' for _ in text_captions ]
161
+
145
162
if self .sdxl :
146
163
crop_params = batch ['cond_crops_coords_top_left' ]
147
164
input_size_params = batch ['cond_original_size' ]
@@ -153,7 +170,8 @@ def _generate_images(self, guidance_scale: float):
153
170
seed = starting_seed + batch_id
154
171
# Generate images from the captions
155
172
with get_precision_context (self .precision ):
156
- generated_images = self .model .generate (tokenized_prompts = captions ,
173
+ generated_images = self .model .generate (prompt = augmented_captions ,
174
+ negative_prompt = augmented_negative_prompt ,
157
175
height = self .size ,
158
176
width = self .size ,
159
177
guidance_scale = guidance_scale ,
@@ -162,8 +180,6 @@ def _generate_images(self, guidance_scale: float):
162
180
input_size_params = input_size_params ,
163
181
progress_bar = False ,
164
182
** self .additional_generate_kwargs ) # type: ignore
165
- # Get the prompts from the tokens
166
- text_captions = self .tokenizer .batch_decode (captions , skip_special_tokens = True )
167
183
self .clip_metric .update ((generated_images * 255 ).to (torch .uint8 ), text_captions )
168
184
# Save the real images
169
185
# Verify that the real images are in the proper range
@@ -233,8 +249,23 @@ def _compute_metrics(self, guidance_scale: float):
233
249
def _generate_images_from_prompts (self , guidance_scale : float ):
234
250
"""Generate images from prompts for visualization."""
235
251
if self .prompts :
252
+ # Augment the prompt
253
+ augmented_prompts = self .prompts
254
+ if self .default_prompt :
255
+ augmented_prompts = [f'{ self .default_prompt } { prompt } ' for prompt in self .prompts ]
256
+ # Augment the negative prompt
257
+ augmented_negative_prompts = None
258
+ if 'negative prompt' in self .additional_generate_kwargs :
259
+ negative_prompts = self .additional_generate_kwargs ['negative prompt' ]
260
+ augmented_negative_prompts = [
261
+ f'{ self .default_negative_prompt } { neg_prompt } ' for neg_prompt in negative_prompts
262
+ ]
263
+ if self .default_negative_prompt and augmented_negative_prompts is None :
264
+ augmented_negative_prompts = [f'{ self .default_negative_prompt } ' for _ in self .prompts ]
265
+
236
266
with get_precision_context (self .precision ):
237
- generated_images = self .model .generate (prompt = self .prompts ,
267
+ generated_images = self .model .generate (prompt = augmented_prompts ,
268
+ negative_prompt = augmented_negative_prompts ,
238
269
height = self .size ,
239
270
width = self .size ,
240
271
guidance_scale = guidance_scale ,
0 commit comments