Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions runtime/datamate-python/app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,12 @@ def mask_db_url(url: str) -> Literal[b""] | str:
init_executor(max_workers=10, max_concurrent_tasks=5)
logger.info("Generation task executor initialized")

# 恢复未完成任务的轮询 + 启动全局状态轮询协程
from app.module.cleaning.service.cleaning_task_scheduler import get_scheduler
scheduler = get_scheduler()
await scheduler.startup()
logger.info("Cleaning task status polling started")

yield

# @shutdown
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,19 +47,16 @@ def _get_task_service(db: AsyncSession) -> CleaningTaskService:
CleaningTaskScheduler,
CleaningTaskValidator,
)
from app.module.cleaning.service.cleaning_task_scheduler import get_scheduler
from app.module.cleaning.repository import (
CleaningTaskRepository,
CleaningResultRepository,
OperatorInstanceRepository,
)
from app.module.cleaning.runtime_client import RuntimeClient
from app.module.dataset.service import DatasetManagementService
from app.module.shared.common.lineage import LineageService

runtime_client = RuntimeClient()
scheduler = CleaningTaskScheduler(
task_repo=CleaningTaskRepository(None), runtime_client=runtime_client
)
scheduler = get_scheduler() # 使用全局单例,保持轮询状态跨请求共享
operator_service = _get_operator_service()
dataset_service = DatasetManagementService(db)
lineage_service = LineageService(db)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,31 @@

logger = get_logger(__name__)

# 模块级单例:进程内所有请求共享同一个调度器
_scheduler: "CleaningTaskScheduler | None" = None


def get_scheduler() -> "CleaningTaskScheduler":
"""获取全局调度器单例,不存在则创建"""
global _scheduler
if _scheduler is None:
_scheduler = CleaningTaskScheduler(
CleaningTaskRepository(None),
RuntimeClient()
)
return _scheduler


class CleaningTaskScheduler:
"""Scheduler for executing cleaning tasks"""

def __init__(self, task_repo: CleaningTaskRepository, runtime_client: RuntimeClient):
self.task_repo = task_repo
self.runtime_client = runtime_client
self._polling_tasks: dict[str, asyncio.Task] = {}
self._polling_task_ids: set[str] = set() # 待轮询的任务 ID 集合
self._polling_started: bool = False
self._poll_failure_count: dict[str, int] = {} # runtime 不可达连续失败计数
self._MAX_POLL_FAILURES = 1 # 一次超时(60s) → 标记 FAILED

async def execute_task(self, db: AsyncSession, task_id: str, retry_count: int) -> bool:
"""Execute cleaning task"""
Expand All @@ -30,19 +47,16 @@ async def execute_task(self, db: AsyncSession, task_id: str, retry_count: int) -
submitted = await self.runtime_client.submit_task(task_id, retry_count)

if submitted:
# Start background polling to sync task status from runtime
self._start_status_polling(task_id)
self._polling_task_ids.add(task_id)

return submitted

async def stop_task(self, db: AsyncSession, task_id: str) -> bool:
"""Stop cleaning task"""
from app.module.cleaning.schema import CleaningTaskDto, CleaningTaskStatus

# Cancel background polling
if task_id in self._polling_tasks:
self._polling_tasks[task_id].cancel()
del self._polling_tasks[task_id]
self._polling_task_ids.discard(task_id)
self._poll_failure_count.pop(task_id, None)

await self.runtime_client.stop_task(task_id)

Expand All @@ -53,30 +67,72 @@ async def stop_task(self, db: AsyncSession, task_id: str) -> bool:
await self.task_repo.update_task(db, task)
return True

def _start_status_polling(self, task_id: str):
"""Start background task to poll runtime for task status"""

async def _poll_loop():
from app.module.cleaning.schema import CleaningTaskDto, CleaningTaskStatus
from app.db.session import AsyncSessionLocal
from datetime import datetime
async def startup(self):
"""进程启动时调用:恢复未完成任务的轮询 + 启动全局轮询协程"""
from app.module.cleaning.schema import CleaningTaskStatus
from app.db.session import AsyncSessionLocal

# 从数据库恢复所有 RUNNING 状态的任务
try:
async with AsyncSessionLocal() as db:
tasks = await self.task_repo.find_tasks(db, status=CleaningTaskStatus.RUNNING)
for task in tasks:
logger.info(f"[Polling] Recovered RUNNING task from DB: {task.id}")
self._polling_task_ids.add(task.id)
except Exception as e:
logger.error(f"[Polling] Failed to recover tasks from DB: {e}")

