22
22
('DPM++ 2M' , 'sample_dpmpp_2m' , ['k_dpmpp_2m' ], {}),
23
23
('DPM++ SDE' , 'sample_dpmpp_sde' , ['k_dpmpp_sde' ], {"second_order" : True , "brownian_noise" : True }),
24
24
('DPM++ 2M SDE' , 'sample_dpmpp_2m_sde' , ['k_dpmpp_2m_sde_ka' ], {"brownian_noise" : True }),
25
+ ('DPM++ 2M SDE Heun' , 'sample_dpmpp_2m_sde' , ['k_dpmpp_2m_sde_heun' ], {"brownian_noise" : True , "solver_type" : "heun" }),
26
+ ('DPM++ 2M SDE Heun Karras' , 'sample_dpmpp_2m_sde' , ['k_dpmpp_2m_sde_heun_ka' ], {'scheduler' : 'karras' , "brownian_noise" : True , "solver_type" : "heun" }),
27
+ ('DPM++ 2M SDE Heun Exponential' , 'sample_dpmpp_2m_sde' , ['k_dpmpp_2m_sde_heun_exp' ], {'scheduler' : 'exponential' , "brownian_noise" : True , "solver_type" : "heun" }),
25
28
('DPM++ 3M SDE' , 'sample_dpmpp_3m_sde' , ['k_dpmpp_3m_sde' ], {'discard_next_to_last_sigma' : True , "brownian_noise" : True }),
26
29
('DPM++ 3M SDE Karras' , 'sample_dpmpp_3m_sde' , ['k_dpmpp_3m_sde_ka' ], {'scheduler' : 'karras' , 'discard_next_to_last_sigma' : True , "brownian_noise" : True }),
27
30
('DPM++ 3M SDE Exponential' , 'sample_dpmpp_3m_sde' , ['k_dpmpp_3m_sde_exp' ], {'scheduler' : 'exponential' , 'discard_next_to_last_sigma' : True , "brownian_noise" : True }),
@@ -161,6 +164,9 @@ def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning,
161
164
noise_sampler = self .create_noise_sampler (x , sigmas , p )
162
165
extra_params_kwargs ['noise_sampler' ] = noise_sampler
163
166
167
+ if self .config .options .get ('solver_type' , None ) == 'heun' :
168
+ extra_params_kwargs ['solver_type' ] = 'heun'
169
+
164
170
self .model_wrap_cfg .init_latent = x
165
171
self .last_latent = x
166
172
self .sampler_extra_args = {
@@ -202,6 +208,9 @@ def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, ima
202
208
noise_sampler = self .create_noise_sampler (x , sigmas , p )
203
209
extra_params_kwargs ['noise_sampler' ] = noise_sampler
204
210
211
+ if self .config .options .get ('solver_type' , None ) == 'heun' :
212
+ extra_params_kwargs ['solver_type' ] = 'heun'
213
+
205
214
self .last_latent = x
206
215
self .sampler_extra_args = {
207
216
'cond' : conditioning ,
@@ -210,6 +219,7 @@ def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, ima
210
219
'cond_scale' : p .cfg_scale ,
211
220
's_min_uncond' : self .s_min_uncond
212
221
}
222
+
213
223
samples = self .launch_sampling (steps , lambda : self .func (self .model_wrap_cfg , x , extra_args = self .sampler_extra_args , disable = False , callback = self .callback_state , ** extra_params_kwargs ))
214
224
215
225
if self .model_wrap_cfg .padded_cond_uncond :
0 commit comments