Skip to content

Commit cd183d9

Browse files
xiaomilezengyh1900liuwenran
authored
[MMSIG] Add new configuration files for StyleGAN2 (#2057)
* 1st * debug * 20230710 调整 * 调整代码,整合模型,避免editors import 过多class * 调整代码,整合模型,避免editors import 过多class * 支持 DeblurGANv2 inference demo 示例 python mmagic\demo\mmagic_inference_demo.py --model-name deblurganv2 --model-comfig ../configs/deblurganv2/deblurganv2_fpn_inception.py --model-ckpt 权重文件路径 --img 测试图片路径 --device cpu --result-out-dir ./out.png * 支持 DeblurGANv2 inference fix CI test * 支持 DeblurGANv2 inference fix CI test * 支持 DeblurGANv2 inference fix CI test * 支持 DeblurGANv2 inference update model-index * 支持 DeblurGANv2 inference Fix CI Test * 支持 DeblurGANv2 inference CI test fix and update readme.md * 支持 DeblurGANv2 inference fix CI test and update readme.md * 支持 DeblurGANv2 inference Fix CI test * 支持 DeblurGANv2 inference yapf 修正 * 支持 DeblurGANv2 inference 代码调整,参数名保持一致 * 支持 DeblurGANv2 inference doc string coverage * 支持 DeblurGANv2 inference add some doc string * 支持 DeblurGANv2 inference add some doc string * 支持 DeblurGANv2 inference * 支持 DeblurGANv2 inference * 支持 DeblurGANv2 inference * 支持 DeblurGANv2 inference * 支持 DeblurGANv2 inference add unit test * 支持 DeblurGANv2 inference add unit test * 支持 DeblurGANv2 inference add unit test * 支持 DeblurGANv2 inference fix unit test * Update .gitignore Co-authored-by: Yanhong Zeng <[email protected]> * Update .gitignore Co-authored-by: Yanhong Zeng <[email protected]> * Update .gitignore Co-authored-by: Yanhong Zeng <[email protected]> * Update .gitignore Co-authored-by: Yanhong Zeng <[email protected]> * Update .gitignore Co-authored-by: Yanhong Zeng <[email protected]> * Update .gitignore Co-authored-by: Yanhong Zeng <[email protected]> * Update .gitignore Co-authored-by: Yanhong Zeng <[email protected]> * Update configs/deblurganv2/README.md Co-authored-by: Yanhong Zeng <[email protected]> * 支持 DeblurGANv2 inference move the implementation of loss function to mmagic/models/losses add quick start to readme * 支持 DeblurGANv2 inference fix unit test * 支持 DeblurGANv2 inference re run unit test * Update configs/deblurganv2/deblurganv2_fpn-inception_1xb1_gopro.py Co-authored-by: Yanhong Zeng <[email protected]> * Update configs/deblurganv2/deblurganv2_fpn-inception_1xb1_gopro.py Co-authored-by: Yanhong Zeng <[email protected]> * Update configs/deblurganv2/deblurganv2_fpn-inception_1xb1_gopro.py Co-authored-by: Yanhong Zeng <[email protected]> * Update configs/deblurganv2/deblurganv2_fpn-inception_1xb1_gopro.py Co-authored-by: Yanhong Zeng <[email protected]> * Update configs/deblurganv2/deblurganv2_fpn-mobilenet_1xb1_gopro.py Co-authored-by: Yanhong Zeng <[email protected]> * Update configs/deblurganv2/deblurganv2_fpn-mobilenet_1xb1_gopro.py Co-authored-by: Yanhong Zeng <[email protected]> * Update configs/deblurganv2/deblurganv2_fpn-mobilenet_1xb1_gopro.py Co-authored-by: Yanhong Zeng <[email protected]> * Update configs/deblurganv2/deblurganv2_fpn-mobilenet_1xb1_gopro.py Co-authored-by: Yanhong Zeng <[email protected]> * 支持 DeblurGANv2 inference fix some url and path add README_zh-CN add Deblurring task into mmagic/apis/inferencers/__init__.py * Adding support for FastComposer 支持 FastComposer * Adding support for FastComposer Add some doc string and fix bugs * Adding support for FastComposer Fixed a bug * Adding support for FastComposer fix a bug * Adding support for FastComposer fix some bugs * Adding support for FastComposer fix some bugs * Adding support for FastComposer fix a bug * Adding support for FastComposer fix a bug * Adding support for FastComposer change for minimum version cpu check * Adding support for FastComposer avoid a windows CI bug which complains not enough memory. * Adding support for FastComposer rerun circleci check * Adding support for FastComposer add example code which run without gradio to readme add config of clip for running unittest without using "clip_vit_url = 'openai/clip-vit-large-patch14' " * Adding support for FastComposer rerun checks of build cu102 * Adding support for FastComposer a small change * Adding support for FastComposer some small changes * Adding support for FastComposer add some simple instructions to demo/README.md * Adding support for FastComposer resolve conflicts * Adding support for FastComposer rerun checks * Adding support for FastComposer add device for running with cuda by default * Adding support for Consistency Models * Adding support for Consistency Models * Update README.md * Adding support for Consistency Models * Adding support for Consistency Models mdformat debug * Adding support for Consistency Models * Adding support for Consistency Models add some doc string * Adding support for Consistency Models re run circle check * Adding support for Consistency Models rerun circle check * [FIX] Check circle ci memory add function teardown_module to test_fastcomposer * Adding support for Consistency Models rerun ci check * Add new configuration files for StyleGAN2 * Revert "Add new configuration files for StyleGAN2" This reverts commit a043b22. * Add new configuration files for StyleGAN2 * fix config-vaildate error * fix a bug * delete code of consistency model * delete code which in another pr * delete code which in another pr * Add new configuration files for StyleGAN2 rerun ci check * ci check memory --------- Co-authored-by: Yanhong Zeng <[email protected]> Co-authored-by: rangoliu <[email protected]>
1 parent 73d7b01 commit cd183d9

10 files changed

+619
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
from mmengine.config import read_base
3+
from torch.optim import Adam
4+
5+
from mmagic.engine import VisualizationHook
6+
from mmagic.evaluation import (FrechetInceptionDistance, PerceptualPathLength,
7+
PrecisionAndRecall)
8+
from mmagic.models import BaseGAN
9+
10+
with read_base():
11+
from .._base_.datasets.ffhq_flip import * # noqa: F403,F405
12+
from .._base_.gen_default_runtime import * # noqa: F403,F405
13+
from .._base_.models.base_styleganv2 import * # noqa: F403,F405
14+
15+
# reg params
16+
d_reg_interval = 16
17+
g_reg_interval = 4
18+
19+
g_reg_ratio = g_reg_interval / (g_reg_interval + 1)
20+
d_reg_ratio = d_reg_interval / (d_reg_interval + 1)
21+
22+
ema_half_life = 10. # G_smoothing_kimg
23+
24+
model.update(
25+
generator=dict(out_size=256),
26+
discriminator=dict(in_size=256),
27+
ema_config=dict(
28+
type=ExponentialMovingAverage,
29+
interval=1,
30+
momentum=1. - (0.5**(32. / (ema_half_life * 1000.)))),
31+
loss_config=dict(
32+
r1_loss_weight=10. / 2. * d_reg_interval,
33+
r1_interval=d_reg_interval,
34+
norm_mode='HWC',
35+
g_reg_interval=g_reg_interval,
36+
g_reg_weight=2. * g_reg_interval,
37+
pl_batch_shrink=2))
38+
39+
train_cfg.update(max_iters=800002)
40+
41+
optim_wrapper.update(
42+
generator=dict(
43+
optimizer=dict(
44+
type=Adam, lr=0.002 * g_reg_ratio, betas=(0, 0.99**g_reg_ratio))),
45+
discriminator=dict(
46+
optimizer=dict(
47+
type=Adam, lr=0.002 * d_reg_ratio, betas=(0, 0.99**d_reg_ratio))))
48+
49+
batch_size = 4
50+
data_root = './data/ffhq/ffhq_imgs/ffhq_256'
51+
52+
train_dataloader.update(
53+
batch_size=batch_size, dataset=dict(data_root=data_root))
54+
55+
val_dataloader.update(batch_size=batch_size, dataset=dict(data_root=data_root))
56+
57+
test_dataloader.update(
58+
batch_size=batch_size, dataset=dict(data_root=data_root))
59+
60+
# VIS_HOOK
61+
custom_hooks = [
62+
dict(
63+
type=VisualizationHook,
64+
interval=5000,
65+
fixed_input=True,
66+
vis_kwargs_list=dict(type=BaseGAN, name='fake_img'))
67+
]
68+
69+
# METRICS
70+
metrics = [
71+
dict(
72+
type=FrechetInceptionDistance,
73+
prefix='FID-50k',
74+
fake_nums=50000,
75+
real_nums=50000,
76+
inception_style='StyleGAN',
77+
sample_model='ema'),
78+
dict(type=PrecisionAndRecall, fake_nums=50000, prefix='PR-50K'),
79+
dict(type=PerceptualPathLength, fake_nums=50000, prefix='ppl-w')
80+
]
81+
# NOTE: config for save multi best checkpoints
82+
# default_hooks.update(
83+
# checkpoint=dict(
84+
# save_best=['FID-Full-50k/fid', 'IS-50k/is'],
85+
# rule=['less', 'greater']))
86+
default_hooks.update(checkpoint=dict(save_best='FID-50k/fid'))
87+
88+
val_evaluator.update(metrics=metrics)
89+
test_evaluator.update(metrics=metrics)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
from mmengine.config import read_base
3+
from torch.optim import Adam
4+
5+
from mmagic.engine import VisualizationHook
6+
from mmagic.evaluation import (FrechetInceptionDistance, PerceptualPathLength,
7+
PrecisionAndRecall)
8+
from mmagic.models import BaseGAN
9+
10+
with read_base():
11+
from .._base_.datasets.lsun_stylegan import * # noqa: F403,F405
12+
from .._base_.gen_default_runtime import * # noqa: F403,F405
13+
from .._base_.models.base_styleganv2 import * # noqa: F403,F405
14+
15+
# reg params
16+
d_reg_interval = 16
17+
g_reg_interval = 4
18+
19+
g_reg_ratio = g_reg_interval / (g_reg_interval + 1)
20+
d_reg_ratio = d_reg_interval / (d_reg_interval + 1)
21+
22+
ema_half_life = 10. # G_smoothing_kimg
23+
24+
model.update(
25+
generator=dict(out_size=256),
26+
discriminator=dict(in_size=256),
27+
ema_config=dict(
28+
type=ExponentialMovingAverage,
29+
interval=1,
30+
momentum=1. - (0.5**(32. / (ema_half_life * 1000.)))),
31+
loss_config=dict(
32+
r1_loss_weight=10. / 2. * d_reg_interval,
33+
r1_interval=d_reg_interval,
34+
norm_mode='HWC',
35+
g_reg_interval=g_reg_interval,
36+
g_reg_weight=2. * g_reg_interval,
37+
pl_batch_shrink=2))
38+
39+
train_cfg.update(max_iters=800002)
40+
41+
optim_wrapper.update(
42+
generator=dict(
43+
optimizer=dict(
44+
type=Adam, lr=0.002 * g_reg_ratio, betas=(0, 0.99**g_reg_ratio))),
45+
discriminator=dict(
46+
optimizer=dict(
47+
type=Adam, lr=0.002 * d_reg_ratio, betas=(0, 0.99**d_reg_ratio))))
48+
49+
batch_size = 4
50+
data_root = './data/lsun-cat'
51+
52+
train_dataloader.update(
53+
batch_size=batch_size, dataset=dict(data_root=data_root))
54+
55+
val_dataloader.update(batch_size=batch_size, dataset=dict(data_root=data_root))
56+
57+
test_dataloader.update(
58+
batch_size=batch_size, dataset=dict(data_root=data_root))
59+
60+
# VIS_HOOK
61+
custom_hooks = [
62+
dict(
63+
type=VisualizationHook,
64+
interval=5000,
65+
fixed_input=True,
66+
vis_kwargs_list=dict(type=BaseGAN, name='fake_img'))
67+
]
68+
69+
# METRICS
70+
metrics = [
71+
dict(
72+
type=FrechetInceptionDistance,
73+
prefix='FID-Full-50k',
74+
fake_nums=50000,
75+
inception_style='StyleGAN',
76+
sample_model='ema'),
77+
dict(type=PrecisionAndRecall, fake_nums=50000, prefix='PR-50K'),
78+
dict(type=PerceptualPathLength, fake_nums=50000, prefix='ppl-w')
79+
]
80+
# NOTE: config for save multi best checkpoints
81+
# default_hooks.update(
82+
# checkpoint=dict(
83+
# save_best=['FID-Full-50k/fid', 'IS-50k/is'],
84+
# rule=['less', 'greater']))
85+
default_hooks.update(checkpoint=dict(save_best='FID-Full-50k/fid'))
86+
87+
val_evaluator.update(metrics=metrics)
88+
test_evaluator.update(metrics=metrics)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
from mmengine.config import read_base
3+
from torch.optim import Adam
4+
5+
from mmagic.engine import VisualizationHook
6+
from mmagic.evaluation import (FrechetInceptionDistance, PerceptualPathLength,
7+
PrecisionAndRecall)
8+
from mmagic.models import BaseGAN
9+
10+
with read_base():
11+
from .._base_.datasets.lsun_stylegan import * # noqa: F403,F405
12+
from .._base_.gen_default_runtime import * # noqa: F403,F405
13+
from .._base_.models.base_styleganv2 import * # noqa: F403,F405
14+
15+
# reg params
16+
d_reg_interval = 16
17+
g_reg_interval = 4
18+
19+
g_reg_ratio = g_reg_interval / (g_reg_interval + 1)
20+
d_reg_ratio = d_reg_interval / (d_reg_interval + 1)
21+
22+
ema_half_life = 10. # G_smoothing_kimg
23+
24+
model.update(
25+
generator=dict(out_size=256),
26+
discriminator=dict(in_size=256),
27+
ema_config=dict(
28+
type=ExponentialMovingAverage,
29+
interval=1,
30+
momentum=1. - (0.5**(32. / (ema_half_life * 1000.)))),
31+
loss_config=dict(
32+
r1_loss_weight=10. / 2. * d_reg_interval,
33+
r1_interval=d_reg_interval,
34+
norm_mode='HWC',
35+
g_reg_interval=g_reg_interval,
36+
g_reg_weight=2. * g_reg_interval,
37+
pl_batch_shrink=2))
38+
39+
train_cfg.update(max_iters=800002)
40+
41+
optim_wrapper.update(
42+
generator=dict(
43+
optimizer=dict(
44+
type=Adam, lr=0.002 * g_reg_ratio, betas=(0, 0.99**g_reg_ratio))),
45+
discriminator=dict(
46+
optimizer=dict(
47+
type=Adam, lr=0.002 * d_reg_ratio, betas=(0, 0.99**d_reg_ratio))))
48+
49+
batch_size = 4
50+
data_root = './data/lsun-church'
51+
52+
train_dataloader.update(
53+
batch_size=batch_size, dataset=dict(data_root=data_root))
54+
55+
val_dataloader.update(batch_size=batch_size, dataset=dict(data_root=data_root))
56+
57+
test_dataloader.update(
58+
batch_size=batch_size, dataset=dict(data_root=data_root))
59+
60+
# VIS_HOOK
61+
custom_hooks = [
62+
dict(
63+
type=VisualizationHook,
64+
interval=5000,
65+
fixed_input=True,
66+
vis_kwargs_list=dict(type=BaseGAN, name='fake_img'))
67+
]
68+
69+
# METRICS
70+
metrics = [
71+
dict(
72+
type=FrechetInceptionDistance,
73+
prefix='FID-Full-50k',
74+
fake_nums=50000,
75+
inception_style='StyleGAN',
76+
sample_model='ema'),
77+
dict(type=PrecisionAndRecall, fake_nums=50000, prefix='PR-50K'),
78+
dict(type=PerceptualPathLength, fake_nums=50000, prefix='ppl-w')
79+
]
80+
# NOTE: config for save multi best checkpoints
81+
# default_hooks.update(
82+
# checkpoint=dict(
83+
# save_best=['FID-Full-50k/fid', 'IS-50k/is'],
84+
# rule=['less', 'greater']))
85+
default_hooks.update(checkpoint=dict(save_best='FID-Full-50k/fid'))
86+
87+
val_evaluator.update(metrics=metrics)
88+
test_evaluator.update(metrics=metrics)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
from mmengine.config import read_base
3+
from torch.optim import Adam
4+
5+
from mmagic.engine import VisualizationHook
6+
from mmagic.evaluation import (FrechetInceptionDistance, PerceptualPathLength,
7+
PrecisionAndRecall)
8+
from mmagic.models import BaseGAN
9+
10+
with read_base():
11+
from .._base_.datasets.lsun_stylegan import * # noqa: F403,F405
12+
from .._base_.gen_default_runtime import * # noqa: F403,F405
13+
from .._base_.models.base_styleganv2 import * # noqa: F403,F405
14+
15+
# reg params
16+
d_reg_interval = 16
17+
g_reg_interval = 4
18+
19+
g_reg_ratio = g_reg_interval / (g_reg_interval + 1)
20+
d_reg_ratio = d_reg_interval / (d_reg_interval + 1)
21+
22+
ema_half_life = 10. # G_smoothing_kimg
23+
24+
model.update(
25+
generator=dict(out_size=256),
26+
discriminator=dict(in_size=256),
27+
ema_config=dict(
28+
type=ExponentialMovingAverage,
29+
interval=1,
30+
momentum=1. - (0.5**(32. / (ema_half_life * 1000.)))),
31+
loss_config=dict(
32+
r1_loss_weight=10. / 2. * d_reg_interval,
33+
r1_interval=d_reg_interval,
34+
norm_mode='HWC',
35+
g_reg_interval=g_reg_interval,
36+
g_reg_weight=2. * g_reg_interval,
37+
pl_batch_shrink=2))
38+
39+
train_cfg.update(max_iters=800002)
40+
41+
optim_wrapper.update(
42+
generator=dict(
43+
optimizer=dict(
44+
type=Adam, lr=0.002 * g_reg_ratio, betas=(0, 0.99**g_reg_ratio))),
45+
discriminator=dict(
46+
optimizer=dict(
47+
type=Adam, lr=0.002 * d_reg_ratio, betas=(0, 0.99**d_reg_ratio))))
48+
49+
batch_size = 4
50+
data_root = './data/lsun-horse'
51+
52+
train_dataloader.update(
53+
batch_size=batch_size, dataset=dict(data_root=data_root))
54+
55+
val_dataloader.update(batch_size=batch_size, dataset=dict(data_root=data_root))
56+
57+
test_dataloader.update(
58+
batch_size=batch_size, dataset=dict(data_root=data_root))
59+
60+
# VIS_HOOK
61+
custom_hooks = [
62+
dict(
63+
type=VisualizationHook,
64+
interval=5000,
65+
fixed_input=True,
66+
vis_kwargs_list=dict(type=BaseGAN, name='fake_img'))
67+
]
68+
69+
# METRICS
70+
metrics = [
71+
dict(
72+
type=FrechetInceptionDistance,
73+
prefix='FID-Full-50k',
74+
fake_nums=50000,
75+
inception_style='StyleGAN',
76+
sample_model='ema'),
77+
dict(type=PrecisionAndRecall, fake_nums=50000, prefix='PR-50K'),
78+
dict(type=PerceptualPathLength, fake_nums=50000, prefix='ppl-w')
79+
]
80+
# NOTE: config for save multi best checkpoints
81+
# default_hooks.update(
82+
# checkpoint=dict(
83+
# save_best=['FID-Full-50k/fid', 'IS-50k/is'],
84+
# rule=['less', 'greater']))
85+
default_hooks.update(checkpoint=dict(save_best='FID-Full-50k/fid'))
86+
87+
val_evaluator.update(metrics=metrics)
88+
test_evaluator.update(metrics=metrics)

0 commit comments

Comments
 (0)