-
Notifications
You must be signed in to change notification settings - Fork 1k
Support for SDXL refiner #227
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
4cd95e1
7c03b1d
db154eb
1d8bfff
d3be8b9
5f9d50c
fc204ef
f4142cf
3aa0e23
58b4f62
137d0c8
e28f812
2543c1f
b211d2d
fdc0185
81dd25b
d589ab8
43619e0
352f349
e5724db
90864bc
e2e2b16
7cb53e8
84450eb
1dae882
ae516e7
efda893
a5f0280
17dec17
580bcbd
c69deb7
68c39b3
afdfbcd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,7 +13,7 @@ | |
import coremltools as ct | ||
from diffusers import ( | ||
StableDiffusionPipeline, | ||
StableDiffusionXLPipeline, | ||
DiffusionPipeline, | ||
ControlNetModel | ||
) | ||
import gc | ||
|
@@ -568,15 +568,22 @@ def convert_vae_encoder(pipe, args): | |
height = (args.latent_h or pipe.unet.config.sample_size) * 8 | ||
width = (args.latent_w or pipe.unet.config.sample_size) * 8 | ||
|
||
z_shape = ( | ||
x_shape = ( | ||
1, # B | ||
3, # C (RGB range from -1 to 1) | ||
height, # H | ||
width, # w | ||
) | ||
|
||
if args.xl_version: | ||
inputs_dtype = torch.float32 | ||
compute_precision = ct.precision.FLOAT32 | ||
else: | ||
inputs_dtype = torch.float16 | ||
compute_precision = None | ||
|
||
sample_vae_encoder_inputs = { | ||
"z": torch.rand(*z_shape, dtype=torch.float16) | ||
"x": torch.rand(*x_shape, dtype=inputs_dtype) | ||
} | ||
|
||
class VAEEncoder(nn.Module): | ||
|
@@ -588,19 +595,19 @@ def __init__(self): | |
self.quant_conv = pipe.vae.quant_conv.to(dtype=torch.float32) | ||
self.encoder = pipe.vae.encoder.to(dtype=torch.float32) | ||
|
||
def forward(self, z): | ||
return self.quant_conv(self.encoder(z)) | ||
def forward(self, x): | ||
return self.quant_conv(self.encoder(x)) | ||
|
||
baseline_encoder = VAEEncoder().eval() | ||
|
||
# No optimization needed for the VAE Encoder as it is a pure ConvNet | ||
traced_vae_encoder = torch.jit.trace( | ||
baseline_encoder, (sample_vae_encoder_inputs["z"].to(torch.float32), )) | ||
baseline_encoder, (sample_vae_encoder_inputs["x"].to(torch.float32), )) | ||
|
||
modify_coremltools_torch_frontend_badbmm() | ||
coreml_vae_encoder, out_path = _convert_to_coreml( | ||
"vae_encoder", traced_vae_encoder, sample_vae_encoder_inputs, | ||
["latent"], args) | ||
["latent"], args, precision=compute_precision) | ||
|
||
# Set model metadata | ||
coreml_vae_encoder.author = f"Please refer to the Model Card available at huggingface.co/{args.model_version}" | ||
|
@@ -611,7 +618,7 @@ def forward(self, z): | |
"Please refer to https://arxiv.org/abs/2112.10752 for details." | ||
|
||
# Set the input descriptions | ||
coreml_vae_encoder.input_description["z"] = \ | ||
coreml_vae_encoder.input_description["x"] = \ | ||
"The input image to base the initial latents on normalized to range [-1, 1]" | ||
|
||
# Set the output descriptions | ||
|
@@ -624,7 +631,7 @@ def forward(self, z): | |
# Parity check PyTorch vs CoreML | ||
if args.check_output_correctness: | ||
baseline_out = baseline_encoder( | ||
z=sample_vae_encoder_inputs["z"].to(torch.float32)).numpy() | ||
x=sample_vae_encoder_inputs["x"].to(torch.float32)).numpy() | ||
coreml_out = list( | ||
coreml_vae_encoder.predict( | ||
{k: v.numpy() | ||
|
@@ -673,12 +680,19 @@ def convert_unet(pipe, args): | |
raise RuntimeError( | ||
"convert_text_encoder() deletes pipe.text_encoder to save RAM. " | ||
"Please use convert_unet() before convert_text_encoder()") | ||
|
||
if hasattr(pipe, "text_encoder") and pipe.text_encoder is not None: | ||
text_token_sequence_length = pipe.text_encoder.config.max_position_embeddings | ||
hidden_size = pipe.text_encoder.config.hidden_size, | ||
elif hasattr(pipe, "text_encoder_2") and pipe.text_encoder_2 is not None: | ||
text_token_sequence_length = pipe.text_encoder_2.config.max_position_embeddings | ||
hidden_size = pipe.text_encoder_2.config.hidden_size, | ||
|
||
encoder_hidden_states_shape = ( | ||
batch_size, | ||
args.text_encoder_hidden_size or pipe.unet.cross_attention_dim or pipe.text_encoder.config.hidden_size, | ||
args.text_encoder_hidden_size or pipe.unet.cross_attention_dim or hidden_size, | ||
1, | ||
args.text_token_sequence_length or pipe.text_encoder.config.max_position_embeddings, | ||
args.text_token_sequence_length or text_token_sequence_length, | ||
) | ||
|
||
# Create the scheduled timesteps for downstream use | ||
|
@@ -704,11 +718,28 @@ def convert_unet(pipe, args): | |
unet_cls = unet.UNet2DConditionModelXL | ||
|
||
# Sample time_ids | ||
time_ids = torch.tensor([ | ||
pipe.vae.sample_size, pipe.vae.sample_size, # output_resolution | ||
0., 0., # topleft_crop_cond | ||
pipe.vae.sample_size, pipe.vae.sample_size # resolution_cond | ||
] * (batch_size)).to(torch.float32) | ||
height = (args.latent_h or pipe.unet.config.sample_size) * 8 | ||
width = (args.latent_w or pipe.unet.config.sample_size) * 8 | ||
|
||
original_size = (height, width) # output_resolution | ||
crops_coords_top_left = (0, 0) # topleft_crop_cond | ||
target_size = (height, width) # resolution_cond | ||
if hasattr(pipe, "requires_aesthetics_score") and pipe.config.requires_aesthetics_score: | ||
# Part of SDXL's micro-conditioning as explained in section 2.2 of | ||
# [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). Can be used to | ||
# simulate an aesthetic score of the generated image by influencing the position and negative text conditions. | ||
aesthetic_score = 6.0 # default aesthetic_score | ||
negative_aesthetic_score = 2.5 # default negative_aesthetic_score | ||
add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,)) | ||
add_neg_time_ids = list(original_size + crops_coords_top_left + (negative_aesthetic_score,)) | ||
else: | ||
add_time_ids = list(original_size + crops_coords_top_left + target_size) | ||
add_neg_time_ids = list(original_size + crops_coords_top_left + target_size) | ||
|
||
time_ids = [ | ||
add_neg_time_ids, | ||
add_time_ids | ||
] | ||
|
||
# Pooled text embedding from text_encoder_2 | ||
text_embeds_shape = ( | ||
|
@@ -717,7 +748,7 @@ def convert_unet(pipe, args): | |
) | ||
|
||
additional_xl_inputs = OrderedDict([ | ||
("time_ids", time_ids), | ||
("time_ids", torch.tensor(time_ids).to(torch.float32)), | ||
("text_embeds", torch.rand(*text_embeds_shape)), | ||
]) | ||
|
||
|
@@ -796,9 +827,15 @@ def convert_unet(pipe, args): | |
for k, v in sample_unet_inputs.items() | ||
} | ||
|
||
if args.xl_version and pipe.config.requires_aesthetics_score and args.unet_support_cli: | ||
# SDXL Refiner Unet does not support FP16 via CLI | ||
compute_precision = ct.precision.FLOAT32 | ||
else: | ||
compute_precision = None | ||
|
||
coreml_unet, out_path = _convert_to_coreml(unet_name, reference_unet, | ||
coreml_sample_unet_inputs, | ||
["noise_pred"], args) | ||
["noise_pred"], args, precision=compute_precision) | ||
del reference_unet | ||
gc.collect() | ||
|
||
|
@@ -1219,15 +1256,11 @@ def convert_controlnet(pipe, args): | |
gc.collect() | ||
|
||
def get_pipeline(args): | ||
if 'xl' in args.model_version: | ||
pipe = StableDiffusionXLPipeline.from_pretrained(args.model_version, | ||
torch_dtype=torch.float16, | ||
variant="fp16", | ||
use_safetensors=True, | ||
use_auth_token=True) | ||
else: | ||
pipe = StableDiffusionPipeline.from_pretrained(args.model_version, | ||
use_auth_token=True) | ||
pipe = DiffusionPipeline.from_pretrained(args.model_version, | ||
torch_dtype=torch.float16, | ||
variant="fp16", | ||
use_safetensors=True, | ||
use_auth_token=True) | ||
return pipe | ||
|
||
def main(args): | ||
|
@@ -1267,13 +1300,13 @@ def main(args): | |
convert_unet(pipe, args) | ||
logger.info("Converted unet") | ||
|
||
if args.convert_text_encoder: | ||
if args.convert_text_encoder and hasattr(pipe, "text_encoder") and pipe.text_encoder is not None: | ||
logger.info("Converting text_encoder") | ||
convert_text_encoder(pipe.text_encoder, pipe.tokenizer, "text_encoder", args) | ||
del pipe.text_encoder | ||
logger.info("Converted text_encoder") | ||
|
||
if args.convert_text_encoder and args.xl_version: | ||
if args.convert_text_encoder and hasattr(pipe, "text_encoder_2") and pipe.text_encoder_2 is not None: | ||
logger.info("Converting text_encoder_2") | ||
convert_text_encoder(pipe.text_encoder_2, pipe.tokenizer_2, "text_encoder_2", args) | ||
del pipe.text_encoder_2 | ||
|
@@ -1392,6 +1425,13 @@ def parser_spec(): | |
"If specified, enable unet to receive additional inputs from controlnet. " | ||
"Each input added to corresponding resnet output." | ||
) | ||
parser.add_argument( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we rename this to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yea I'll rename it. The CLI thing is still a mystery to me, I'd love to hear if anyone can replicate it or if it's just my machine. I was just getting noise on the output with using fp16, which went away when in fp32 format. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am still traveling with bad internet so I am yet to convert and test the refiner model.. (hence the slow review). Will confirm once I do! |
||
"--unet-support-cli", | ||
action="store_true", | ||
help= | ||
"If specified, convert the unet with float32 precision. " | ||
"This is only necessary if you plan to call the model from StableDiffusionCLI." | ||
) | ||
|
||
# Swift CLI Resource Bundling | ||
parser.add_argument( | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,6 +4,7 @@ | |
import Foundation | ||
import Accelerate | ||
import CoreML | ||
import CoreGraphics | ||
|
||
@available(iOS 16.0, macOS 13.0, *) | ||
extension CGImage { | ||
|
@@ -62,7 +63,7 @@ extension CGImage { | |
|
||
return cgImage | ||
} | ||
|
||
public func plannerRGBShapedArray(minValue: Float, maxValue: Float) | ||
throws -> MLShapedArray<Float32> { | ||
guard | ||
|
@@ -77,34 +78,34 @@ extension CGImage { | |
else { | ||
throw ShapedArrayError.incorrectFormatsConvertingToShapedArray | ||
} | ||
|
||
var sourceImageBuffer = try vImage_Buffer(cgImage: self) | ||
|
||
var mediumDesination = try vImage_Buffer(width: Int(width), height: Int(height), bitsPerPixel: mediumFormat.bitsPerPixel) | ||
|
||
let converter = vImageConverter_CreateWithCGImageFormat( | ||
&sourceFormat, | ||
&mediumFormat, | ||
nil, | ||
vImage_Flags(kvImagePrintDiagnosticsToConsole), | ||
nil) | ||
|
||
guard let converter = converter?.takeRetainedValue() else { | ||
throw ShapedArrayError.vImageConverterNotInitialized | ||
} | ||
|
||
vImageConvert_AnyToAny(converter, &sourceImageBuffer, &mediumDesination, nil, vImage_Flags(kvImagePrintDiagnosticsToConsole)) | ||
|
||
var destinationA = try vImage_Buffer(width: Int(width), height: Int(height), bitsPerPixel: 8 * UInt32(MemoryLayout<Float>.size)) | ||
var destinationR = try vImage_Buffer(width: Int(width), height: Int(height), bitsPerPixel: 8 * UInt32(MemoryLayout<Float>.size)) | ||
var destinationG = try vImage_Buffer(width: Int(width), height: Int(height), bitsPerPixel: 8 * UInt32(MemoryLayout<Float>.size)) | ||
var destinationB = try vImage_Buffer(width: Int(width), height: Int(height), bitsPerPixel: 8 * UInt32(MemoryLayout<Float>.size)) | ||
|
||
var minFloat: [Float] = Array(repeating: minValue, count: 4) | ||
var maxFloat: [Float] = Array(repeating: maxValue, count: 4) | ||
|
||
vImageConvert_ARGB8888toPlanarF(&mediumDesination, &destinationA, &destinationR, &destinationG, &destinationB, &maxFloat, &minFloat, .zero) | ||
|
||
let destAPtr = destinationA.data.assumingMemoryBound(to: Float.self) | ||
let destRPtr = destinationR.data.assumingMemoryBound(to: Float.self) | ||
let destGPtr = destinationG.data.assumingMemoryBound(to: Float.self) | ||
|
@@ -121,11 +122,60 @@ extension CGImage { | |
let redData = Data(bytes: destinationR.data, count: Int(width) * Int(height) * MemoryLayout<Float>.size) | ||
let greenData = Data(bytes: destinationG.data, count: Int(width) * Int(height) * MemoryLayout<Float>.size) | ||
let blueData = Data(bytes: destinationB.data, count: Int(width) * Int(height) * MemoryLayout<Float>.size) | ||
|
||
let imageData = redData + greenData + blueData | ||
|
||
let shapedArray = MLShapedArray<Float32>(data: imageData, shape: [1, 3, self.height, self.width]) | ||
|
||
|
||
return shapedArray | ||
} | ||
|
||
private func normalizePixelValues(pixel: UInt8) -> Float { | ||
return (Float(pixel) / 127.5) - 1.0 | ||
} | ||
|
||
public func toRGBShapedArray(minValue: Float, maxValue: Float) | ||
throws -> MLShapedArray<Float32> { | ||
let image = self | ||
let width = image.width | ||
let height = image.height | ||
let alphaMaskValue: Float = minValue | ||
|
||
guard let colorSpace = CGColorSpace(name: CGColorSpace.sRGB), | ||
let context = CGContext(data: nil, width: width, height: height, bitsPerComponent: 8, bytesPerRow: 4 * width, space: colorSpace, bitmapInfo: CGImageAlphaInfo.premultipliedLast.rawValue), | ||
let ptr = context.data?.bindMemory(to: UInt8.self, capacity: width * height * 4) else { | ||
return [] | ||
} | ||
|
||
context.draw(image, in: CGRect(x: 0, y: 0, width: width, height: height)) | ||
|
||
var redChannel = [Float](repeating: 0, count: width * height) | ||
var greenChannel = [Float](repeating: 0, count: width * height) | ||
var blueChannel = [Float](repeating: 0, count: width * height) | ||
|
||
for y in 0..<height { | ||
for x in 0..<width { | ||
let i = 4 * (y * width + x) | ||
if ptr[i+3] == 0 { | ||
// Alpha mask for controlnets | ||
redChannel[y * width + x] = alphaMaskValue | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I attempted to replicate it with this new function, and the |
||
greenChannel[y * width + x] = alphaMaskValue | ||
blueChannel[y * width + x] = alphaMaskValue | ||
} else { | ||
redChannel[y * width + x] = normalizePixelValues(pixel: ptr[i]) | ||
greenChannel[y * width + x] = normalizePixelValues(pixel: ptr[i+1]) | ||
blueChannel[y * width + x] = normalizePixelValues(pixel: ptr[i+2]) | ||
} | ||
} | ||
} | ||
|
||
let colorShape = [1, 1, height, width] | ||
let redShapedArray = MLShapedArray<Float32>(scalars: redChannel, shape: colorShape) | ||
let greenShapedArray = MLShapedArray<Float32>(scalars: greenChannel, shape: colorShape) | ||
let blueShapedArray = MLShapedArray<Float32>(scalars: blueChannel, shape: colorShape) | ||
|
||
let shapedArray = MLShapedArray<Float32>(concatenating: [redShapedArray, greenShapedArray, blueShapedArray], alongAxis: 1) | ||
|
||
return shapedArray | ||
} | ||
} | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -50,7 +50,7 @@ public struct Encoder: ResourceManaging { | |
scaleFactor: Float32, | ||
random: inout RandomSource | ||
) throws -> MLShapedArray<Float32> { | ||
let imageData = try image.plannerRGBShapedArray(minValue: -1.0, maxValue: 1.0) | ||
let imageData = try image.toRGBShapedArray(minValue: -1.0, maxValue: 1.0) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you please summarize the reasons for moving away from There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This function is intended to fix a bug in the Example of the bug (converted to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Wow, thanks @ZachNagengast! Is this for any 1024x1024 or some pixel value-dependent edge case? I might want to file a ticket for this. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In any case, makes sense to workaround it by using your function at the moment There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yea you can try it out just do let imageData = try image.plannerRGBShapedArray(minValue: -1.0, maxValue: 1.0)
let decodedImage = try CGImage.fromShapedArray(imageData) and inspect the decoded image, it should look off like this for the SDXL pipeline. |
||
guard imageData.shape == inputShape else { | ||
// TODO: Consider auto resizing and croping similar to how Vision or CoreML auto-generated Swift code can accomplish with `MLFeatureValue` | ||
throw Error.sampleInputShapeNotCorrect | ||
|
@@ -93,7 +93,7 @@ public struct Encoder: ResourceManaging { | |
|
||
var inputDescription: MLFeatureDescription { | ||
try! model.perform { model in | ||
model.modelDescription.inputDescriptionsByName["z"]! | ||
model.modelDescription.inputDescriptionsByName.first!.value | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This preserves backwards compatibility so I am not concerned about the rename and thanks for the semantically correct fix :) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No need to update, just appreciating the change |
||
} | ||
} | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -44,7 +44,19 @@ public struct PipelineConfiguration: Hashable { | |
public var encoderScaleFactor: Float32 = 0.18215 | ||
/// Scale factor to use on the latent before decoding | ||
public var decoderScaleFactor: Float32 = 0.18215 | ||
|
||
/// If `originalSize` is not the same as `targetSize` the image will appear to be down- or upsampled. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am inclined to recommend a dedicated There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yea, I'd agree it probably shouldn't be a requirement on the caller to set the scale factor properly, especially when we already know the pipeline is SDXL and can set the default properly. My question would be whether we'd want a new class for this or just setup the defaults somewhere in the pipeline. Or potentially a broader option of just including the config files with the converted models directly, alongside the merges.txt and vocab.json, and using those for the default config. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
That would be nice and we can even make that backward compatible. However, it still wouldn't address the argument set mismatch across XL and non-XL right? IMO, the easiest is to create 2 (XL and non-XL) default configs and then add your proposal from above as an extension. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yea that makes sense to me, and seems to be the approach There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree, we can refactor over time :) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree with the approach discussed, separate configs do seem necessary to make usage simpler. |
||
/// Part of SDXL’s micro-conditioning as explained in section 2.2 of https://huggingface.co/papers/2307.01952. | ||
public var originalSize: Float32 = 1024 | ||
/// `cropsCoordsTopLeft` can be used to generate an image that appears to be “cropped” from the position `cropsCoordsTopLeft` downwards. | ||
/// Favorable, well-centered images are usually achieved by setting `cropsCoordsTopLeft` to (0, 0). | ||
public var cropsCoordsTopLeft: Float32 = 0 | ||
/// For most cases, `target_size` should be set to the desired height and width of the generated image. | ||
public var targetSize: Float32 = 1024 | ||
/// Used to simulate an aesthetic score of the generated image by influencing the positive text condition. | ||
public var aestheticScore: Float32 = 6 | ||
/// Can be used to simulate an aesthetic score of the generated image by influencing the negative text condition. | ||
public var negativeAestheticScore: Float32 = 2.5 | ||
|
||
/// Given the configuration, what mode will be used for generation | ||
public var mode: PipelineMode { | ||
guard startingImage != nil else { | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.