Skip to content

Commit 91493d5

Browse files
authored
REF: refactor controlnet for image model (#2346)
1 parent 92fc84b commit 91493d5

File tree

4 files changed

+192
-74
lines changed

4 files changed

+192
-74
lines changed

examples/StableDiffusionControlNet.ipynb

+4-4
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@
9191
"from diffusers.utils import load_image\n",
9292
"\n",
9393
"mlsd = MLSDdetector.from_pretrained(\"lllyasviel/ControlNet\")\n",
94-
"image_path = os.path.expanduser(\"~/draft.png\")\n",
94+
"image_path = os.path.expanduser(\"draft.png\")\n",
9595
"image = load_image(image_path)\n",
9696
"image = mlsd(image)\n",
9797
"image"
@@ -181,7 +181,7 @@
181181
],
182182
"metadata": {
183183
"kernelspec": {
184-
"display_name": "Python 3",
184+
"display_name": "Python 3 (ipykernel)",
185185
"language": "python",
186186
"name": "python3"
187187
},
@@ -195,9 +195,9 @@
195195
"name": "python",
196196
"nbconvert_exporter": "python",
197197
"pygments_lexer": "ipython3",
198-
"version": "3.11.6"
198+
"version": "3.11.9"
199199
}
200200
},
201201
"nbformat": 4,
202-
"nbformat_minor": 2
202+
"nbformat_minor": 4
203203
}

xinference/model/image/core.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -210,18 +210,19 @@ def create_image_model_instance(
210210
for name in controlnet:
211211
for cn_model_spec in model_spec.controlnet:
212212
if cn_model_spec.model_name == name:
213-
if not model_path:
214-
model_path = cache(cn_model_spec)
215-
controlnet_model_paths.append(model_path)
213+
controlnet_model_path = cache(cn_model_spec)
214+
controlnet_model_paths.append(controlnet_model_path)
216215
break
217216
else:
218217
raise ValueError(
219218
f"controlnet `{name}` is not supported for model `{model_name}`."
220219
)
221220
if len(controlnet_model_paths) == 1:
222-
kwargs["controlnet"] = controlnet_model_paths[0]
221+
kwargs["controlnet"] = (controlnet[0], controlnet_model_paths[0])
223222
else:
224-
kwargs["controlnet"] = controlnet_model_paths
223+
kwargs["controlnet"] = [
224+
(n, path) for n, path in zip(controlnet, controlnet_model_paths)
225+
]
225226
if not model_path:
226227
model_path = cache(model_spec)
227228
if peft_model_config is not None:

xinference/model/image/sdapi.py

+33-4
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,12 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
1415
import base64
1516
import io
1617
import warnings
1718

18-
from PIL import Image
19+
from PIL import Image, ImageOps
1920

2021

2122
class SDAPIToDiffusersConverter:
@@ -30,7 +31,7 @@ class SDAPIToDiffusersConverter:
3031
txt2img_arg_mapping = {
3132
"steps": "num_inference_steps",
3233
"cfg_scale": "guidance_scale",
33-
# "denoising_strength": "strength",
34+
"denoising_strength": "strength",
3435
}
3536
img2img_identical_args = {
3637
"prompt",
@@ -42,9 +43,11 @@ class SDAPIToDiffusersConverter:
4243
}
4344
img2img_arg_mapping = {
4445
"init_images": "image",
46+
"mask": "mask_image",
4547
"steps": "num_inference_steps",
4648
"cfg_scale": "guidance_scale",
4749
"denoising_strength": "strength",
50+
"inpaint_full_res_padding": "padding_mask_crop",
4851
}
4952

5053
@staticmethod
@@ -121,12 +124,38 @@ def _decode_b64_img(img_str: str) -> Image:
121124

122125
def img2img(self, **kwargs):
123126
init_images = kwargs.pop("init_images", [])
124-
kwargs["init_images"] = [self._decode_b64_img(i) for i in init_images]
127+
kwargs["init_images"] = init_images = [
128+
self._decode_b64_img(i) for i in init_images
129+
]
130+
if len(init_images) == 1:
131+
kwargs["init_images"] = init_images[0]
132+
mask_image = kwargs.pop("mask", None)
133+
if mask_image:
134+
if kwargs.pop("inpainting_mask_invert"):
135+
mask_image = ImageOps.invert(mask_image)
136+
137+
kwargs["mask"] = self._decode_b64_img(mask_image)
138+
139+
# process inpaint_full_res and inpaint_full_res_padding
140+
if kwargs.pop("inpaint_full_res", None):
141+
kwargs["inpaint_full_res_padding"] = kwargs.pop(
142+
"inpaint_full_res_padding", 0
143+
)
144+
else:
145+
# inpaint_full_res_padding is turned `into padding_mask_crop`
146+
# in diffusers, if padding_mask_crop is passed, it will do inpaint_full_res
147+
# so if not inpaint_full_rs, we need to pop this option
148+
kwargs.pop("inpaint_full_res_padding", None)
149+
125150
clip_skip = kwargs.get("override_settings", {}).get("clip_skip")
126151
converted_kwargs = self._check_kwargs("img2img", kwargs)
127152
if clip_skip:
128153
converted_kwargs["clip_skip"] = clip_skip
129-
result = self.image_to_image(response_format="b64_json", **converted_kwargs) # type: ignore
154+
155+
if not converted_kwargs.get("mask_image"):
156+
result = self.image_to_image(response_format="b64_json", **converted_kwargs) # type: ignore
157+
else:
158+
result = self.inpainting(response_format="b64_json", **converted_kwargs) # type: ignore
130159

131160
# convert to SD API result
132161
return {

0 commit comments

Comments
 (0)