1
+ import threading
2
+ import time
1
3
from typing import Any , Dict , List , Optional , Tuple , Union
2
4
import logging
3
5
from tqdm import tqdm
4
6
from multiprocessing import Pool , cpu_count
5
7
from functools import partial
6
- import time
7
8
from pathlib import Path
8
9
import json
9
10
import hashlib
10
- import signal
11
11
12
- from lyrics_transcriber .types import LyricsData , PhraseScore , AnchorSequence , GapSequence , ScoredAnchor , TranscriptionResult , Word
12
+ from lyrics_transcriber .types import LyricsData , PhraseScore , PhraseType , AnchorSequence , GapSequence , ScoredAnchor , TranscriptionResult , Word
13
13
from lyrics_transcriber .correction .phrase_analyzer import PhraseAnalyzer
14
14
from lyrics_transcriber .correction .text_utils import clean_text
15
15
from lyrics_transcriber .utils .word_utils import WordUtils
@@ -45,7 +45,14 @@ def __init__(
45
45
# Initialize cache directory
46
46
self .cache_dir = Path (cache_dir )
47
47
self .cache_dir .mkdir (parents = True , exist_ok = True )
48
- self .logger .debug (f"Initialized AnchorSequenceFinder with cache dir: { self .cache_dir } , timeout: { timeout_seconds } s" )
48
+ self .logger .info (f"Initialized AnchorSequenceFinder with cache dir: { self .cache_dir } , timeout: { timeout_seconds } s" )
49
+
50
+ def _check_timeout (self , start_time : float , operation_name : str = "operation" ):
51
+ """Check if timeout has occurred and raise exception if so."""
52
+ if self .timeout_seconds > 0 :
53
+ elapsed_time = time .time () - start_time
54
+ if elapsed_time > self .timeout_seconds :
55
+ raise AnchorSequenceTimeoutError (f"{ operation_name } exceeded { self .timeout_seconds } seconds (elapsed: { elapsed_time :.1f} s)" )
49
56
50
57
def _clean_text (self , text : str ) -> str :
51
58
"""Clean text by removing punctuation and normalizing whitespace."""
@@ -177,10 +184,6 @@ def _load_from_cache(self, cache_path: Path) -> Optional[List[ScoredAnchor]]:
177
184
self .logger .error (f"Unexpected error loading cache: { type (e ).__name__ } : { e } " )
178
185
return None
179
186
180
- def _timeout_handler (self , signum , frame ):
181
- """Handle timeout signal by raising AnchorSequenceTimeoutError."""
182
- raise AnchorSequenceTimeoutError (f"Anchor sequence computation exceeded { self .timeout_seconds } seconds" )
183
-
184
187
def _process_ngram_length (
185
188
self ,
186
189
n : int ,
@@ -286,11 +289,6 @@ def find_anchors(
286
289
"""Find anchor sequences that appear in both transcription and references with timeout protection."""
287
290
start_time = time .time ()
288
291
289
- # Set up timeout signal handler
290
- if self .timeout_seconds > 0 :
291
- old_handler = signal .signal (signal .SIGALRM , self ._timeout_handler )
292
- signal .alarm (self .timeout_seconds )
293
-
294
292
try :
295
293
cache_key = self ._get_cache_key (transcribed , references , transcription_result )
296
294
cache_path = self .cache_dir / f"anchors_{ cache_key } .json"
@@ -316,6 +314,9 @@ def find_anchors(
316
314
self .logger .info (f"Cache miss for key { cache_key } - computing anchors with timeout { self .timeout_seconds } s" )
317
315
self .logger .info (f"Finding anchor sequences for transcription with length { len (transcribed )} " )
318
316
317
+ # Check timeout before starting computation
318
+ self ._check_timeout (start_time , "anchor computation initialization" )
319
+
319
320
# Get all words from transcription
320
321
all_words = []
321
322
for segment in transcription_result .result .segments :
@@ -329,6 +330,9 @@ def find_anchors(
329
330
}
330
331
ref_words = {source : [w for s in lyrics .segments for w in s .words ] for source , lyrics in references .items ()}
331
332
333
+ # Check timeout after preprocessing
334
+ self ._check_timeout (start_time , "anchor computation preprocessing" )
335
+
332
336
# Filter out very short reference sources for n-gram length calculation
333
337
valid_ref_lengths = [
334
338
len (words ) for words in ref_texts_clean .values ()
@@ -355,7 +359,10 @@ def find_anchors(
355
359
356
360
# Process n-gram lengths in parallel with timeout
357
361
candidate_anchors = []
358
- pool_timeout = max (60 , self .timeout_seconds // 2 ) # Use half the total timeout for pool operations
362
+ pool_timeout = max (60 , self .timeout_seconds // 2 ) if self .timeout_seconds > 0 else 300 # Use half the total timeout for pool operations
363
+
364
+ # Check timeout before parallel processing
365
+ self ._check_timeout (start_time , "parallel processing start" )
359
366
360
367
try :
361
368
with Pool (processes = max (cpu_count () - 1 , 1 )) as pool :
@@ -368,65 +375,55 @@ def find_anchors(
368
375
# Collect results with individual timeouts
369
376
for i , async_result in enumerate (async_results ):
370
377
try :
371
- # Check remaining time
378
+ # Check timeout before each result collection
379
+ self ._check_timeout (start_time , f"collecting n-gram { n_gram_lengths [i ]} results" )
380
+
381
+ # Check remaining time for pool timeout
372
382
elapsed_time = time .time () - start_time
373
- remaining_time = max (10 , self .timeout_seconds - elapsed_time )
383
+ remaining_time = max (10 , self .timeout_seconds - elapsed_time ) if self . timeout_seconds > 0 else pool_timeout
374
384
375
385
result = async_result .get (timeout = min (pool_timeout , remaining_time ))
376
386
results .append (result )
377
387
378
388
self .logger .debug (f"Completed n-gram length { n_gram_lengths [i ]} ({ i + 1 } /{ len (n_gram_lengths )} )" )
379
389
390
+ except AnchorSequenceTimeoutError :
391
+ # Re-raise timeout errors
392
+ raise
380
393
except Exception as e :
381
394
self .logger .warning (f"n-gram length { n_gram_lengths [i ]} failed or timed out: { str (e )} " )
382
395
results .append ([]) # Add empty result to maintain order
383
396
384
397
for anchors in results :
385
398
candidate_anchors .extend (anchors )
386
399
400
+ except AnchorSequenceTimeoutError :
401
+ # Re-raise timeout errors
402
+ raise
387
403
except Exception as e :
388
404
self .logger .error (f"Parallel processing failed: { str (e )} " )
389
- # Fall back to sequential processing with strict timeout
405
+ # Fall back to sequential processing with timeout checks
390
406
self .logger .info ("Falling back to sequential processing" )
391
407
for n in n_gram_lengths :
392
- elapsed_time = time .time () - start_time
393
- if elapsed_time >= self .timeout_seconds * 0.8 : # Use 80% of timeout
394
- self .logger .warning (f"Stopping sequential processing due to timeout after { elapsed_time :.1f} s" )
395
- break
396
-
397
408
try :
409
+ # Check timeout before each n-gram length
410
+ self ._check_timeout (start_time , f"sequential processing n-gram { n } " )
411
+
398
412
anchors = self ._process_ngram_length (
399
413
n , trans_words , all_words , ref_texts_clean , ref_words , self .min_sources
400
414
)
401
415
candidate_anchors .extend (anchors )
416
+ except AnchorSequenceTimeoutError :
417
+ # Re-raise timeout errors
418
+ raise
402
419
except Exception as e :
403
420
self .logger .warning (f"Sequential processing failed for n-gram length { n } : { str (e )} " )
404
421
continue
405
422
406
423
self .logger .info (f"Found { len (candidate_anchors )} candidate anchors in { time .time () - start_time :.1f} s" )
407
424
408
425
# Check timeout before expensive filtering operation
409
- elapsed_time = time .time () - start_time
410
- if elapsed_time >= self .timeout_seconds * 0.9 : # Use 90% of timeout
411
- self .logger .warning (f"Skipping overlap filtering due to timeout ({ elapsed_time :.1f} s elapsed)" )
412
- # Return basic scored anchors without filtering
413
- basic_scored = []
414
- for anchor in candidate_anchors [:100 ]: # Limit to first 100 anchors
415
- try :
416
- phrase_score = PhraseScore (
417
- total_score = 1.0 ,
418
- natural_break_score = 1.0 ,
419
- phrase_type = PhraseType .COMPLETE
420
- )
421
- basic_scored .append (ScoredAnchor (anchor = anchor , phrase_score = phrase_score ))
422
- except :
423
- continue
424
-
425
- # Save basic results to cache
426
- if basic_scored :
427
- self ._save_to_cache (cache_path , basic_scored )
428
-
429
- return basic_scored
426
+ self ._check_timeout (start_time , "overlap filtering start" )
430
427
431
428
filtered_anchors = self ._remove_overlapping_sequences (candidate_anchors , transcribed , transcription_result )
432
429
@@ -445,10 +442,8 @@ def find_anchors(
445
442
self .logger .error (f"Anchor sequence computation failed: { str (e )} " )
446
443
raise
447
444
finally :
448
- # Clean up timeout signal
449
- if self .timeout_seconds > 0 :
450
- signal .alarm (0 )
451
- signal .signal (signal .SIGALRM , old_handler )
445
+ # No cleanup needed for time-based timeout checks
446
+ pass
452
447
453
448
def _score_sequence (self , words : List [str ], context : str ) -> PhraseScore :
454
449
"""Score a sequence based on its phrase quality"""
@@ -625,15 +620,14 @@ def _remove_overlapping_sequences(
625
620
626
621
self .logger .info (f"Filtering { len (scored_anchors )} overlapping sequences" )
627
622
filtered_scored = []
628
- max_filter_time = 60 # Maximum 1 minute for filtering
629
- filter_start = time .time ()
630
623
631
624
for i , scored_anchor in enumerate (scored_anchors ):
632
- # Check timeout every 100 anchors
625
+ # Check timeout every 100 anchors using our timeout mechanism
633
626
if i % 100 == 0 :
634
- elapsed = time .time () - filter_start
635
- if elapsed > max_filter_time :
636
- self .logger .warning (f"Filtering timed out after { elapsed :.1f} s, returning { len (filtered_scored )} anchors" )
627
+ try :
628
+ self ._check_timeout (start_time , f"filtering anchors (processed { i } /{ len (scored_anchors )} )" )
629
+ except AnchorSequenceTimeoutError :
630
+ self .logger .warning (f"Filtering timed out, returning { len (filtered_scored )} anchors out of { len (scored_anchors )} " )
637
631
break
638
632
639
633
overlaps = False
0 commit comments