Skip to content

Commit 5722988

Browse files
authored
Merge pull request #44 from sail-sg/inv
support text invt for drag prompt
2 parents 2db58a9 + 554a6ed commit 5722988

File tree

6 files changed

+115
-54
lines changed

6 files changed

+115
-54
lines changed

editany_demo.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def create_demo_template(
8080
a_prompt = gr.Textbox(
8181
label="Positive Prompt",
8282
info="Text in the expected things of edited region",
83-
value="best quality, extremely detailed",
83+
value="best quality, extremely detailed,",
8484
)
8585
n_prompt = gr.Textbox(
8686
label="Negative Prompt",
@@ -177,7 +177,7 @@ def create_demo_template(
177177
label="SAM Control Scale",
178178
minimum=0,
179179
maximum=1.0,
180-
value=1.0,
180+
value=0.3,
181181
step=0.1,
182182
)
183183
ref_inpaint_scale = gr.Slider(
@@ -187,6 +187,15 @@ def create_demo_template(
187187
value=0.2,
188188
step=0.1,
189189
)
190+
with gr.Row():
191+
ref_textinv = gr.Checkbox(
192+
label="Use textual inversion token", value=False
193+
)
194+
ref_textinv_path = gr.Textbox(
195+
label="textual inversion token path",
196+
info="Text in the inversion token path",
197+
value=None,
198+
)
190199

191200
with gr.Accordion("Advanced options", open=False):
192201
mask_image = gr.Image(
@@ -277,6 +286,8 @@ def create_demo_template(
277286
ref_sam_scale,
278287
ref_inpaint_scale,
279288
ref_auto_prompt,
289+
ref_textinv,
290+
ref_textinv_path,
280291
]
281292
run_button.click(
282293
fn=process,
@@ -321,6 +332,8 @@ def create_demo_template(
321332
ref_sam_scale,
322333
ref_inpaint_scale,
323334
ref_auto_prompt,
335+
ref_textinv,
336+
ref_textinv_path,
324337
]
325338

326339
run_button_click.click(

editany_lora.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -600,6 +600,8 @@ def process(
600600
ref_sam_scale=None,
601601
ref_inpaint_scale=None,
602602
ref_auto_prompt=False,
603+
ref_textinv=True,
604+
ref_textinv_path=None,
603605
):
604606

605607
if condition_model is None or condition_model == "EditAnything":
@@ -652,7 +654,7 @@ def process(
652654
)
653655
self.default_controlnet_path = this_controlnet_path
654656
torch.cuda.empty_cache()
655-
if self.last_ref_infer and ref_image is None:
657+
if self.last_ref_infer:
656658
print("Redefine the model to overwrite the ref mode")
657659
self.pipe = obtain_generation_model(
658660
self.base_model_path,
@@ -661,11 +663,12 @@ def process(
661663
enable_all_generate,
662664
self.extra_inpaint,
663665
)
666+
self.last_ref_infer = False
664667

665668
if ref_image is not None:
666669
ref_mask = ref_image["mask"]
667670
ref_image = ref_image["image"]
668-
if ref_auto_prompt:
671+
if ref_auto_prompt or ref_textinv:
669672
bbox = get_bounding_box(
670673
np.array(ref_mask) / 255
671674
) # reverse the mask to make 1 the choosen region
@@ -680,13 +683,27 @@ def process(
680683
cropped_ref_image = Image.fromarray(
681684
cropped_ref_image.astype("uint8"))
682685

686+
if ref_auto_prompt:
683687
generated_prompt = self.get_blip2_text(cropped_ref_image)
684688
ref_prompt += generated_prompt
685689
a_prompt += generated_prompt
686690
print("Generated ref text:", ref_prompt)
687691
print("Generated input text:", a_prompt)
692+
self.last_ref_infer = True
688693
# ref_image = cropped_ref_image
689694
# ref_mask = cropped_ref_mask
695+
if ref_textinv:
696+
try:
697+
self.pipe.load_textual_inversion(ref_textinv_path)
698+
print("Load textinv embedding from:", ref_textinv_path)
699+
except:
700+
print("No textinvert embeddings found.")
701+
ref_data_path = "./utils/tmp/textinv/img"
702+
if not os.path.exists(ref_data_path):
703+
os.makedirs(ref_data_path)
704+
cropped_ref_image.save(os.path.join(ref_data_path, 'ref.png'))
705+
print("Ref image region is save to:", ref_data_path)
706+
print("Plese finetune with run_texutal_inversion.sh in utils folder to get the textinvert embeddings.")
690707

691708
else:
692709
ref_mask = None

environment.yaml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ dependencies:
1010
- torchvision=0.14.1
1111
- numpy=1.23.1
1212
- pip:
13-
- gradio==3.16.2
13+
- gradio==3.35.2
1414
- albumentations==1.3.0
1515
- opencv-contrib-python==4.3.0.36
1616
- imageio==2.9.0
@@ -32,6 +32,7 @@ dependencies:
3232
- prettytable==3.6.0
3333
- safetensors==0.2.7
3434
- basicsr==1.4.2
35-
- diffusers==0.14.0
35+
- diffusers==0.17.1
3636
- accelerate==0.17.0
37-
- transformers==4.27.4
37+
- transformers==4.30.2
38+
- xformers

utils/run_texutal_inversion.sh

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
export MODEL_NAME="runwayml/stable-diffusion-v1-5"
2+
export DATA_DIR="./tmp/textinv/img"
3+
export OUTPUT_DIR="./tmp/textinv/model"
4+
5+
CUDA_VISIBLE_DEVICES=0 accelerate launch --main_process_port 1111 texutal_inversion.py \
6+
--pretrained_model_name_or_path=$MODEL_NAME \
7+
--train_data_dir=$DATA_DIR \
8+
--learnable_property="object" \
9+
--placeholder_token="<new-obj>" --initializer_token="mark" \
10+
--resolution=512 \
11+
--train_batch_size=4 \
12+
--gradient_accumulation_steps=1 \
13+
--max_train_steps=3000 \
14+
--learning_rate=5.0e-04 --scale_lr \
15+
--lr_scheduler="constant" \
16+
--lr_warmup_steps=0 \
17+
--output_dir=$OUTPUT_DIR \
18+
--num_vectors 10

utils/stable_diffusion_reference.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -468,7 +468,7 @@ def hack_CrossAttnDownBlock2D_forward(
468468
and len(self.mean_bank0) > 0
469469
and len(self.var_bank0) > 0
470470
):
471-
print("hacked_CrossAttnDownBlock2D_forward0")
471+
# print("hacked_CrossAttnDownBlock2D_forward0")
472472
scale_ratio = self.inpaint_mask.shape[2] / \
473473
hidden_states.shape[2]
474474
this_inpaint_mask = F.interpolate(
@@ -548,7 +548,7 @@ def hack_CrossAttnDownBlock2D_forward(
548548
and len(self.mean_bank) > 0
549549
and len(self.var_bank) > 0
550550
):
551-
print("hack_CrossAttnDownBlock2D_forward")
551+
# print("hack_CrossAttnDownBlock2D_forward")
552552
scale_ratio = self.inpaint_mask.shape[2] / \
553553
hidden_states.shape[2]
554554
this_inpaint_mask = F.interpolate(
@@ -645,7 +645,7 @@ def hacked_DownBlock2D_forward(self, hidden_states, temb=None):
645645
and len(self.mean_bank) > 0
646646
and len(self.var_bank) > 0
647647
):
648-
print("hacked_DownBlock2D_forward")
648+
# print("hacked_DownBlock2D_forward")
649649
scale_ratio = self.inpaint_mask.shape[2] / \
650650
hidden_states.shape[2]
651651
this_inpaint_mask = F.interpolate(
@@ -753,7 +753,7 @@ def hacked_CrossAttnUpBlock2D_forward(
753753
and len(self.mean_bank0) > 0
754754
and len(self.var_bank0) > 0
755755
):
756-
print("hacked_CrossAttnUpBlock2D_forward1")
756+
# print("hacked_CrossAttnUpBlock2D_forward1")
757757
scale_ratio = self.inpaint_mask.shape[2] / \
758758
hidden_states.shape[2]
759759
this_inpaint_mask = F.interpolate(
@@ -835,7 +835,7 @@ def hacked_CrossAttnUpBlock2D_forward(
835835
and len(self.mean_bank) > 0
836836
and len(self.var_bank) > 0
837837
):
838-
print("hacked_CrossAttnUpBlock2D_forward")
838+
# print("hacked_CrossAttnUpBlock2D_forward")
839839
scale_ratio = self.inpaint_mask.shape[2] / \
840840
hidden_states.shape[2]
841841
this_inpaint_mask = F.interpolate(
@@ -932,7 +932,7 @@ def hacked_UpBlock2D_forward(
932932
and len(self.mean_bank) > 0
933933
and len(self.var_bank) > 0
934934
):
935-
print("hacked_UpBlock2D_forward")
935+
# print("hacked_UpBlock2D_forward")
936936
scale_ratio = self.inpaint_mask.shape[2] / \
937937
hidden_states.shape[2]
938938
this_inpaint_mask = F.interpolate(

utils/texutal_inversion.py

Lines changed: 53 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import math
1919
import os
2020
import random
21+
import shutil
2122
import warnings
2223
from pathlib import Path
2324

@@ -77,7 +78,7 @@
7778

7879

7980
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
80-
# check_min_version("0.17.0.dev0")
81+
# check_min_version("0.18.0.dev0")
8182

8283
logger = get_logger(__name__)
8384

@@ -394,11 +395,7 @@ def parse_args():
394395
"--checkpoints_total_limit",
395396
type=int,
396397
default=None,
397-
help=(
398-
"Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."
399-
" See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"
400-
" for more docs"
401-
),
398+
help=("Max number of checkpoints to store."),
402399
)
403400
parser.add_argument(
404401
"--resume_from_checkpoint",
@@ -423,38 +420,36 @@ def parse_args():
423420

424421
return args
425422

423+
426424
imagenet_templates_small = [
427425
"a photo of a {}",
426+
"a rendering of a {}",
427+
"a cropped photo of the {}",
428+
"the photo of a {}",
429+
"a photo of a clean {}",
430+
"a photo of a dirty {}",
431+
"a dark photo of the {}",
432+
"a photo of my {}",
433+
"a photo of the cool {}",
434+
"a close-up photo of a {}",
435+
"a bright photo of the {}",
436+
"a cropped photo of a {}",
437+
"a photo of the {}",
438+
"a good photo of the {}",
439+
"a photo of one {}",
440+
"a close-up photo of the {}",
441+
"a rendition of the {}",
442+
"a photo of the clean {}",
443+
"a rendition of a {}",
444+
"a photo of a nice {}",
445+
"a good photo of a {}",
446+
"a photo of the nice {}",
447+
"a photo of the small {}",
448+
"a photo of the weird {}",
449+
"a photo of the large {}",
450+
"a photo of a cool {}",
451+
"a photo of a small {}",
428452
]
429-
# imagenet_templates_small = [
430-
# "a photo of a {}",
431-
# "a rendering of a {}",
432-
# "a cropped photo of the {}",
433-
# "the photo of a {}",
434-
# "a photo of a clean {}",
435-
# "a photo of a dirty {}",
436-
# "a dark photo of the {}",
437-
# "a photo of my {}",
438-
# "a photo of the cool {}",
439-
# "a close-up photo of a {}",
440-
# "a bright photo of the {}",
441-
# "a cropped photo of a {}",
442-
# "a photo of the {}",
443-
# "a good photo of the {}",
444-
# "a photo of one {}",
445-
# "a close-up photo of the {}",
446-
# "a rendition of the {}",
447-
# "a photo of the clean {}",
448-
# "a rendition of a {}",
449-
# "a photo of a nice {}",
450-
# "a good photo of a {}",
451-
# "a photo of the nice {}",
452-
# "a photo of the small {}",
453-
# "a photo of the weird {}",
454-
# "a photo of the large {}",
455-
# "a photo of a cool {}",
456-
# "a photo of a small {}",
457-
# ]
458453

459454
imagenet_style_templates_small = [
460455
"a painting in the style of {}",
@@ -568,14 +563,11 @@ def __getitem__(self, i):
568563
def main():
569564
args = parse_args()
570565
logging_dir = os.path.join(args.output_dir, args.logging_dir)
571-
572-
accelerator_project_config = ProjectConfiguration(total_limit=args.checkpoints_total_limit)
573-
566+
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
574567
accelerator = Accelerator(
575568
gradient_accumulation_steps=args.gradient_accumulation_steps,
576569
mixed_precision=args.mixed_precision,
577570
log_with=args.report_to,
578-
logging_dir=logging_dir,
579571
project_config=accelerator_project_config,
580572
)
581573

@@ -755,8 +747,8 @@ def main():
755747
text_encoder, optimizer, train_dataloader, lr_scheduler
756748
)
757749

758-
# For mixed precision training we cast the unet and vae weights to half-precision
759-
# as these models are only used for inference, keeping weights in full precision is not required.
750+
# For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora unet) to half-precision
751+
# as these weights are only used for inference, keeping weights in full precision is not required.
760752
weight_dtype = torch.float32
761753
if accelerator.mixed_precision == "fp16":
762754
weight_dtype = torch.float16
@@ -890,6 +882,26 @@ def main():
890882

891883
if accelerator.is_main_process:
892884
if global_step % args.checkpointing_steps == 0:
885+
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
886+
if args.checkpoints_total_limit is not None:
887+
checkpoints = os.listdir(args.output_dir)
888+
checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
889+
checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
890+
891+
# before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
892+
if len(checkpoints) >= args.checkpoints_total_limit:
893+
num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
894+
removing_checkpoints = checkpoints[0:num_to_remove]
895+
896+
logger.info(
897+
f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
898+
)
899+
logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
900+
901+
for removing_checkpoint in removing_checkpoints:
902+
removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
903+
shutil.rmtree(removing_checkpoint)
904+
893905
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
894906
accelerator.save_state(save_path)
895907
logger.info(f"Saved state to {save_path}")

0 commit comments

Comments
 (0)