From 4d8c9b8d8516b340d7d248d11c6f9124ff77badb Mon Sep 17 00:00:00 2001 From: MoeexT Date: Mon, 8 Jun 2026 12:19:13 +0800 Subject: [PATCH] feat: task status polling with runtime unreachable detection MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - backend-python: global polling coroutine polls all RUNNING tasks every 2s - backend-python: mark task FAILED when runtime unreachable (httpx 60s timeout) - backend-python: singleton scheduler + startup() recovers RUNNING tasks on restart - runtime: Ray job connection failure counter (5 fails → FAILED) - runtime: stall detection (3600s no log progress → FAILED) --- runtime/datamate-python/app/main.py | 6 + .../interface/cleaning_task_routes.py | 7 +- .../service/cleaning_task_scheduler.py | 117 ++++++++++++------ .../datamate/scheduler/job_task_scheduler.py | 10 +- 4 files changed, 95 insertions(+), 45 deletions(-) diff --git a/runtime/datamate-python/app/main.py b/runtime/datamate-python/app/main.py index 1d837677..bc48dbef 100644 --- a/runtime/datamate-python/app/main.py +++ b/runtime/datamate-python/app/main.py @@ -73,6 +73,12 @@ def mask_db_url(url: str) -> Literal[b""] | str: init_executor(max_workers=10, max_concurrent_tasks=5) logger.info("Generation task executor initialized") + # 恢复未完成任务的轮询 + 启动全局状态轮询协程 + from app.module.cleaning.service.cleaning_task_scheduler import get_scheduler + scheduler = get_scheduler() + await scheduler.startup() + logger.info("Cleaning task status polling started") + yield # @shutdown diff --git a/runtime/datamate-python/app/module/cleaning/interface/cleaning_task_routes.py b/runtime/datamate-python/app/module/cleaning/interface/cleaning_task_routes.py index 5ba3a719..8e4cef92 100644 --- a/runtime/datamate-python/app/module/cleaning/interface/cleaning_task_routes.py +++ b/runtime/datamate-python/app/module/cleaning/interface/cleaning_task_routes.py @@ -47,19 +47,16 @@ def _get_task_service(db: AsyncSession) -> CleaningTaskService: CleaningTaskScheduler, CleaningTaskValidator, ) + from app.module.cleaning.service.cleaning_task_scheduler import get_scheduler from app.module.cleaning.repository import ( CleaningTaskRepository, CleaningResultRepository, OperatorInstanceRepository, ) - from app.module.cleaning.runtime_client import RuntimeClient from app.module.dataset.service import DatasetManagementService from app.module.shared.common.lineage import LineageService - runtime_client = RuntimeClient() - scheduler = CleaningTaskScheduler( - task_repo=CleaningTaskRepository(None), runtime_client=runtime_client - ) + scheduler = get_scheduler() # 使用全局单例,保持轮询状态跨请求共享 operator_service = _get_operator_service() dataset_service = DatasetManagementService(db) lineage_service = LineageService(db) diff --git a/runtime/datamate-python/app/module/cleaning/service/cleaning_task_scheduler.py b/runtime/datamate-python/app/module/cleaning/service/cleaning_task_scheduler.py index 779dd910..67a2baaf 100644 --- a/runtime/datamate-python/app/module/cleaning/service/cleaning_task_scheduler.py +++ b/runtime/datamate-python/app/module/cleaning/service/cleaning_task_scheduler.py @@ -6,6 +6,20 @@ logger = get_logger(__name__) +# 模块级单例:进程内所有请求共享同一个调度器 +_scheduler: "CleaningTaskScheduler | None" = None + + +def get_scheduler() -> "CleaningTaskScheduler": + """获取全局调度器单例,不存在则创建""" + global _scheduler + if _scheduler is None: + _scheduler = CleaningTaskScheduler( + CleaningTaskRepository(None), + RuntimeClient() + ) + return _scheduler + class CleaningTaskScheduler: """Scheduler for executing cleaning tasks""" @@ -13,7 +27,10 @@ class CleaningTaskScheduler: def __init__(self, task_repo: CleaningTaskRepository, runtime_client: RuntimeClient): self.task_repo = task_repo self.runtime_client = runtime_client - self._polling_tasks: dict[str, asyncio.Task] = {} + self._polling_task_ids: set[str] = set() # 待轮询的任务 ID 集合 + self._polling_started: bool = False + self._poll_failure_count: dict[str, int] = {} # runtime 不可达连续失败计数 + self._MAX_POLL_FAILURES = 1 # 一次超时(60s) → 标记 FAILED async def execute_task(self, db: AsyncSession, task_id: str, retry_count: int) -> bool: """Execute cleaning task""" @@ -30,8 +47,7 @@ async def execute_task(self, db: AsyncSession, task_id: str, retry_count: int) - submitted = await self.runtime_client.submit_task(task_id, retry_count) if submitted: - # Start background polling to sync task status from runtime - self._start_status_polling(task_id) + self._polling_task_ids.add(task_id) return submitted @@ -39,10 +55,8 @@ async def stop_task(self, db: AsyncSession, task_id: str) -> bool: """Stop cleaning task""" from app.module.cleaning.schema import CleaningTaskDto, CleaningTaskStatus - # Cancel background polling - if task_id in self._polling_tasks: - self._polling_tasks[task_id].cancel() - del self._polling_tasks[task_id] + self._polling_task_ids.discard(task_id) + self._poll_failure_count.pop(task_id, None) await self.runtime_client.stop_task(task_id) @@ -53,30 +67,72 @@ async def stop_task(self, db: AsyncSession, task_id: str) -> bool: await self.task_repo.update_task(db, task) return True - def _start_status_polling(self, task_id: str): - """Start background task to poll runtime for task status""" - - async def _poll_loop(): - from app.module.cleaning.schema import CleaningTaskDto, CleaningTaskStatus - from app.db.session import AsyncSessionLocal - from datetime import datetime + async def startup(self): + """进程启动时调用:恢复未完成任务的轮询 + 启动全局轮询协程""" + from app.module.cleaning.schema import CleaningTaskStatus + from app.db.session import AsyncSessionLocal + + # 从数据库恢复所有 RUNNING 状态的任务 + try: + async with AsyncSessionLocal() as db: + tasks = await self.task_repo.find_tasks(db, status=CleaningTaskStatus.RUNNING) + for task in tasks: + logger.info(f"[Polling] Recovered RUNNING task from DB: {task.id}") + self._polling_task_ids.add(task.id) + except Exception as e: + logger.error(f"[Polling] Failed to recover tasks from DB: {e}") + + # 启动全局轮询协程 + if not self._polling_started: + self._polling_started = True + asyncio.create_task(self._poll_all_tasks()) + logger.info("[Polling] Global polling loop started") + + async def _poll_all_tasks(self): + """全局轮询协程:每 2 秒轮询所有 RUNNING 任务的 runtime 状态""" + from app.module.cleaning.schema import CleaningTaskDto, CleaningTaskStatus + from app.db.session import AsyncSessionLocal + from datetime import datetime - logger.info(f"[Polling] Starting status polling for task {task_id}") - await asyncio.sleep(5) + logger.info("[Polling] Global status polling loop started") + terminal_statuses = {"completed", "failed", "cancelled", "stopped"} - terminal_statuses = {"completed", "failed", "cancelled", "stopped"} - max_polls = 1800 # Max 1 hour (2s interval) - poll_count = 0 + while True: + task_ids = list(self._polling_task_ids) + if not task_ids: + # 没有待轮询任务,休眠后继续等 + await asyncio.sleep(2) + continue - while poll_count < max_polls: + for task_id in task_ids: try: status_data = await self.runtime_client.get_task_status(task_id) - if status_data is None: - poll_count += 1 - await asyncio.sleep(2) + # runtime 不可达,累计失败计数 + count = self._poll_failure_count.get(task_id, 0) + 1 + self._poll_failure_count[task_id] = count + logger.warning( + f"[Polling] Task {task_id} unreachable ({count}/{self._MAX_POLL_FAILURES})" + ) + if count >= self._MAX_POLL_FAILURES: + logger.error( + f"[Polling] Task {task_id} marked FAILED: " + f"runtime unreachable after {count} attempts" + ) + async with AsyncSessionLocal() as db: + task_dto = CleaningTaskDto() + task_dto.id = task_id + task_dto.status = CleaningTaskStatus.FAILED + task_dto.finished_at = datetime.now() + await self.task_repo.update_task(db, task_dto) + await db.commit() + self._polling_task_ids.discard(task_id) + self._poll_failure_count.pop(task_id, None) continue + # runtime 可达,重置失败计数 + self._poll_failure_count.pop(task_id, None) + current_status = (status_data.get("status", "") or "").lower() logger.debug(f"[Polling] Task {task_id} status: {current_status}") @@ -94,19 +150,10 @@ async def _poll_loop(): logger.info( f"[Polling] Task {task_id} finished: {current_status}" ) - break + self._polling_task_ids.discard(task_id) + self._poll_failure_count.pop(task_id, None) - except asyncio.CancelledError: - break except Exception as e: logger.error(f"[Polling] Error polling task {task_id}: {e}") - poll_count += 1 - await asyncio.sleep(2) - else: - logger.warning(f"[Polling] Task {task_id} timed out") - - self._polling_tasks.pop(task_id, None) - - task = asyncio.create_task(_poll_loop()) - self._polling_tasks[task_id] = task + await asyncio.sleep(2) diff --git a/runtime/python-executor/datamate/scheduler/job_task_scheduler.py b/runtime/python-executor/datamate/scheduler/job_task_scheduler.py index 29ceb83d..3a266395 100644 --- a/runtime/python-executor/datamate/scheduler/job_task_scheduler.py +++ b/runtime/python-executor/datamate/scheduler/job_task_scheduler.py @@ -75,10 +75,10 @@ async def _execute(self): last_log_position = 0 # 记录已写入的日志位置 connection_failure_count = 0 # 连接失败计数器 max_connection_failures = 5 # 最大连接失败次数阈值 - + # 任务无进展超时:如果 job 一直是 RUNNING 但日志/进度没有任何变化, # 说明 workers 可能已全部丢失,判定为失败 - stall_timeout = int(os.getenv("RAY_JOB_STALL_TIMEOUT", "120")) # 默认 120 秒 + stall_timeout = int(os.getenv("RAY_JOB_STALL_TIMEOUT", "3600")) # 默认 120 秒 last_log_size = 0 last_active_time = datetime.now() @@ -94,7 +94,7 @@ async def _execute(self): try: info = client.get_job_info(self.job_id) job_status = info.status - + # 连接成功,重置失败计数器 connection_failure_count = 0 @@ -151,13 +151,13 @@ async def _execute(self): logger.error( f"Connection to Ray cluster failed (attempt {connection_failure_count}/{max_connection_failures}): {e}" ) - + if connection_failure_count >= max_connection_failures: self.status = TaskStatus.FAILED self.error = f"Lost connection to Ray cluster after {max_connection_failures} attempts: {str(e)}" logger.error(f"Task {self.task_id} failed: {self.error}") break - + except Exception as e: # 其他异常:记录警告但继续重试 logger.warning(f"Error checking job status: {e}")