Skip to content

Commit aa1422e

Browse files
authored
Merge pull request #1 from lincuan/main
Merge PR TEN-framework#7 from lincuan
2 parents 6abbf1b + 624aa66 commit aa1422e

File tree

5 files changed

+665
-0
lines changed

5 files changed

+665
-0
lines changed

include/ten_vad_enhanced.h

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
#ifndef TEN_VAD_H
2+
#define TEN_VAD_H
3+
4+
#if defined(__APPLE__) || defined(__ANDROID__) || defined(__linux__)
5+
#define TENVAD_API __attribute__((visibility("default")))
6+
#elif defined(_WIN32) || defined(__CYGWIN__)
7+
#ifdef TENVAD_EXPORTS
8+
#define TENVAD_API __declspec(dllexport)
9+
#else
10+
#define TENVAD_API __declspec(dllimport)
11+
#endif
12+
#else
13+
#define TENVAD_API
14+
#endif
15+
16+
#include <stddef.h> /* size_t */
17+
#include <stdint.h> /* int16_t */
18+
19+
#ifdef __cplusplus
20+
extern "C"
21+
{
22+
#endif
23+
24+
/**
25+
* @brief Error codes for TEN VAD operations.
26+
*/
27+
typedef enum {
28+
TEN_VAD_SUCCESS = 0, /**< Operation successful */
29+
TEN_VAD_ERROR_INVALID_PARAM = -1, /**< Invalid parameter (e.g., null pointer, invalid hop_size) */
30+
TEN_VAD_ERROR_OUT_OF_MEMORY = -2, /**< Memory allocation failed */
31+
TEN_VAD_ERROR_INVALID_STATE = -3, /**< Invalid VAD handle or state */
32+
TEN_VAD_ERROR_PROCESS_FAILED = -4 /**< Processing error */
33+
} ten_vad_error_t;
34+
35+
/**
36+
* @typedef ten_vad_handle
37+
* @brief Opaque handle for ten_vad instance.
38+
*/
39+
typedef void *ten_vad_handle_t;
40+
41+
/**
42+
* @brief Callback function type for VAD processing results.
43+
*
44+
* @param probability Voice activity probability [0.0, 1.0].
45+
* @param flag Binary voice activity decision (0: no voice, 1: voice).
46+
* @param user_data User-defined data passed to the callback.
47+
*/
48+
typedef void (*ten_vad_callback_t)(float probability, int flag, void *user_data);
49+
50+
/**
51+
* @brief Version information for the TEN VAD library.
52+
*/
53+
typedef struct {
54+
int major; /**< Major version number */
55+
int minor; /**< Minor version number */
56+
int patch; /**< Patch version number */
57+
} ten_vad_version_t;
58+
59+
/**
60+
* @brief Create and initialize a ten_vad instance.
61+
*
62+
* @param[out] handle Pointer to receive the vad handle. Must not be NULL.
63+
* @param[in] hop_size Number of samples per analysis frame (e.g., 256). Must be positive.
64+
* @param[in] threshold VAD detection threshold [0.0, 1.0]. Determines voice activity by comparing with output probability.
65+
* @return TEN_VAD_SUCCESS on success, TEN_VAD_ERROR_INVALID_PARAM if handle is NULL or parameters are invalid,
66+
* TEN_VAD_ERROR_OUT_OF_MEMORY if allocation fails.
67+
* @note Must call ten_vad_destroy() to release resources.
68+
* @example
69+
* ten_vad_handle_t handle = NULL;
70+
* ten_vad_error_t ret = ten_vad_create(&handle, 256, 0.5);
71+
* if (ret == TEN_VAD_SUCCESS) {
72+
* // Use handle
73+
* ten_vad_destroy(&handle);
74+
* }
75+
*/
76+
TENVAD_API ten_vad_error_t ten_vad_create(ten_vad_handle_t *handle, size_t hop_size, float threshold);
77+
78+
/**
79+
* @brief Process one audio frame for voice activity detection.
80+
* Must call ten_vad_create() before calling this, and ten_vad_destroy() when done.
81+
*
82+
* @param[in] handle Valid VAD handle returned by ten_vad_create().
83+
* @param[in] audio_data Pointer to an array of int16_t samples, buffer length must equal hop_size.
84+
* @param[in] audio_data_length Size of audio_data buffer, must equal hop_size.
85+
* @param[out] out_probability Pointer to a float (size 1) to receive voice activity probability [0.0, 1.0].
86+
* @param[out] out_flag Pointer to an int (size 1) to receive binary decision: 0 (no voice), 1 (voice).
87+
* @return TEN_VAD_SUCCESS on success, TEN_VAD_ERROR_INVALID_PARAM if parameters are invalid,
88+
* TEN_VAD_ERROR_INVALID_STATE if handle is invalid, TEN_VAD_ERROR_PROCESS_FAILED on processing error.
89+
*/
90+
TENVAD_API ten_vad_error_t ten_vad_process(ten_vad_handle_t handle, const int16_t *audio_data, size_t audio_data_length,
91+
float *out_probability, int *out_flag);
92+
93+
/**
94+
* @brief Destroy a ten_vad instance and release its resources.
95+
*
96+
* @param[in,out] handle Pointer to the ten_vad handle; set to NULL on success.
97+
* @return TEN_VAD_SUCCESS on success, TEN_VAD_ERROR_INVALID_PARAM if handle is NULL.
98+
* @note Safe to call multiple times; subsequent calls with NULL handle return TEN_VAD_SUCCESS.
99+
*/
100+
TENVAD_API ten_vad_error_t ten_vad_destroy(ten_vad_handle_t *handle);
101+
102+
/**
103+
* @brief Update the VAD threshold dynamically.
104+
*
105+
* @param[in] handle Valid VAD handle returned by ten_vad_create().
106+
* @param[in] threshold New VAD detection threshold [0.0, 1.0].
107+
* @return TEN_VAD_SUCCESS on success, TEN_VAD_ERROR_INVALID_PARAM if handle or threshold is invalid.
108+
*/
109+
TENVAD_API ten_vad_error_t ten_vad_set_threshold(ten_vad_handle_t handle, float threshold);
110+
111+
/**
112+
* @brief Register a callback for VAD processing results.
113+
*
114+
* @param[in] handle Valid VAD handle.
115+
* @param[in] callback Callback function to invoke after ten_vad_process.
116+
* @param[in] user_data User-defined data to pass to the callback.
117+
* @return TEN_VAD_SUCCESS on success, TEN_VAD_ERROR_INVALID_PARAM if handle or callback is invalid.
118+
*/
119+
TENVAD_API ten_vad_error_t ten_vad_register_callback(ten_vad_handle_t handle, ten_vad_callback_t callback, void *user_data);
120+
121+
/**
122+
* @brief Get the ten_vad library version string.
123+
*
124+
* @return The version string (e.g., "1.0.0").
125+
*/
126+
TENVAD_API const char *ten_vad_get_version(void);
127+
128+
/**
129+
* @brief Get the ten_vad library version.
130+
*
131+
* @param[out] version Pointer to a ten_vad_version_t structure to receive version information.
132+
* @return TEN_VAD_SUCCESS on success, TEN_VAD_ERROR_INVALID_PARAM if version is NULL.
133+
*/
134+
TENVAD_API ten_vad_error_t ten_vad_get_version_struct(ten_vad_version_t *version);
135+
136+
#ifdef __cplusplus
137+
}
138+
#endif
139+
140+
#endif /* TEN_VAD_H */