# 启动全局轮询协程
if not self._polling_started:
self._polling_started = True
asyncio.create_task(self._poll_all_tasks())
logger.info("[Polling] Global polling loop started")

async def _poll_all_tasks(self):
"""全局轮询协程:每 2 秒轮询所有 RUNNING 任务的 runtime 状态"""
from app.module.cleaning.schema import CleaningTaskDto, CleaningTaskStatus
from app.db.session import AsyncSessionLocal
from datetime import datetime

logger.info(f"[Polling] Starting status polling for task {task_id}")
await asyncio.sleep(5)
logger.info("[Polling] Global status polling loop started")
terminal_statuses = {"completed", "failed", "cancelled", "stopped"}

terminal_statuses = {"completed", "failed", "cancelled", "stopped"}
max_polls = 1800 # Max 1 hour (2s interval)
poll_count = 0
while True:
task_ids = list(self._polling_task_ids)
if not task_ids:
# 没有待轮询任务,休眠后继续等
await asyncio.sleep(2)
continue

while poll_count < max_polls:
for task_id in task_ids:
try:
status_data = await self.runtime_client.get_task_status(task_id)

if status_data is None:
poll_count += 1
await asyncio.sleep(2)
# runtime 不可达,累计失败计数
count = self._poll_failure_count.get(task_id, 0) + 1
self._poll_failure_count[task_id] = count
logger.warning(
f"[Polling] Task {task_id} unreachable ({count}/{self._MAX_POLL_FAILURES})"
)
if count >= self._MAX_POLL_FAILURES:
logger.error(
f"[Polling] Task {task_id} marked FAILED: "
f"runtime unreachable after {count} attempts"
)
async with AsyncSessionLocal() as db:
task_dto = CleaningTaskDto()
task_dto.id = task_id
task_dto.status = CleaningTaskStatus.FAILED
task_dto.finished_at = datetime.now()
await self.task_repo.update_task(db, task_dto)
await db.commit()
self._polling_task_ids.discard(task_id)
self._poll_failure_count.pop(task_id, None)
continue

# runtime 可达,重置失败计数
self._poll_failure_count.pop(task_id, None)

current_status = (status_data.get("status", "") or "").lower()
logger.debug(f"[Polling] Task {task_id} status: {current_status}")

Expand All @@ -94,19 +150,10 @@ async def _poll_loop():
logger.info(
f"[Polling] Task {task_id} finished: {current_status}"
)
break
self._polling_task_ids.discard(task_id)
self._poll_failure_count.pop(task_id, None)

except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"[Polling] Error polling task {task_id}: {e}")

poll_count += 1
await asyncio.sleep(2)
else:
logger.warning(f"[Polling] Task {task_id} timed out")

self._polling_tasks.pop(task_id, None)

task = asyncio.create_task(_poll_loop())
self._polling_tasks[task_id] = task
await asyncio.sleep(2)
10 changes: 5 additions & 5 deletions runtime/python-executor/datamate/scheduler/job_task_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,10 @@ async def _execute(self):
last_log_position = 0 # 记录已写入的日志位置
connection_failure_count = 0 # 连接失败计数器
max_connection_failures = 5 # 最大连接失败次数阈值

# 任务无进展超时:如果 job 一直是 RUNNING 但日志/进度没有任何变化,
# 说明 workers 可能已全部丢失,判定为失败
stall_timeout = int(os.getenv("RAY_JOB_STALL_TIMEOUT", "120")) # 默认 120 秒
stall_timeout = int(os.getenv("RAY_JOB_STALL_TIMEOUT", "3600")) # 默认 120 秒
last_log_size = 0
last_active_time = datetime.now()

Expand All @@ -94,7 +94,7 @@ async def _execute(self):
try:
info = client.get_job_info(self.job_id)
job_status = info.status

# 连接成功,重置失败计数器
connection_failure_count = 0

Expand Down Expand Up @@ -151,13 +151,13 @@ async def _execute(self):
logger.error(
f"Connection to Ray cluster failed (attempt {connection_failure_count}/{max_connection_failures}): {e}"
)

if connection_failure_count >= max_connection_failures:
self.status = TaskStatus.FAILED
self.error = f"Lost connection to Ray cluster after {max_connection_failures} attempts: {str(e)}"
logger.error(f"Task {self.task_id} failed: {self.error}")
break

except Exception as e:
# 其他异常:记录警告但继续重试
logger.warning(f"Error checking job status: {e}")
Expand Down
Loading