1
+ import os
1
2
import tempfile
2
3
import unittest
4
+ from unittest .mock import MagicMock
5
+
6
+ import huggingface_hub
3
7
4
8
import gradio as gr
5
9
from gradio import flagging
6
10
7
11
8
12
class TestDefaultFlagging (unittest .TestCase ):
9
- def test_default_flagging_handler (self ):
13
+ def test_default_flagging_callback (self ):
10
14
with tempfile .TemporaryDirectory () as tmpdirname :
11
15
io = gr .Interface (lambda x : x , "text" , "text" , flagging_dir = tmpdirname )
12
16
io .launch (prevent_thread_lock = True )
@@ -18,7 +22,7 @@ def test_default_flagging_handler(self):
18
22
19
23
20
24
class TestSimpleFlagging (unittest .TestCase ):
21
- def test_simple_csv_flagging_handler (self ):
25
+ def test_simple_csv_flagging_callback (self ):
22
26
with tempfile .TemporaryDirectory () as tmpdirname :
23
27
io = gr .Interface (
24
28
lambda x : x ,
@@ -29,11 +33,39 @@ def test_simple_csv_flagging_handler(self):
29
33
)
30
34
io .launch (prevent_thread_lock = True )
31
35
row_count = io .flagging_callback .flag (io , ["test" ], ["test" ])
32
- self .assertEqual (row_count , 0 ) # no header
36
+ self .assertEqual (row_count , 0 ) # no header in SimpleCSVLogger
33
37
row_count = io .flagging_callback .flag (io , ["test" ], ["test" ])
34
- self .assertEqual (row_count , 1 ) # no header
38
+ self .assertEqual (row_count , 1 ) # no header in SimpleCSVLogger
35
39
io .close ()
36
40
37
41
42
+ class TestHuggingFaceDatasetSaver (unittest .TestCase ):
43
+ def test_saver_setup (self ):
44
+ huggingface_hub .create_repo = MagicMock ()
45
+ huggingface_hub .Repository = MagicMock ()
46
+ flagger = flagging .HuggingFaceDatasetSaver ("test" , "test" )
47
+ with tempfile .TemporaryDirectory () as tmpdirname :
48
+ flagger .setup (tmpdirname )
49
+ huggingface_hub .create_repo .assert_called_once ()
50
+
51
+ def test_saver_flag (self ):
52
+ huggingface_hub .create_repo = MagicMock ()
53
+ huggingface_hub .Repository = MagicMock ()
54
+ with tempfile .TemporaryDirectory () as tmpdirname :
55
+ io = gr .Interface (
56
+ lambda x : x ,
57
+ "text" ,
58
+ "text" ,
59
+ flagging_dir = tmpdirname ,
60
+ flagging_callback = flagging .HuggingFaceDatasetSaver ("test" , "test" ),
61
+ )
62
+ os .mkdir (os .path .join (tmpdirname , "test" ))
63
+ io .launch (prevent_thread_lock = True )
64
+ row_count = io .flagging_callback .flag (io , ["test" ], ["test" ])
65
+ self .assertEqual (row_count , 1 ) # 2 rows written including header
66
+ row_count = io .flagging_callback .flag (io , ["test" ], ["test" ])
67
+ self .assertEqual (row_count , 2 ) # 3 rows written including header
68
+
69
+
38
70
if __name__ == "__main__" :
39
71
unittest .main ()
0 commit comments