Skip to content

Commit 485148b

Browse files
committed
fixed sdmain compiling
1 parent 2975ccd commit 485148b

File tree

1 file changed

+52
-85
lines changed

1 file changed

+52
-85
lines changed

otherarch/sdcpp/main.cpp

Lines changed: 52 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -54,15 +54,13 @@ const char* modes_str[] = {
5454
"txt2img",
5555
"img2img",
5656
"img2vid",
57-
"edit",
5857
"convert",
5958
};
6059

6160
enum SDMode {
6261
TXT2IMG,
6362
IMG2IMG,
6463
IMG2VID,
65-
EDIT,
6664
CONVERT,
6765
MODE_COUNT
6866
};
@@ -88,7 +86,8 @@ struct SDParams {
8886
std::string input_path;
8987
std::string mask_path;
9088
std::string control_image_path;
91-
std::vector<std::string> ref_image_paths;
89+
90+
std::vector<std::string> kontext_image_paths;
9291

9392
std::string prompt;
9493
std::string negative_prompt;
@@ -154,10 +153,6 @@ void print_params(SDParams params) {
154153
printf(" init_img: %s\n", params.input_path.c_str());
155154
printf(" mask_img: %s\n", params.mask_path.c_str());
156155
printf(" control_image: %s\n", params.control_image_path.c_str());
157-
printf(" ref_images_paths:\n");
158-
for (auto& path : params.ref_image_paths) {
159-
printf(" %s\n", path.c_str());
160-
};
161156
printf(" clip on cpu: %s\n", params.clip_on_cpu ? "true" : "false");
162157
printf(" controlnet cpu: %s\n", params.control_net_cpu ? "true" : "false");
163158
printf(" vae decoder on cpu:%s\n", params.vae_on_cpu ? "true" : "false");
@@ -212,7 +207,6 @@ void print_usage(int argc, const char* argv[]) {
212207
printf(" -i, --init-img [IMAGE] path to the input image, required by img2img\n");
213208
printf(" --mask [MASK] path to the mask image, required by img2img with mask\n");
214209
printf(" --control-image [IMAGE] path to image condition, control net\n");
215-
printf(" -r, --ref_image [PATH] reference image for Flux Kontext models (can be used multiple times) \n");
216210
printf(" -o, --output OUTPUT path to write result image to (default: ./output.png)\n");
217211
printf(" -p, --prompt [PROMPT] the prompt to render\n");
218212
printf(" -n, --negative-prompt PROMPT the negative prompt (default: \"\")\n");
@@ -248,8 +242,9 @@ void print_usage(int argc, const char* argv[]) {
248242
printf(" This might crash if it is not supported by the backend.\n");
249243
printf(" --control-net-cpu keep controlnet in cpu (for low vram)\n");
250244
printf(" --canny apply canny preprocessor (edge detection)\n");
251-
printf(" --color colors the logging tags according to level\n");
245+
printf(" --color Colors the logging tags according to level\n");
252246
printf(" -v, --verbose print extra info\n");
247+
printf(" -ki, --kontext_img [PATH] Reference image for Flux Kontext models (can be used multiple times) \n");
253248
}
254249

255250
void parse_args(int argc, const char** argv, SDParams& params) {
@@ -634,12 +629,12 @@ void parse_args(int argc, const char** argv, SDParams& params) {
634629
break;
635630
}
636631
params.skip_layer_end = std::stof(argv[i]);
637-
} else if (arg == "-r" || arg == "--ref-image") {
632+
} else if (arg == "-ki" || arg == "--kontext-img") {
638633
if (++i >= argc) {
639634
invalid_arg = true;
640635
break;
641636
}
642-
params.ref_image_paths.push_back(argv[i]);
637+
params.kontext_image_paths.push_back(argv[i]);
643638
} else {
644639
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
645640
print_usage(argc, argv);
@@ -668,13 +663,7 @@ void parse_args(int argc, const char** argv, SDParams& params) {
668663
}
669664

670665
if ((params.mode == IMG2IMG || params.mode == IMG2VID) && params.input_path.length() == 0) {
671-
fprintf(stderr, "error: when using the img2img/img2vid mode, the following arguments are required: init-img\n");
672-
print_usage(argc, argv);
673-
exit(1);
674-
}
675-
676-
if (params.mode == EDIT && params.ref_image_paths.size() == 0) {
677-
fprintf(stderr, "error: when using the edit mode, the following arguments are required: ref-image\n");
666+
fprintf(stderr, "error: when using the img2img mode, the following arguments are required: init-img\n");
678667
print_usage(argc, argv);
679668
exit(1);
680669
}
@@ -838,12 +827,43 @@ int main(int argc, const char* argv[]) {
838827
fprintf(stderr, "SVD support is broken, do not use it!!!\n");
839828
return 1;
840829
}
830+
bool vae_decode_only = true;
831+
832+
std::vector<sd_image_t> kontext_imgs;
833+
for (auto& path : params.kontext_image_paths) {
834+
vae_decode_only = false;
835+
int c = 0;
836+
int width = 0;
837+
int height = 0;
838+
uint8_t* image_buffer = stbi_load(path.c_str(), &width, &height, &c, 3);
839+
if (image_buffer == NULL) {
840+
fprintf(stderr, "load image from '%s' failed\n", path.c_str());
841+
return 1;
842+
}
843+
if (c < 3) {
844+
fprintf(stderr, "the number of channels for the input image must be >= 3, but got %d channels\n", c);
845+
free(image_buffer);
846+
return 1;
847+
}
848+
if (width <= 0) {
849+
fprintf(stderr, "error: the width of image must be greater than 0\n");
850+
free(image_buffer);
851+
return 1;
852+
}
853+
if (height <= 0) {
854+
fprintf(stderr, "error: the height of image must be greater than 0\n");
855+
free(image_buffer);
856+
return 1;
857+
}
858+
kontext_imgs.push_back({(uint32_t)width,
859+
(uint32_t)height,
860+
3,
861+
image_buffer});
862+
}
841863

842-
bool vae_decode_only = true;
843864
uint8_t* input_image_buffer = NULL;
844865
uint8_t* control_image_buffer = NULL;
845866
uint8_t* mask_image_buffer = NULL;
846-
std::vector<sd_image_t> ref_images;
847867

848868
if (params.mode == IMG2IMG || params.mode == IMG2VID) {
849869
vae_decode_only = false;
@@ -895,37 +915,6 @@ int main(int argc, const char* argv[]) {
895915
free(input_image_buffer);
896916
input_image_buffer = resized_image_buffer;
897917
}
898-
} else if (params.mode == EDIT) {
899-
vae_decode_only = false;
900-
for (auto& path : params.ref_image_paths) {
901-
int c = 0;
902-
int width = 0;
903-
int height = 0;
904-
uint8_t* image_buffer = stbi_load(path.c_str(), &width, &height, &c, 3);
905-
if (image_buffer == NULL) {
906-
fprintf(stderr, "load image from '%s' failed\n", path.c_str());
907-
return 1;
908-
}
909-
if (c < 3) {
910-
fprintf(stderr, "the number of channels for the input image must be >= 3, but got %d channels\n", c);
911-
free(image_buffer);
912-
return 1;
913-
}
914-
if (width <= 0) {
915-
fprintf(stderr, "error: the width of image must be greater than 0\n");
916-
free(image_buffer);
917-
return 1;
918-
}
919-
if (height <= 0) {
920-
fprintf(stderr, "error: the height of image must be greater than 0\n");
921-
free(image_buffer);
922-
return 1;
923-
}
924-
ref_images.push_back({(uint32_t)width,
925-
(uint32_t)height,
926-
3,
927-
image_buffer});
928-
}
929918
}
930919

931920
sd_ctx_t* sd_ctx = new_sd_ctx(params.model_path.c_str(),
@@ -1012,12 +1001,14 @@ int main(int argc, const char* argv[]) {
10121001
params.style_ratio,
10131002
params.normalize_input,
10141003
params.input_id_images_path.c_str(),
1004+
kontext_imgs.data(), kontext_imgs.size(),
10151005
params.skip_layers.data(),
10161006
params.skip_layers.size(),
10171007
params.slg_scale,
10181008
params.skip_layer_start,
1019-
params.skip_layer_end);
1020-
} else if (params.mode == IMG2IMG || params.mode == IMG2VID) {
1009+
params.skip_layer_end,
1010+
std::vector<sd_image_t*>());
1011+
} else {
10211012
sd_image_t input_image = {(uint32_t)params.width,
10221013
(uint32_t)params.height,
10231014
3,
@@ -1081,38 +1072,14 @@ int main(int argc, const char* argv[]) {
10811072
params.style_ratio,
10821073
params.normalize_input,
10831074
params.input_id_images_path.c_str(),
1075+
kontext_imgs.data(), kontext_imgs.size(),
10841076
params.skip_layers.data(),
10851077
params.skip_layers.size(),
10861078
params.slg_scale,
10871079
params.skip_layer_start,
1088-
params.skip_layer_end);
1080+
params.skip_layer_end,
1081+
std::vector<sd_image_t*>());
10891082
}
1090-
} else { // EDIT
1091-
results = edit(sd_ctx,
1092-
ref_images.data(),
1093-
ref_images.size(),
1094-
params.prompt.c_str(),
1095-
params.negative_prompt.c_str(),
1096-
params.clip_skip,
1097-
params.cfg_scale,
1098-
params.guidance,
1099-
params.eta,
1100-
params.width,
1101-
params.height,
1102-
params.sample_method,
1103-
params.sample_steps,
1104-
params.strength,
1105-
params.seed,
1106-
params.batch_count,
1107-
control_image,
1108-
params.control_strength,
1109-
params.style_ratio,
1110-
params.normalize_input,
1111-
params.skip_layers.data(),
1112-
params.skip_layers.size(),
1113-
params.slg_scale,
1114-
params.skip_layer_start,
1115-
params.skip_layer_end);
11161083
}
11171084

11181085
if (results == NULL) {
@@ -1150,19 +1117,19 @@ int main(int argc, const char* argv[]) {
11501117

11511118
std::string dummy_name, ext, lc_ext;
11521119
bool is_jpg;
1153-
size_t last = params.output_path.find_last_of(".");
1120+
size_t last = params.output_path.find_last_of(".");
11541121
size_t last_path = std::min(params.output_path.find_last_of("/"),
11551122
params.output_path.find_last_of("\\"));
1156-
if (last != std::string::npos // filename has extension
1157-
&& (last_path == std::string::npos || last > last_path)) {
1123+
if (last != std::string::npos // filename has extension
1124+
&& (last_path == std::string::npos || last > last_path)) {
11581125
dummy_name = params.output_path.substr(0, last);
11591126
ext = lc_ext = params.output_path.substr(last);
11601127
std::transform(ext.begin(), ext.end(), lc_ext.begin(), ::tolower);
11611128
is_jpg = lc_ext == ".jpg" || lc_ext == ".jpeg" || lc_ext == ".jpe";
11621129
} else {
11631130
dummy_name = params.output_path;
11641131
ext = lc_ext = "";
1165-
is_jpg = false;
1132+
is_jpg = false;
11661133
}
11671134
// appending ".png" to absent or unknown extension
11681135
if (!is_jpg && lc_ext != ".png") {
@@ -1174,7 +1141,7 @@ int main(int argc, const char* argv[]) {
11741141
continue;
11751142
}
11761143
std::string final_image_path = i > 0 ? dummy_name + "_" + std::to_string(i + 1) + ext : dummy_name + ext;
1177-
if(is_jpg) {
1144+
if (is_jpg) {
11781145
stbi_write_jpg(final_image_path.c_str(), results[i].width, results[i].height, results[i].channel,
11791146
results[i].data, 90);
11801147
printf("save result JPEG image to '%s'\n", final_image_path.c_str());

0 commit comments

Comments
 (0)