Skip to content

Commit 82cea22

Browse files
author
Abubakar Abid
committed
added flagging tests
1 parent e76aa4a commit 82cea22

File tree

2 files changed

+37
-5
lines changed

2 files changed

+37
-5
lines changed

.circleci/config.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ jobs:
1818
. venv/bin/activate
1919
pip install --upgrade pip
2020
pip install -r gradio.egg-info/requires.txt
21-
pip install shap IPython comet_ml wandb mlflow tensorflow transformers
21+
pip install shap IPython comet_ml wandb mlflow tensorflow transformers huggingface_hub
2222
pip install selenium==4.0.0a6.post2 coverage scikit-image
2323
- run:
2424
command: |

test/test_flagging.py

+36-4
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
1+
import os
12
import tempfile
23
import unittest
4+
from unittest.mock import MagicMock
5+
6+
import huggingface_hub
37

48
import gradio as gr
59
from gradio import flagging
610

711

812
class TestDefaultFlagging(unittest.TestCase):
9-
def test_default_flagging_handler(self):
13+
def test_default_flagging_callback(self):
1014
with tempfile.TemporaryDirectory() as tmpdirname:
1115
io = gr.Interface(lambda x: x, "text", "text", flagging_dir=tmpdirname)
1216
io.launch(prevent_thread_lock=True)
@@ -18,7 +22,7 @@ def test_default_flagging_handler(self):
1822

1923

2024
class TestSimpleFlagging(unittest.TestCase):
21-
def test_simple_csv_flagging_handler(self):
25+
def test_simple_csv_flagging_callback(self):
2226
with tempfile.TemporaryDirectory() as tmpdirname:
2327
io = gr.Interface(
2428
lambda x: x,
@@ -29,11 +33,39 @@ def test_simple_csv_flagging_handler(self):
2933
)
3034
io.launch(prevent_thread_lock=True)
3135
row_count = io.flagging_callback.flag(io, ["test"], ["test"])
32-
self.assertEqual(row_count, 0) # no header
36+
self.assertEqual(row_count, 0) # no header in SimpleCSVLogger
3337
row_count = io.flagging_callback.flag(io, ["test"], ["test"])
34-
self.assertEqual(row_count, 1) # no header
38+
self.assertEqual(row_count, 1) # no header in SimpleCSVLogger
3539
io.close()
3640

3741

42+
class TestHuggingFaceDatasetSaver(unittest.TestCase):
43+
def test_saver_setup(self):
44+
huggingface_hub.create_repo = MagicMock()
45+
huggingface_hub.Repository = MagicMock()
46+
flagger = flagging.HuggingFaceDatasetSaver("test", "test")
47+
with tempfile.TemporaryDirectory() as tmpdirname:
48+
flagger.setup(tmpdirname)
49+
huggingface_hub.create_repo.assert_called_once()
50+
51+
def test_saver_flag(self):
52+
huggingface_hub.create_repo = MagicMock()
53+
huggingface_hub.Repository = MagicMock()
54+
with tempfile.TemporaryDirectory() as tmpdirname:
55+
io = gr.Interface(
56+
lambda x: x,
57+
"text",
58+
"text",
59+
flagging_dir=tmpdirname,
60+
flagging_callback=flagging.HuggingFaceDatasetSaver("test", "test"),
61+
)
62+
os.mkdir(os.path.join(tmpdirname, "test"))
63+
io.launch(prevent_thread_lock=True)
64+
row_count = io.flagging_callback.flag(io, ["test"], ["test"])
65+
self.assertEqual(row_count, 1) # 2 rows written including header
66+
row_count = io.flagging_callback.flag(io, ["test"], ["test"])
67+
self.assertEqual(row_count, 2) # 3 rows written including header
68+
69+
3870
if __name__ == "__main__":
3971
unittest.main()

0 commit comments

Comments
 (0)