Skip to content

Commit 7c4d6c3

Browse files
committed
feat(tests): add datasource job monitoring functionality to track all jobs related to a datasource and improve event handling
fix(tests): update job progress logging to handle content events and improve clarity in logs
1 parent a0e5466 commit 7c4d6c3

File tree

1 file changed

+122
-23
lines changed

1 file changed

+122
-23
lines changed

tests/test_csv_upload.py

Lines changed: 122 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -473,12 +473,33 @@ def subscribe_to_job(client, job_id, callback=None, timeout=300): # noqa: C901
473473

474474
# Define a default callback if none provided
475475
if callback is None:
476+
# Track the previous event type
477+
previous_event_type = [
478+
None
479+
] # Using list to allow modification in nested function
476480

477481
def default_callback(event):
478482
event_type = event.get("event_type", "unknown")
479483
progress = event.get("progress")
484+
485+
# Log when event type changes
486+
if (
487+
previous_event_type[0] is not None
488+
and previous_event_type[0] != event_type
489+
):
490+
if event_type.endswith("LLMContent"):
491+
logger.info(f"Job event: {event_type}")
492+
elif previous_event_type[0].endswith("LLMContent"):
493+
print()
494+
495+
# Update the previous event type
496+
previous_event_type[0] = event_type
497+
480498
if progress is not None:
481499
logger.info(f"Job progress: {progress}%")
500+
elif event_type.endswith("LLMContent"):
501+
if content := event.get("payload", {}).get("content", ""):
502+
print(f"\033[36m{content}\033[0m", end="", flush=True)
482503
else:
483504
logger.info(f"Job event: {event_type}")
484505

@@ -617,6 +638,97 @@ def process_events():
617638
processor_thread.join(timeout=2)
618639

619640

641+
def subscribe_to_datasource_jobs(client, datasource_id, callback=None, timeout=300):
642+
"""
643+
Subscribe to all jobs related to a datasource.
644+
645+
Args:
646+
client: InfactoryClient instance
647+
datasource_id: ID of the datasource to monitor jobs for
648+
callback: Optional callback function to process events
649+
timeout: Maximum time to wait in seconds
650+
651+
Returns:
652+
List of job IDs that were monitored
653+
"""
654+
logger = logging.getLogger("datasource_jobs")
655+
logger.info(f"Monitoring all jobs for datasource {datasource_id}")
656+
657+
start_time = time.time()
658+
monitored_jobs = set()
659+
threads = []
660+
661+
while True:
662+
# Check for timeout
663+
if time.time() - start_time > timeout:
664+
logger.info(f"Monitoring timeout reached after {timeout} seconds")
665+
break
666+
667+
try:
668+
# Query for all jobs with this datasource as source
669+
response = client.http_client.get(
670+
f"{client.base_url}/v1/jobs/status",
671+
params={"source": "datasource", "source_id": datasource_id},
672+
)
673+
674+
if response.status_code != 200:
675+
logger.error(
676+
f"Error getting jobs: {response.status_code} {response.text}"
677+
)
678+
time.sleep(5)
679+
continue
680+
681+
jobs = response.json()
682+
if not isinstance(jobs, list):
683+
if isinstance(jobs, dict) and "jobs" in jobs:
684+
jobs = jobs["jobs"]
685+
else:
686+
logger.warning(f"Unexpected jobs response format: {jobs}")
687+
time.sleep(5)
688+
continue
689+
690+
# Start monitoring any new jobs we find
691+
for job in jobs:
692+
job_id = job.get("id")
693+
if job_id and job_id not in monitored_jobs:
694+
logger.info(
695+
f"Found new job to monitor: {job_id} (status: {job.get('status')})"
696+
)
697+
monitored_jobs.add(job_id)
698+
699+
# Start a thread to monitor this specific job
700+
thread = threading.Thread(
701+
target=subscribe_to_job,
702+
args=(client, job_id, callback),
703+
kwargs={"timeout": timeout - (time.time() - start_time)},
704+
)
705+
thread.daemon = True
706+
thread.start()
707+
threads.append(thread)
708+
709+
# Check if all jobs are completed
710+
all_completed = True
711+
for job in jobs:
712+
if job.get("status") not in ["completed", "failed", "error"]:
713+
all_completed = False
714+
break
715+
716+
if jobs and all_completed:
717+
logger.info("All jobs for datasource have completed")
718+
break
719+
720+
except Exception as e:
721+
logger.error(f"Error monitoring datasource jobs: {e}")
722+
723+
time.sleep(5)
724+
725+
# Wait for all monitoring threads to finish
726+
for thread in threads:
727+
thread.join(timeout=1)
728+
729+
return list(monitored_jobs)
730+
731+
620732
def main(): # noqa: C901
621733
# Initialize client
622734
print_step(1, "Initialize client and authenticate")
@@ -739,36 +851,23 @@ def main(): # noqa: C901
739851
print("File upload request sent successfully!")
740852

741853
# Step 5: Monitor job progress
742-
print_step(5, "Monitor job progress")
854+
print_step(5, "Monitor all jobs for datasource")
743855

744-
# Try to subscribe to streaming updates first
745-
print("Attempting to subscribe to streaming job updates...")
746-
subscription_thread = threading.Thread(
747-
target=subscribe_to_job, args=(client, job_id), kwargs={"timeout": 300}
856+
# Subscribe to all jobs for this datasource instead of just one job
857+
print(f"Monitoring all jobs for datasource {datasource.id}...")
858+
monitored_jobs = subscribe_to_datasource_jobs(client, datasource.id, timeout=300)
859+
print(
860+
f"Monitored {len(monitored_jobs)} jobs for datasource: {', '.join(monitored_jobs)}"
748861
)
749-
subscription_thread.daemon = True
750-
subscription_thread.start()
751862

752-
# Also start monitoring with polling as a fallback
753-
print("Starting event monitoring thread...")
754-
event_monitoring_thread = threading.Thread(
755-
target=monitor_job_events, args=(client, job_id, 300, True)
756-
)
757-
event_monitoring_thread.daemon = True
758-
event_monitoring_thread.start()
759-
760-
# Monitor job status directly
863+
# We can still check our initial job as well
761864
job_success, job_status = wait_for_job_completion(
762-
client, job_id, timeout=300, poll_interval=2 # 5 minutes timeout
865+
client, job_id, timeout=300, poll_interval=2
763866
)
764867

765-
# Wait for monitoring threads to finish
766-
subscription_thread.join(timeout=2)
767-
event_monitoring_thread.join(timeout=2)
768-
769868
if not job_success:
770-
print(f"Job failed with status: {job_status}")
771-
# sys.exit(1)
869+
print(f"Initial job failed with status: {job_status}")
870+
# Do not exit, we'll continue to see what other jobs did
772871

773872
# Step 6: Wait for and list datalines
774873
print_step(6, "Wait for and list datalines")

0 commit comments

Comments
 (0)