Skip to content

Commit c988e48

Browse files
authored
Merge pull request #837 from ftnext/feature/pass-prompt-openai
2 parents 0747dcc + 72da569 commit c988e48

File tree

2 files changed

+29
-7
lines changed

2 files changed

+29
-7
lines changed

speech_recognition/recognizers/whisper_api/base.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
1+
import logging
12
from io import BytesIO
23

34
from speech_recognition.audio import AudioData
45

6+
logger = logging.getLogger(__name__)
7+
58

69
class OpenAICompatibleRecognizer:
710
def __init__(self, client) -> None:
@@ -16,7 +19,10 @@ def recognize(self, audio_data: "AudioData", model: str, **kwargs) -> str:
1619
wav_data = BytesIO(audio_data.get_wav_data())
1720
wav_data.name = "SpeechRecognition_audio.wav"
1821

22+
parameters = {"model": model, **kwargs}
23+
logger.debug(parameters)
24+
1925
transcript = self.client.audio.transcriptions.create(
20-
file=wav_data, model=model, **kwargs
26+
file=wav_data, **parameters
2127
)
2228
return transcript.text

speech_recognition/recognizers/whisper_api/openai.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import logging
34
from typing import Literal
45

56
from typing_extensions import Unpack
@@ -65,16 +66,31 @@ def recognize(
6566
parser = argparse.ArgumentParser()
6667
parser.add_argument("audio_file")
6768
parser.add_argument(
68-
"--model", choices=get_args(WhisperModel), default="whisper-1"
69+
"-m", "--model", choices=get_args(WhisperModel), default="whisper-1"
6970
)
7071
parser.add_argument("-l", "--language")
72+
parser.add_argument("-p", "--prompt")
73+
parser.add_argument("-v", "--verbose", action="store_true")
7174
args = parser.parse_args()
7275

76+
if args.verbose:
77+
speech_recognition_logger = logging.getLogger("speech_recognition")
78+
speech_recognition_logger.setLevel(logging.DEBUG)
79+
80+
console_handler = logging.StreamHandler()
81+
console_formatter = logging.Formatter(
82+
"%(asctime)s | %(levelname)s | %(name)s:%(funcName)s:%(lineno)d - %(message)s"
83+
)
84+
console_handler.setFormatter(console_formatter)
85+
speech_recognition_logger.addHandler(console_handler)
86+
7387
audio_data = sr.AudioData.from_file(args.audio_file)
88+
89+
recognize_args = {"model": args.model}
7490
if args.language:
75-
transcription = recognize(
76-
None, audio_data, model=args.model, language=args.language
77-
)
78-
else:
79-
transcription = recognize(None, audio_data, model=args.model)
91+
recognize_args["language"] = args.language
92+
if args.prompt:
93+
recognize_args["prompt"] = args.prompt
94+
95+
transcription = recognize(None, audio_data, **recognize_args)
8096
print(transcription)

0 commit comments

Comments
 (0)