include/ten_vad_enhanced.py

Lines changed: 214 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,214 @@
1+
import logging
2+
import platform
3+
import os
4+
from ctypes import c_int, c_int32, c_float, c_size_t, CDLL, c_void_p, POINTER
5+
import numpy as np
6+
from typing import Tuple, Callable, Optional
7+
import asyncio
8+
9+
10+
logging.basicConfig(level=logging.INFO)
11+
logger = logging.getLogger(__name__)
12+
13+
class TenVad:
14+
"""Voice Activity Detection (VAD) using a C-based library.
15+
16+
Args:
17+
hop_size (int, optional): Size of each audio frame. Defaults to 256.
18+
threshold (float, optional): Speech detection threshold (0 to 1). Defaults to 0.5.
19+
callback (Callable[[float, int], None], optional): Callback function to handle VAD output.
20+
21+
Raises:
22+
FileNotFoundError: If the VAD library cannot be found.
23+
RuntimeError: If VAD handler creation fails.
24+
ValueError: If hop_size or threshold is invalid.
25+
"""
26+
def __init__(self, hop_size: int = 256, threshold: float = 0.5, callback: Optional[Callable[[float, int], None]] = None):
27+
if hop_size <= 0:
28+
raise ValueError("[TEN VAD]: hop_size must be positive")
29+
if not 0 <= threshold <= 1:
30+
raise ValueError("[TEN VAD]: threshold must be between 0 and 1")
31+
32+
self.hop_size = hop_size
33+
self.threshold = threshold
34+
self.callback = callback
35+
self._audio_data_ref = None # 用于保持音频数据引用,防止垃圾回收
36+
37+
# 动态加载库
38+
def get_library_path():
39+
base_dir = os.path.dirname(os.path.relpath(__file__))
40+
system = platform.system().lower()
41+
arch = platform.machine()
42+
lib_name = "libten_vad.so" if system == "linux" else "ten_vad.dll" if system == "windows" else "libten_vad.dylib"
43+
possible_paths = [
44+
os.path.join(base_dir, f"../lib/{system}/{arch}/{lib_name}"),
45+
os.path.join(base_dir, f"./ten_vad_library/{lib_name}"),
46+
os.environ.get("TEN_VAD_LIB_PATH", "")
47+
]
48+
for path in possible_paths:
49+
if path and os.path.exists(path):
50+
return path
51+
raise FileNotFoundError(f"[TEN VAD]: Could not find {lib_name} library")
52+
53+
self.vad_library = CDLL(get_library_path())
54+
self.vad_handler = c_void_p(0)
55+
self.out_probability = c_float()
56+
self.out_flags = c_int32()
57+
58+
# 设置 C 函数签名
59+
self.vad_library.ten_vad_create.argtypes = [POINTER(c_void_p), c_size_t, c_float]
60+
self.vad_library.ten_vad_create.restype = c_int
61+
self.vad_library.ten_vad_destroy.argtypes = [POINTER(c_void_p)]
62+
self.vad_library.ten_vad_destroy.restype = c_int
63+
self.vad_library.ten_vad_process.argtypes = [c_void_p, c_void_p, c_size_t, POINTER(c_float), POINTER(c_int32)]
64+
self.vad_library.ten_vad_process.restype = c_int
65+
66+
self.create_and_init_handler()
67+
68+
def create_and_init_handler(self) -> None:
69+
"""Initialize the VAD handler.
70+
71+
Raises:
72+
RuntimeError: If handler creation fails.
73+
"""
74+
result = self.vad_library.ten_vad_create(
75+
POINTER(c_void_p)(self.vad_handler),
76+
c_size_t(self.hop_size),
77+
c_float(self.threshold),
78+
)
79+
if result != 0:
80+
logger.error("[TEN VAD]: Failed to create handler, error code: %d", result)
81+
raise RuntimeError(f"[TEN VAD]: create handler failure with error code: {result}")
82+
83+
def __del__(self) -> None:
84+
"""Destroy the VAD handler.
85+
86+
Raises:
87+
RuntimeError: If handler destruction fails.
88+
"""
89+
if self.vad_handler:
90+
result = self.vad_library.ten_vad_destroy(POINTER(c_void_p)(self.vad_handler))
91+
if result != 0:
92+
logger.error("[TEN VAD]: Failed to destroy handler, error code: %d", result)
93+
raise RuntimeError(f"[TEN VAD]: destroy handler failure with error code: {result}")
94+
95+
def get_input_data(self, audio_data: np.ndarray) -> c_void_p:
96+
"""Prepare audio data for processing.
97+
98+
Args:
99+
audio_data (np.ndarray): Audio data of shape (hop_size,) and type int16.
100+
101+
Returns:
102+
c_void_p: Pointer to the audio data.
103+
104+
Raises:
105+
TypeError: If audio_data is not a NumPy array or has incorrect type.
106+
ValueError: If audio_data shape or size is invalid.
107+
"""
108+
if not isinstance(audio_data, np.ndarray):
109+
raise TypeError("[TEN VAD]: audio_data must be a NumPy array")
110+
audio_data = np.squeeze(audio_data)
111+
if audio_data.size == 0:
112+
raise ValueError("[TEN VAD]: audio_data is empty")
113+
if len(audio_data.shape) != 1 or audio_data.shape[0] != self.hop_size:
114+
raise ValueError(f"[TEN VAD]: audio data shape should be [{self.hop_size}]")
115+
if audio_data.dtype != np.int16:
116+
raise TypeError("[TEN VAD]: audio data type must be int16")
117+
if not audio_data.flags.c_contiguous:
118+
audio_data = np.ascontiguousarray(audio_data, dtype=np.int16)
119+
return c_void_p(audio_data.__array_interface__["data"][0])
120+
121+
def set_threshold(self, threshold: float) -> None:
122+
"""Update the VAD threshold dynamically.
123+
124+
Args:
125+
threshold (float): New threshold value (0 to 1).
126+
127+
Raises:
128+
ValueError: If threshold is not between 0 and 1.
129+
RuntimeError: If handler reinitialization fails.
130+
"""
131+
if not 0 <= threshold <= 1:
132+
raise ValueError("[TEN VAD]: threshold must be between 0 and 1")
133+
self.threshold = threshold
134+
if self.vad_handler:
135+
self.vad_library.ten_vad_destroy(POINTER(c_void_p)(self.vad_handler))
136+
self.create_and_init_handler()
137+
138+
def _process_internal(self, audio_data: np.ndarray) -> Tuple[float, int]:
139+
"""Internal method to process audio data.
140+
141+
Args:
142+
audio_data (np.ndarray): Audio data to process.
143+
144+
Returns:
145+
Tuple[float, int]: Speech probability and detection flag.
146+
147+
Raises:
148+
RuntimeError: If processing fails.
149+
"""
150+
self._audio_data_ref = audio_data # 保持引用防止垃圾回收
151+
input_pointer = self.get_input_data(audio_data)
152+
result = self.vad_library.ten_vad_process(
153+
self.vad_handler,
154+
input_pointer,
155+
c_size_t(self.hop_size),
156+
POINTER(c_float)(self.out_probability),
157+
POINTER(c_int32)(self.out_flags),
158+
)
159+
if result != 0:
160+
logger.error("[TEN VAD]: Process failed, error code: %d", result)
161+
raise RuntimeError(f"[TEN VAD]: process failed with error code: {result}")
162+
return self.out_probability.value, self.out_flags.value
163+
164+
def process(self, audio_data: np.ndarray) -> Tuple[float, int]:
165+
"""Process an audio frame and return VAD results.
166+
167+
Args:
168+
audio_data (np.ndarray): Audio data of shape (hop_size,) and type int16.
169+
170+
Returns:
171+
Tuple[float, int]: Speech probability and detection flag.
172+
173+
Raises:
174+
ValueError: If audio_data shape or type is invalid.
175+
RuntimeError: If VAD processing fails.
176+
"""
177+
prob, flag = self._process_internal(audio_data)
178+
if self.callback:
179+
self.callback(prob, flag)
180+
return prob, flag
181+
182+
async def process_async(self, audio_data: np.ndarray) -> Tuple[float, int]:
183+
"""Asynchronously process an audio frame and return VAD results.
184+
185+
Args:
186+
audio_data (np.ndarray): Audio data of shape (hop_size,) and type int16.
187+
188+
Returns:
189+
Tuple[float, int]: Speech probability and detection flag.
190+
191+
Raises:
192+
ValueError: If audio_data shape or type is invalid.
193+
RuntimeError: If VAD processing fails.
194+
"""
195+
self._audio_data_ref = audio_data # 保持引用
196+
input_pointer = self.get_input_data(audio_data)
197+
loop = asyncio.get_event_loop()
198+
result = await loop.run_in_executor(
199+
None,
200+
lambda: self.vad_library.ten_vad_process(
201+
self.vad_handler,
202+
input_pointer,
203+
c_size_t(self.hop_size),
204+
POINTER(c_float)(self.out_probability),
205+
POINTER(c_int32)(self.out_flags),
206+
)
207+
)
208+
if result != 0:
209+
logger.error("[TEN VAD]: Async process failed, error code: %d", result)
210+
raise RuntimeError(f"[TEN VAD]: async process failed with error code: {result}")
211+
prob, flag = self.out_probability.value, self.out_flags.value
212+
if self.callback:
213+
self.callback(prob, flag)
214+
return prob, flag

0 commit comments

Comments
 (0)