@@ -330,13 +330,14 @@ def _post_training_quantization_ov(
330
330
if datamodule is None :
331
331
msg = "Datamodule must be provided for OpenVINO INT8_PTQ compression"
332
332
raise ValueError (msg )
333
+ datamodule .setup ("fit" )
333
334
334
335
model_input = model .input (0 )
335
336
336
337
if model_input .partial_shape [0 ].is_static :
337
338
datamodule .train_batch_size = model_input .shape [0 ]
338
339
339
- dataloader = datamodule .train_dataloader ()
340
+ dataloader = datamodule .val_dataloader ()
340
341
if len (dataloader .dataset ) < 300 :
341
342
logger .warning (
342
343
f">300 images recommended for INT8 quantization, found only { len (dataloader .dataset )} images" ,
@@ -373,6 +374,8 @@ def _accuracy_control_quantization_ov(
373
374
if datamodule is None :
374
375
msg = "Datamodule must be provided for OpenVINO INT8_PTQ compression"
375
376
raise ValueError (msg )
377
+ datamodule .setup ("fit" )
378
+
376
379
if metric is None :
377
380
msg = "Metric must be provided for OpenVINO INT8_ACQ compression"
378
381
raise ValueError (msg )
@@ -383,14 +386,14 @@ def _accuracy_control_quantization_ov(
383
386
datamodule .train_batch_size = model_input .shape [0 ]
384
387
datamodule .eval_batch_size = model_input .shape [0 ]
385
388
386
- dataloader = datamodule .train_dataloader ()
389
+ dataloader = datamodule .val_dataloader ()
387
390
if len (dataloader .dataset ) < 300 :
388
391
logger .warning (
389
392
f">300 images recommended for INT8 quantization, found only { len (dataloader .dataset )} images" ,
390
393
)
391
394
392
395
calibration_dataset = nncf .Dataset (dataloader , lambda x : x ["image" ])
393
- validation_dataset = nncf .Dataset (datamodule .val_dataloader ())
396
+ validation_dataset = nncf .Dataset (datamodule .test_dataloader ())
394
397
395
398
if isinstance (metric , str ):
396
399
metric = create_metric_collection ([metric ])[metric ]
0 commit comments