@@ -5,20 +5,30 @@ use crate::engine::ai::providers::ProviderTrait;
5
5
use crate :: error:: AIProxyError ;
6
6
use ahnlich_types:: ai:: ExecutionProvider ;
7
7
use ahnlich_types:: { ai:: AIStoreInputType , keyval:: StoreKey } ;
8
- use image:: { DynamicImage , GenericImageView , ImageFormat , ImageReader } ;
8
+ use fast_image_resize:: images:: Image ;
9
+ use fast_image_resize:: images:: ImageRef ;
10
+ use fast_image_resize:: FilterType ;
11
+ use fast_image_resize:: PixelType ;
12
+ use fast_image_resize:: ResizeAlg ;
13
+ use fast_image_resize:: ResizeOptions ;
14
+ use fast_image_resize:: Resizer ;
15
+ use image:: imageops;
16
+ use image:: ImageReader ;
17
+ use image:: RgbImage ;
9
18
use ndarray:: { Array , Ix3 } ;
10
19
use ndarray:: { ArrayView , Ix4 } ;
11
20
use nonzero_ext:: nonzero;
12
- use serde:: de:: Error as DeError ;
13
- use serde:: ser:: Error as SerError ;
14
- use serde:: { Deserialize , Deserializer , Serialize , Serializer } ;
21
+ use once_cell:: sync:: Lazy ;
22
+ use serde:: { Deserialize , Serialize } ;
15
23
use std:: fmt;
16
24
use std:: io:: Cursor ;
17
25
use std:: num:: NonZeroUsize ;
18
26
use std:: path:: PathBuf ;
19
27
use strum:: Display ;
20
28
use tokenizers:: Encoding ;
21
29
30
+ static CHANNELS : Lazy < u8 > = Lazy :: new ( || image:: ColorType :: Rgb8 . channel_count ( ) ) ;
31
+
22
32
#[ derive( Display , Debug , Serialize , Deserialize ) ]
23
33
pub enum ModelType {
24
34
Text {
@@ -246,31 +256,71 @@ pub enum ModelInput {
246
256
Images ( Array < f32 , Ix4 > ) ,
247
257
}
248
258
249
- #[ derive( Debug , Clone ) ]
250
- pub struct ImageArray {
259
+ #[ derive( Debug ) ]
260
+ pub struct OnnxTransformResult {
251
261
array : Array < f32 , Ix3 > ,
252
- image : DynamicImage ,
253
- image_format : ImageFormat ,
254
- onnx_transformed : bool ,
255
262
}
256
263
257
- impl ImageArray {
258
- pub fn try_new ( bytes : Vec < u8 > ) -> Result < Self , AIProxyError > {
259
- let img_reader = ImageReader :: new ( Cursor :: new ( & bytes) )
260
- . with_guessed_format ( )
261
- . map_err ( |_| AIProxyError :: ImageBytesDecodeError ) ?;
264
+ impl OnnxTransformResult {
265
+ pub fn view ( & self ) -> ArrayView < f32 , Ix3 > {
266
+ self . array . view ( )
267
+ }
262
268
263
- let image_format = & img_reader
264
- . format ( )
265
- . ok_or ( AIProxyError :: ImageBytesDecodeError ) ?;
269
+ pub fn image_dim ( & self ) -> ( NonZeroUsize , NonZeroUsize ) {
270
+ let shape = self . array . shape ( ) ;
271
+ (
272
+ NonZeroUsize :: new ( shape[ 2 ] ) . expect ( "Array columns should be non zero" ) ,
273
+ NonZeroUsize :: new ( shape[ 1 ] ) . expect ( "Array channels should be non zero" ) ,
274
+ )
275
+ }
276
+ }
266
277
267
- let image = img_reader
268
- . decode ( )
278
+ impl TryFrom < ImageArray > for OnnxTransformResult {
279
+ type Error = AIProxyError ;
280
+
281
+ // Swapping axes from [rows, columns, channels] to [channels, rows, columns] for ONNX
282
+ #[ tracing:: instrument( skip_all) ]
283
+ fn try_from ( value : ImageArray ) -> Result < Self , Self :: Error > {
284
+ let image = value. image ;
285
+ let mut array = Array :: from_shape_vec (
286
+ (
287
+ image. height ( ) as usize ,
288
+ image. width ( ) as usize ,
289
+ * CHANNELS as usize ,
290
+ ) ,
291
+ image. into_raw ( ) ,
292
+ )
293
+ . map_err ( |e| AIProxyError :: ImageArrayToNdArrayError {
294
+ message : format ! ( "Error running onnx transform {e}" ) ,
295
+ } ) ?
296
+ . mapv ( f32:: from) ;
297
+ array. swap_axes ( 1 , 2 ) ;
298
+ array. swap_axes ( 0 , 1 ) ;
299
+ Ok ( Self { array } )
300
+ }
301
+ }
302
+
303
+ #[ derive( Debug ) ]
304
+ pub struct ImageArray {
305
+ image : RgbImage ,
306
+ }
307
+
308
+ impl TryFrom < & [ u8 ] > for ImageArray {
309
+ type Error = AIProxyError ;
310
+
311
+ #[ tracing:: instrument( skip_all) ]
312
+ fn try_from ( value : & [ u8 ] ) -> Result < Self , Self :: Error > {
313
+ let img_reader = ImageReader :: new ( Cursor :: new ( value) )
314
+ . with_guessed_format ( )
269
315
. map_err ( |_| AIProxyError :: ImageBytesDecodeError ) ?;
270
316
271
317
// Always convert to RGB8 format
272
318
// https://github.com/Anush008/fastembed-rs/blob/cea92b6c8b877efda762393848d1c449a4eea126/src/image_embedding/utils.rs#L198
273
- let image: DynamicImage = image. to_owned ( ) . into_rgb8 ( ) . into ( ) ;
319
+ let image = img_reader
320
+ . decode ( )
321
+ . map_err ( |_| AIProxyError :: ImageBytesDecodeError ) ?
322
+ . into_rgb8 ( ) ;
323
+
274
324
let ( width, height) = image. dimensions ( ) ;
275
325
276
326
if width == 0 || height == 0 {
@@ -279,116 +329,59 @@ impl ImageArray {
279
329
height : height as usize ,
280
330
} ) ;
281
331
}
282
-
283
- let channels = & image. color ( ) . channel_count ( ) ;
284
- let shape = ( height as usize , width as usize , * channels as usize ) ;
285
- let array = Array :: from_shape_vec ( shape, image. clone ( ) . into_bytes ( ) )
286
- . map_err ( |_| AIProxyError :: ImageBytesDecodeError ) ?
287
- . mapv ( f32:: from) ;
288
-
289
- Ok ( ImageArray {
290
- array,
291
- image,
292
- image_format : image_format. to_owned ( ) ,
293
- onnx_transformed : false ,
294
- } )
295
- }
296
-
297
- // Swapping axes from [rows, columns, channels] to [channels, rows, columns] for ONNX
298
- pub fn onnx_transform ( & mut self ) {
299
- if self . onnx_transformed {
300
- return ;
301
- }
302
- self . array . swap_axes ( 1 , 2 ) ;
303
- self . array . swap_axes ( 0 , 1 ) ;
304
- self . onnx_transformed = true ;
305
- }
306
-
307
- pub fn view ( & self ) -> ArrayView < f32 , Ix3 > {
308
- self . array . view ( )
332
+ Ok ( Self { image } )
309
333
}
334
+ }
310
335
311
- pub fn get_bytes ( & self ) -> Result < Vec < u8 > , AIProxyError > {
312
- let mut buffer = Cursor :: new ( Vec :: new ( ) ) ;
313
- let _ = & self
314
- . image
315
- . write_to ( & mut buffer, self . image_format )
316
- . map_err ( |_| AIProxyError :: ImageBytesEncodeError ) ?;
317
- let bytes = buffer. into_inner ( ) ;
318
- Ok ( bytes)
336
+ impl ImageArray {
337
+ fn array_view ( & self ) -> ArrayView < u8 , Ix3 > {
338
+ let shape = (
339
+ self . image . height ( ) as usize ,
340
+ self . image . width ( ) as usize ,
341
+ * CHANNELS as usize ,
342
+ ) ;
343
+ let raw_bytes = self . image . as_raw ( ) ;
344
+ ArrayView :: from_shape ( shape, raw_bytes) . expect ( "Image bytes decode error" )
319
345
}
320
346
347
+ #[ tracing:: instrument( skip( self ) ) ]
321
348
pub fn resize (
322
- & self ,
349
+ & mut self ,
323
350
width : u32 ,
324
351
height : u32 ,
325
352
filter : Option < image:: imageops:: FilterType > ,
326
353
) -> Result < Self , AIProxyError > {
327
- let filter_type = filter . unwrap_or ( image:: imageops :: FilterType :: CatmullRom ) ;
328
- let resized_img = self . image . resize_exact ( width , height , filter_type ) ;
329
- let channels = resized_img . color ( ) . channel_count ( ) ;
330
- let shape = ( height as usize , width as usize , channels as usize ) ;
331
-
332
- let flattened_pixels = resized_img . clone ( ) . into_bytes ( ) ;
333
- let array = Array :: from_shape_vec ( shape , flattened_pixels )
334
- . map_err ( |_ | AIProxyError :: ImageResizeError ) ?
335
- . mapv ( f32 :: from ) ;
336
- Ok ( ImageArray {
337
- array ,
338
- image : resized_img ,
339
- image_format : self . image_format ,
340
- onnx_transformed : false ,
341
- } )
354
+ // Create container for data of destination image
355
+ let ( width , height ) = self . image . dimensions ( ) ;
356
+ let mut dest_image = Image :: new ( width , height , PixelType :: U8x3 ) ;
357
+ let mut resizer = Resizer :: new ( ) ;
358
+ resizer
359
+ . resize (
360
+ & ImageRef :: new ( width , height , self . image . as_raw ( ) , PixelType :: U8x3 )
361
+ . map_err ( |e | AIProxyError :: ImageResizeError ( e . to_string ( ) ) ) ? ,
362
+ & mut dest_image ,
363
+ & ResizeOptions :: new ( ) . resize_alg ( ResizeAlg :: Convolution ( FilterType :: CatmullRom ) ) ,
364
+ )
365
+ . map_err ( |e| AIProxyError :: ImageResizeError ( e . to_string ( ) ) ) ? ;
366
+ let resized_img = RgbImage :: from_raw ( width , height , dest_image . into_vec ( ) )
367
+ . expect ( "Could not get image after resizing" ) ;
368
+ Ok ( ImageArray { image : resized_img } )
342
369
}
343
370
344
- pub fn crop ( & self , x : u32 , y : u32 , width : u32 , height : u32 ) -> Result < Self , AIProxyError > {
345
- let cropped_img = self . image . crop_imm ( x, y, width, height) ;
346
- let channels = cropped_img. color ( ) . channel_count ( ) ;
347
- let shape = ( height as usize , width as usize , channels as usize ) ;
348
-
349
- let flattened_pixels = cropped_img. clone ( ) . into_bytes ( ) ;
350
- let array = Array :: from_shape_vec ( shape, flattened_pixels)
351
- . map_err ( |_| AIProxyError :: ImageCropError ) ?
352
- . mapv ( f32:: from) ;
353
- Ok ( ImageArray {
354
- array,
355
- image : cropped_img,
356
- image_format : self . image_format ,
357
- onnx_transformed : false ,
358
- } )
359
- }
360
-
361
- pub fn image_dim ( & self ) -> ( NonZeroUsize , NonZeroUsize ) {
362
- let shape = self . array . shape ( ) ;
363
- match self . onnx_transformed {
364
- true => (
365
- NonZeroUsize :: new ( shape[ 2 ] ) . expect ( "Array columns should be non-zero" ) ,
366
- NonZeroUsize :: new ( shape[ 1 ] ) . expect ( "Array channels should be non-zero" ) ,
367
- ) , // (width, channels)
368
- false => (
369
- NonZeroUsize :: new ( shape[ 1 ] ) . expect ( "Array columns should be non-zero" ) ,
370
- NonZeroUsize :: new ( shape[ 0 ] ) . expect ( "Array rows should be non-zero" ) ,
371
- ) , // (width, height)
372
- }
373
- }
374
- }
371
+ #[ tracing:: instrument( skip( self ) ) ]
372
+ pub fn crop ( & mut self , x : u32 , y : u32 , width : u32 , height : u32 ) -> Result < Self , AIProxyError > {
373
+ let cropped_img = imageops:: crop ( & mut self . image , x, y, width, height) . to_image ( ) ;
375
374
376
- impl Serialize for ImageArray {
377
- fn serialize < S > ( & self , serializer : S ) -> Result < S :: Ok , S :: Error >
378
- where
379
- S : Serializer ,
380
- {
381
- serializer. serialize_bytes ( & self . get_bytes ( ) . map_err ( S :: Error :: custom) ?)
375
+ Ok ( ImageArray { image : cropped_img } )
382
376
}
383
- }
384
377
385
- impl < ' de > Deserialize < ' de > for ImageArray {
386
- fn deserialize < D > ( deserializer : D ) -> Result < Self , D :: Error >
387
- where
388
- D : Deserializer < ' de > ,
389
- {
390
- let bytes : Vec < u8 > = Deserialize :: deserialize ( deserializer ) ? ;
391
- ImageArray :: try_new ( bytes ) . map_err ( D :: Error :: custom )
378
+ pub fn image_dim ( & self ) -> ( NonZeroUsize , NonZeroUsize ) {
379
+ let arr_view = self . array_view ( ) ;
380
+ let shape = arr_view . shape ( ) ;
381
+ (
382
+ NonZeroUsize :: new ( shape [ 1 ] ) . expect ( "Array columns should be non-zero" ) ,
383
+ NonZeroUsize :: new ( shape [ 0 ] ) . expect ( "Array rows should be non-zero" ) ,
384
+ )
392
385
}
393
386
}
394
387
0 commit comments