-
Notifications
You must be signed in to change notification settings - Fork 1k
Support for SDXL refiner #227
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 23 commits
4cd95e1
7c03b1d
db154eb
1d8bfff
d3be8b9
5f9d50c
fc204ef
f4142cf
3aa0e23
58b4f62
137d0c8
e28f812
2543c1f
b211d2d
fdc0185
81dd25b
d589ab8
43619e0
352f349
e5724db
90864bc
e2e2b16
7cb53e8
84450eb
1dae882
ae516e7
efda893
a5f0280
17dec17
580bcbd
c69deb7
68c39b3
afdfbcd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,6 +4,7 @@ | |
import Foundation | ||
import Accelerate | ||
import CoreML | ||
import CoreGraphics | ||
|
||
@available(iOS 16.0, macOS 13.0, *) | ||
extension CGImage { | ||
|
@@ -77,7 +78,7 @@ extension CGImage { | |
else { | ||
throw ShapedArrayError.incorrectFormatsConvertingToShapedArray | ||
} | ||
|
||
var sourceImageBuffer = try vImage_Buffer(cgImage: self) | ||
|
||
var mediumDestination = try vImage_Buffer(width: Int(width), height: Int(height), bitsPerPixel: mediumFormat.bitsPerPixel) | ||
|
@@ -88,7 +89,7 @@ extension CGImage { | |
nil, | ||
vImage_Flags(kvImagePrintDiagnosticsToConsole), | ||
nil) | ||
|
||
guard let converter = converter?.takeRetainedValue() else { | ||
throw ShapedArrayError.vImageConverterNotInitialized | ||
} | ||
|
@@ -99,7 +100,7 @@ extension CGImage { | |
var destinationR = try vImage_Buffer(width: Int(width), height: Int(height), bitsPerPixel: 8 * UInt32(MemoryLayout<Float>.size)) | ||
var destinationG = try vImage_Buffer(width: Int(width), height: Int(height), bitsPerPixel: 8 * UInt32(MemoryLayout<Float>.size)) | ||
var destinationB = try vImage_Buffer(width: Int(width), height: Int(height), bitsPerPixel: 8 * UInt32(MemoryLayout<Float>.size)) | ||
|
||
var minFloat: [Float] = Array(repeating: minValue, count: 4) | ||
var maxFloat: [Float] = Array(repeating: maxValue, count: 4) | ||
|
||
|
@@ -125,7 +126,56 @@ extension CGImage { | |
let imageData = redData + greenData + blueData | ||
|
||
let shapedArray = MLShapedArray<Float32>(data: imageData, shape: [1, 3, self.height, self.width]) | ||
|
||
|
||
return shapedArray | ||
} | ||
|
||
private func normalizePixelValues(pixel: UInt8) -> Float { | ||
return (Float(pixel) / 127.5) - 1.0 | ||
} | ||
|
||
public func toRGBShapedArray(minValue: Float, maxValue: Float) | ||
throws -> MLShapedArray<Float32> { | ||
let image = self | ||
let width = image.width | ||
let height = image.height | ||
let alphaMaskValue: Float = minValue | ||
|
||
guard let colorSpace = CGColorSpace(name: CGColorSpace.sRGB), | ||
let context = CGContext(data: nil, width: width, height: height, bitsPerComponent: 8, bytesPerRow: 4 * width, space: colorSpace, bitmapInfo: CGImageAlphaInfo.premultipliedLast.rawValue), | ||
let ptr = context.data?.bindMemory(to: UInt8.self, capacity: width * height * 4) else { | ||
return [] | ||
} | ||
|
||
context.draw(image, in: CGRect(x: 0, y: 0, width: width, height: height)) | ||
|
||
var redChannel = [Float](repeating: 0, count: width * height) | ||
var greenChannel = [Float](repeating: 0, count: width * height) | ||
var blueChannel = [Float](repeating: 0, count: width * height) | ||
|
||
for y in 0..<height { | ||
for x in 0..<width { | ||
let i = 4 * (y * width + x) | ||
if ptr[i+3] == 0 { | ||
// Alpha mask for controlnets | ||
redChannel[y * width + x] = alphaMaskValue | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I attempted to replicate it with this new function, and the |
||
greenChannel[y * width + x] = alphaMaskValue | ||
blueChannel[y * width + x] = alphaMaskValue | ||
} else { | ||
redChannel[y * width + x] = normalizePixelValues(pixel: ptr[i]) | ||
greenChannel[y * width + x] = normalizePixelValues(pixel: ptr[i+1]) | ||
blueChannel[y * width + x] = normalizePixelValues(pixel: ptr[i+2]) | ||
} | ||
} | ||
} | ||
|
||
let colorShape = [1, 1, height, width] | ||
let redShapedArray = MLShapedArray<Float32>(scalars: redChannel, shape: colorShape) | ||
let greenShapedArray = MLShapedArray<Float32>(scalars: greenChannel, shape: colorShape) | ||
let blueShapedArray = MLShapedArray<Float32>(scalars: blueChannel, shape: colorShape) | ||
|
||
let shapedArray = MLShapedArray<Float32>(concatenating: [redShapedArray, greenShapedArray, blueShapedArray], alongAxis: 1) | ||
|
||
return shapedArray | ||
} | ||
} | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -93,7 +93,7 @@ public struct Encoder: ResourceManaging { | |
|
||
var inputDescription: MLFeatureDescription { | ||
try! model.perform { model in | ||
model.modelDescription.inputDescriptionsByName["z"]! | ||
model.modelDescription.inputDescriptionsByName.first!.value | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This preserves backwards compatibility so I am not concerned about the rename and thanks for the semantically correct fix :) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No need to update, just appreciating the change |
||
} | ||
} | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,8 +20,14 @@ public struct PipelineConfiguration: Hashable { | |
public var negativePrompt: String = "" | ||
/// Starting image for image2image or in-painting | ||
public var startingImage: CGImage? = nil | ||
//public var maskImage: CGImage? = nil | ||
/// Fraction of inference steps to be used in `.imageToImage` pipeline mode | ||
/// Must be between 0 and 1 | ||
/// Higher values will result in greater transformation of the `startingImage` | ||
public var strength: Float = 1.0 | ||
/// Fraction of inference steps to at which to start using the refiner unet if present in `textToImage` mode | ||
/// Must be between 0 and 1 | ||
/// Higher values will result in fewer refiner steps | ||
public var refinerStart: Float = 0.8 | ||
/// Number of images to generate | ||
public var imageCount: Int = 1 | ||
/// Number of inference steps to perform | ||
|
@@ -44,7 +50,19 @@ public struct PipelineConfiguration: Hashable { | |
public var encoderScaleFactor: Float32 = 0.18215 | ||
/// Scale factor to use on the latent before decoding | ||
public var decoderScaleFactor: Float32 = 0.18215 | ||
|
||
/// If `originalSize` is not the same as `targetSize` the image will appear to be down- or upsampled. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am inclined to recommend a dedicated There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yea, I'd agree it probably shouldn't be a requirement on the caller to set the scale factor properly, especially when we already know the pipeline is SDXL and can set the default properly. My question would be whether we'd want a new class for this or just setup the defaults somewhere in the pipeline. Or potentially a broader option of just including the config files with the converted models directly, alongside the merges.txt and vocab.json, and using those for the default config. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
That would be nice and we can even make that backward compatible. However, it still wouldn't address the argument set mismatch across XL and non-XL right? IMO, the easiest is to create 2 (XL and non-XL) default configs and then add your proposal from above as an extension. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yea that makes sense to me, and seems to be the approach There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree, we can refactor over time :) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree with the approach discussed, separate configs do seem necessary to make usage simpler. |
||
/// Part of SDXL’s micro-conditioning as explained in section 2.2 of https://huggingface.co/papers/2307.01952. | ||
public var originalSize: Float32 = 1024 | ||
/// `cropsCoordsTopLeft` can be used to generate an image that appears to be “cropped” from the position `cropsCoordsTopLeft` downwards. | ||
/// Favorable, well-centered images are usually achieved by setting `cropsCoordsTopLeft` to (0, 0). | ||
public var cropsCoordsTopLeft: Float32 = 0 | ||
/// For most cases, `target_size` should be set to the desired height and width of the generated image. | ||
public var targetSize: Float32 = 1024 | ||
/// Used to simulate an aesthetic score of the generated image by influencing the positive text condition. | ||
public var aestheticScore: Float32 = 6 | ||
/// Can be used to simulate an aesthetic score of the generated image by influencing the negative text condition. | ||
public var negativeAestheticScore: Float32 = 2.5 | ||
|
||
/// Given the configuration, what mode will be used for generation | ||
public var mode: PipelineMode { | ||
guard startingImage != nil else { | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is --xl-version used for the base model only, or do you need to use it along with --refiner-version for the refiner?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jrittvo That's a good question, I suppose it would work for a non-xl model in the same way. All it is doing is converting the unet for the provided model and renaming it as the refiner model for that base model. The only real requirement is that they are both the same kind of model so the latent matchup. That's for the conversion - but the pipeline for the swift side of things has nothing to handle a
UnetRefiner.mlmodelc
for no XL models, so it would just be ignored.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So be more specific though, yes the --xl-version should be used along with the --refiner-version in order to convert the refiner in a similar way as the
--model-version
input.