From 62235d81ae68de107436d3f5d59b87b5161440f9 Mon Sep 17 00:00:00 2001 From: Rolled Date: Sun, 14 Jun 2026 07:05:19 -0700 Subject: [PATCH] fix: close leaked SQLAlchemy sessions in process_watcher_task (#5611) Co-authored-by: Shahar Glazner --- keep/api/bl/dismissal_expiry_bl.py | 95 +++++----- keep/api/bl/maintenance_windows_bl.py | 251 +++++++++++++------------- tests/test_session_cleanup.py | 140 ++++++++++++++ 3 files changed, 317 insertions(+), 169 deletions(-) create mode 100644 tests/test_session_cleanup.py diff --git a/keep/api/bl/dismissal_expiry_bl.py b/keep/api/bl/dismissal_expiry_bl.py index 8e1aa8892d..c7d18a9a66 100644 --- a/keep/api/bl/dismissal_expiry_bl.py +++ b/keep/api/bl/dismissal_expiry_bl.py @@ -12,7 +12,7 @@ from sqlmodel import Session, select from keep.api.core.db import get_session_sync from keep.api.core.db_utils import get_json_extract_field -from keep.api.core.elastic import ElasticClient +from keep.api.core.elastic import ElasticClient from keep.api.core.dependencies import get_pusher_client from keep.api.models.action_type import ActionType from keep.api.models.alert import AlertDto @@ -20,33 +20,33 @@ class DismissalExpiryBl: - + @staticmethod def get_alerts_with_expired_dismissals(session: Session) -> List[AlertEnrichment]: """ Get all AlertEnrichment records that have expired dismissedUntil timestamps. - + Returns enrichment records where: - 1. dismissed = true + 1. dismissed = true 2. dismissedUntil is not null and not "forever" 3. dismissedUntil timestamp is in the past - + Args: session: Database session - + Returns: List of AlertEnrichment objects with expired dismissals """ logger = logging.getLogger(__name__) now = datetime.datetime.now(datetime.timezone.utc) - + logger.info("Searching for enrichments with expired dismissals") - + # Query for enrichments with dismissed=true and dismissedUntil set # Use the proper helper function for cross-database compatibility dismissed_field = get_json_extract_field(session, AlertEnrichment.enrichments, "dismissed") dismissed_until_field = get_json_extract_field(session, AlertEnrichment.enrichments, "dismissUntil") - + # Build cross-database compatible boolean comparison # Different databases store/extract JSON booleans differently: # - SQLite: json_extract can return 1/0 for true/false OR "True"/"False"/"true"/"false" strings depending on how data was stored @@ -61,7 +61,7 @@ def get_alerts_with_expired_dismissals(session: Session) -> List[AlertEnrichment else: # For MySQL, compare with lowercase string "true" dismissed_condition = dismissed_field == "true" - + query = session.exec( select(AlertEnrichment).where( dismissed_condition, @@ -71,24 +71,24 @@ def get_alerts_with_expired_dismissals(session: Session) -> List[AlertEnrichment dismissed_until_field != "forever", ) ) - + candidate_enrichments = query.all() - + logger.info(f"Found {len(candidate_enrichments)} candidate enrichments with dismissals") - + # Filter in Python for safety and clarity (parsing ISO timestamps) expired_enrichments = [] for enrichment in candidate_enrichments: dismiss_until_str = enrichment.enrichments.get("dismissUntil") if not dismiss_until_str or dismiss_until_str == "forever": continue - + try: - # Parse the dismissedUntil timestamp + # Parse the dismissedUntil timestamp dismiss_until = datetime.datetime.strptime( dismiss_until_str, "%Y-%m-%dT%H:%M:%S.%fZ" ).replace(tzinfo=datetime.timezone.utc) - + # Check if it's expired (current time > dismissedUntil) if now > dismiss_until: logger.info( @@ -101,54 +101,55 @@ def get_alerts_with_expired_dismissals(session: Session) -> List[AlertEnrichment } ) expired_enrichments.append(enrichment) - + except (ValueError, TypeError) as e: # Log invalid timestamp but don't fail logger.warning( f"Invalid dismissedUntil timestamp for fingerprint {enrichment.alert_fingerprint}: {dismiss_until_str}", extra={ - "tenant_id": enrichment.tenant_id, + "tenant_id": enrichment.tenant_id, "fingerprint": enrichment.alert_fingerprint, "error": str(e) } ) continue - + logger.info(f"Found {len(expired_enrichments)} enrichments with expired dismissals") return expired_enrichments - + @staticmethod def check_dismissal_expiry(logger: logging.Logger, session: Optional[Session] = None): """ Check for alerts with expired dismissedUntil and restore them. - + This function: 1. Finds AlertEnrichment records with expired dismissedUntil timestamps 2. Updates their enrichments to set dismissed=false and dismissedUntil=null - 3. Cleans up disposable fields + 3. Cleans up disposable fields 4. Updates Elasticsearch indexes 5. Notifies UI of changes 6. Adds audit trail - + Args: logger: Logger instance for detailed logging session: Optional database session (creates new if None) """ logger.info("Starting dismissal expiry check") - + + _owns_session = session is None if session is None: session = get_session_sync() - + try: # Find enrichments with expired dismissedUntil expired_enrichments = DismissalExpiryBl.get_alerts_with_expired_dismissals(session) - + if not expired_enrichments: logger.info("No enrichments with expired dismissals found") return - + logger.info(f"Processing {len(expired_enrichments)} expired dismissal enrichments") - + # Process each expired enrichment for enrichment in expired_enrichments: logger.info( @@ -159,16 +160,16 @@ def check_dismissal_expiry(logger: logging.Logger, session: Optional[Session] = "dismissed_until": enrichment.enrichments.get("dismissedUntil") } ) - + # Store original values for audit original_dismissed = enrichment.enrichments.get("dismissed", False) original_dismissed_until = enrichment.enrichments.get("dismissedUntil") - + # Update enrichment - set back to not dismissed new_enrichments = enrichment.enrichments.copy() new_enrichments["dismissed"] = False new_enrichments["dismissUntil"] = None # Clear the original field - + # Reset status if it was set to suppressed during dismissal enrichment_status = enrichment.enrichments.get("status") if enrichment_status == "suppressed": @@ -183,7 +184,7 @@ def check_dismissal_expiry(logger: logging.Logger, session: Optional[Session] = "removed_status": enrichment_status } ) - + # Clean up ALL disposable fields (use pattern matching instead of hardcoded list) cleaned_fields = [] keys_to_remove = [] @@ -191,11 +192,11 @@ def check_dismissal_expiry(logger: logging.Logger, session: Optional[Session] = if field_name.startswith("disposable_"): keys_to_remove.append(field_name) cleaned_fields.append(field_name) - + # Remove the disposable fields for field_name in keys_to_remove: new_enrichments.pop(field_name) - + if cleaned_fields: logger.info( f"Cleaned up disposable fields: {cleaned_fields}", @@ -204,11 +205,11 @@ def check_dismissal_expiry(logger: logging.Logger, session: Optional[Session] = "fingerprint": enrichment.alert_fingerprint } ) - + # Update the enrichment record enrichment.enrichments = new_enrichments session.add(enrichment) - + # Add audit trail try: audit = AlertAudit( @@ -237,7 +238,7 @@ def check_dismissal_expiry(logger: logging.Logger, session: Optional[Session] = "fingerprint": enrichment.alert_fingerprint } ) - + # Update Elasticsearch index try: # Get the latest alert for this fingerprint to create AlertDto @@ -248,11 +249,11 @@ def check_dismissal_expiry(logger: logging.Logger, session: Optional[Session] = .order_by(Alert.timestamp.desc()) .limit(1) ).first() - + if latest_alert: # Create AlertDto with updated enrichments alert_data = latest_alert.event.copy() - + # Only update specific enrichment fields, don't override alert event data with None values enrichment_fields = ['dismissed', 'dismissUntil', 'note', 'assignee', 'status'] for field in enrichment_fields: @@ -261,9 +262,9 @@ def check_dismissal_expiry(logger: logging.Logger, session: Optional[Session] = elif field in new_enrichments and new_enrichments[field] is None and field in ['dismissed', 'dismissUntil']: # For dismissal fields, None is a valid value (means not dismissed) alert_data[field] = new_enrichments[field] - + alert_dto = AlertDto(**alert_data) - + elastic_client = ElasticClient(enrichment.tenant_id) elastic_client.index_alert(alert_dto) logger.info( @@ -281,7 +282,7 @@ def check_dismissal_expiry(logger: logging.Logger, session: Optional[Session] = "fingerprint": enrichment.alert_fingerprint } ) - + except Exception as e: logger.error( f"Failed to update Elasticsearch for fingerprint {enrichment.alert_fingerprint}: {e}", @@ -290,7 +291,7 @@ def check_dismissal_expiry(logger: logging.Logger, session: Optional[Session] = "fingerprint": enrichment.alert_fingerprint } ) - + # Notify UI of change try: pusher_client = get_pusher_client() @@ -299,7 +300,7 @@ def check_dismissal_expiry(logger: logging.Logger, session: Optional[Session] = f"private-{enrichment.tenant_id}", "alert-update", { - "fingerprint": enrichment.alert_fingerprint, + "fingerprint": enrichment.alert_fingerprint, "action": "dismissal_expired" } ) @@ -318,17 +319,19 @@ def check_dismissal_expiry(logger: logging.Logger, session: Optional[Session] = "fingerprint": enrichment.alert_fingerprint } ) - + # Commit all changes session.commit() logger.info( f"Successfully processed {len(expired_enrichments)} expired dismissal enrichments", extra={"processed_count": len(expired_enrichments)} ) - + except Exception as e: logger.error(f"Error during dismissal expiry check: {e}", exc_info=True) session.rollback() raise finally: + if _owns_session: + session.close() logger.info("Dismissal expiry check completed") diff --git a/keep/api/bl/maintenance_windows_bl.py b/keep/api/bl/maintenance_windows_bl.py index dd3b9fb255..5f71ae6158 100644 --- a/keep/api/bl/maintenance_windows_bl.py +++ b/keep/api/bl/maintenance_windows_bl.py @@ -179,145 +179,150 @@ def recover_strategy( """ logger.info("Starting recover strategy for maintenance windows review.") env = celpy.Environment() + _owns_session = session is None if session is None: session = get_session_sync() - windows = get_maintenance_windows_started(session) - alerts_in_maint = get_alerts_by_status(AlertStatus.MAINTENANCE, session) - fingerprints_to_check: set = set() - for alert in alerts_in_maint: - active = False - for window in windows: - w_start = window.start_time - w_end = window.end_time - is_enable = window.enabled - if window.tenant_id != alert.tenant_id: - continue - # Check active windows - if ( - w_start < alert.timestamp - and alert.timestamp < w_end - and w_end > datetime.datetime.utcnow() - and is_enable - ): - logger.info("Checking alert %s in maintenance window %s", alert.id, window.id) - is_in_cel = MaintenanceWindowsBl.evaluate_cel( - window, alert, env, logger, {"tenant_id": alert.tenant_id, "alert_id": alert.id} + try: + windows = get_maintenance_windows_started(session) + alerts_in_maint = get_alerts_by_status(AlertStatus.MAINTENANCE, session) + fingerprints_to_check: set = set() + for alert in alerts_in_maint: + active = False + for window in windows: + w_start = window.start_time + w_end = window.end_time + is_enable = window.enabled + if window.tenant_id != alert.tenant_id: + continue + # Check active windows + if ( + w_start < alert.timestamp + and alert.timestamp < w_end + and w_end > datetime.datetime.utcnow() + and is_enable + ): + logger.info("Checking alert %s in maintenance window %s", alert.id, window.id) + is_in_cel = MaintenanceWindowsBl.evaluate_cel( + window, alert, env, logger, {"tenant_id": alert.tenant_id, "alert_id": alert.id} + ) + # Recover source structure + if not isinstance(alert.event.get("source"), list): + alert.event["source"] = [alert.event["source"]] + if is_in_cel: + active = True + set_maintenance_windows_trace(alert, window, session) + logger.info("Alert %s is blocked due to the maintenance window: %s.", alert.id, window.id) + break + if not active: + recover_prev_alert_status(alert, session) + fingerprints_to_check.add((alert.tenant_id, alert.fingerprint)) + add_audit( + tenant_id=alert.tenant_id, + fingerprint=alert.fingerprint, + user_id="system", + action=ActionType.MAINTENANCE_EXPIRED, + description=( + f"Alert {alert.id} has recover its previous status, " + f"from {alert.event.get('previous_status')} to {alert.event.get('status')}" + ), ) - # Recover source structure - if not isinstance(alert.event.get("source"), list): - alert.event["source"] = [alert.event["source"]] - if is_in_cel: - active = True - set_maintenance_windows_trace(alert, window, session) - logger.info("Alert %s is blocked due to the maintenance window: %s.", alert.id, window.id) - break - if not active: - recover_prev_alert_status(alert, session) - fingerprints_to_check.add((alert.tenant_id, alert.fingerprint)) - add_audit( - tenant_id=alert.tenant_id, - fingerprint=alert.fingerprint, - user_id="system", - action=ActionType.MAINTENANCE_EXPIRED, - description=( - f"Alert {alert.id} has recover its previous status, " - f"from {alert.event.get('previous_status')} to {alert.event.get('status')}" - ), - ) - for (tenant, fp) in fingerprints_to_check: - last_alert = get_last_alert_by_fingerprint(tenant, fp, session) - alert = get_alert_by_event_id(tenant, str(last_alert.alert_id), session) - if "previous_status" not in alert.event: - logger.info( - f"Alert {alert.id} does not have previous status, cannot proceed with recover strategy", - extra={"tenant_id": tenant, "fingerprint": fp, "alert_id": alert.id, "alert.status": alert.event.get("status")}, - ) - continue - if not isinstance(alert.event.get("source"), list): - alert.event["source"] = [alert.event["source"]] - alert_dto = AlertDto(**alert.event) - with tracer.start_as_current_span("mw_recover_strategy_push_to_workflows"): - try: - # Now run any workflow that should run based on this alert - # TODO: this should publish event - workflow_manager = WorkflowManager.get_instance() - # insert the events to the workflow manager process queue - logger.info("Adding event to the workflow manager queue") - workflow_manager.insert_events(tenant, [alert_dto]) - logger.info("Added event to the workflow manager queue") - except Exception: - logger.exception( - "Failed to run workflows based on alerts", - extra={ - "provider_type": alert_dto.providerType, - "provider_id": alert_dto.providerId, - "tenant_id": tenant, - }, + for (tenant, fp) in fingerprints_to_check: + last_alert = get_last_alert_by_fingerprint(tenant, fp, session) + alert = get_alert_by_event_id(tenant, str(last_alert.alert_id), session) + if "previous_status" not in alert.event: + logger.info( + f"Alert {alert.id} does not have previous status, cannot proceed with recover strategy", + extra={"tenant_id": tenant, "fingerprint": fp, "alert_id": alert.id, "alert.status": alert.event.get("status")}, ) - - with tracer.start_as_current_span("mw_recover_strategy_run_rules_engine"): - # Now we need to run the rules engine - if KEEP_CORRELATION_ENABLED: - incidents = [] + continue + if not isinstance(alert.event.get("source"), list): + alert.event["source"] = [alert.event["source"]] + alert_dto = AlertDto(**alert.event) + with tracer.start_as_current_span("mw_recover_strategy_push_to_workflows"): try: - rules_engine = RulesEngine(tenant_id=tenant) - # handle incidents, also handle workflow execution as - incidents = rules_engine.run_rules( - [alert_dto], session=session - ) + # Now run any workflow that should run based on this alert + # TODO: this should publish event + workflow_manager = WorkflowManager.get_instance() + # insert the events to the workflow manager process queue + logger.info("Adding event to the workflow manager queue") + workflow_manager.insert_events(tenant, [alert_dto]) + logger.info("Added event to the workflow manager queue") except Exception: logger.exception( - "Failed to run rules engine", + "Failed to run workflows based on alerts", extra={ "provider_type": alert_dto.providerType, "provider_id": alert_dto.providerId, "tenant_id": tenant, }, ) - pusher_cache = get_notification_cache() - if incidents and pusher_cache.should_notify(tenant, "incident-change"): - pusher_client = get_pusher_client() + + with tracer.start_as_current_span("mw_recover_strategy_run_rules_engine"): + # Now we need to run the rules engine + if KEEP_CORRELATION_ENABLED: + incidents = [] try: - pusher_client.trigger( - f"private-{tenant}", - "incident-change", - {}, + rules_engine = RulesEngine(tenant_id=tenant) + # handle incidents, also handle workflow execution as + incidents = rules_engine.run_rules( + [alert_dto], session=session ) except Exception: - logger.exception("Failed to tell the client to pull incidents") + logger.exception( + "Failed to run rules engine", + extra={ + "provider_type": alert_dto.providerType, + "provider_id": alert_dto.providerId, + "tenant_id": tenant, + }, + ) + pusher_cache = get_notification_cache() + if incidents and pusher_cache.should_notify(tenant, "incident-change"): + pusher_client = get_pusher_client() + try: + pusher_client.trigger( + f"private-{tenant}", + "incident-change", + {}, + ) + except Exception: + logger.exception("Failed to tell the client to pull incidents") - try: - presets = get_all_presets_dtos(tenant) - rules_engine = RulesEngine(tenant_id=tenant) - presets_do_update = [] - for preset_dto in presets: - # filter the alerts based on the search query - filtered_alerts = rules_engine.filter_alerts( - [alert_dto], preset_dto.cel_query - ) - # if not related alerts, no need to update - if not filtered_alerts: - continue - presets_do_update.append(preset_dto) - if pusher_cache.should_notify(tenant, "poll-presets"): - try: - pusher_client.trigger( - f"private-{tenant}", - "poll-presets", - json.dumps( - [p.name.lower() for p in presets_do_update], default=str - ), + try: + presets = get_all_presets_dtos(tenant) + rules_engine = RulesEngine(tenant_id=tenant) + presets_do_update = [] + for preset_dto in presets: + # filter the alerts based on the search query + filtered_alerts = rules_engine.filter_alerts( + [alert_dto], preset_dto.cel_query ) - except Exception: - logger.exception("Failed to send presets via pusher") - except Exception: - logger.exception( - "Failed to send presets via pusher", - extra={ - "provider_type": alert_dto.providerType, - "provider_id": alert_dto.providerId, - "tenant_id": tenant, - }, - ) - logger.info("Finished recover strategy for maintenance windows review.") \ No newline at end of file + # if not related alerts, no need to update + if not filtered_alerts: + continue + presets_do_update.append(preset_dto) + if pusher_cache.should_notify(tenant, "poll-presets"): + try: + pusher_client.trigger( + f"private-{tenant}", + "poll-presets", + json.dumps( + [p.name.lower() for p in presets_do_update], default=str + ), + ) + except Exception: + logger.exception("Failed to send presets via pusher") + except Exception: + logger.exception( + "Failed to send presets via pusher", + extra={ + "provider_type": alert_dto.providerType, + "provider_id": alert_dto.providerId, + "tenant_id": tenant, + }, + ) + logger.info("Finished recover strategy for maintenance windows review.") + finally: + if _owns_session: + session.close() diff --git a/tests/test_session_cleanup.py b/tests/test_session_cleanup.py new file mode 100644 index 0000000000..b24a8acffc --- /dev/null +++ b/tests/test_session_cleanup.py @@ -0,0 +1,140 @@ +""" +Regression tests for SQLAlchemy session cleanup in background tasks. + +When recover_strategy() and check_dismissal_expiry() are called from +process_watcher_task via run_in_executor (session=None), they create +sessions internally via get_session_sync(). These sessions must be +closed in a finally block to prevent connection pool exhaustion. + +See: https://github.com/keephq/keep/issues/5496 +""" + +import logging +from unittest.mock import MagicMock, patch + +import pytest + +from keep.api.bl.dismissal_expiry_bl import DismissalExpiryBl +from keep.api.bl.maintenance_windows_bl import MaintenanceWindowsBl + + +class TestRecoverStrategySessionCleanup: + """MaintenanceWindowsBl.recover_strategy session lifecycle tests.""" + + @patch("keep.api.bl.maintenance_windows_bl.get_alerts_by_status", return_value=[]) + @patch( + "keep.api.bl.maintenance_windows_bl.get_maintenance_windows_started", + return_value=[], + ) + @patch("keep.api.bl.maintenance_windows_bl.get_session_sync") + def test_closes_session_when_created_internally( + self, mock_get_session, mock_get_windows, mock_get_alerts + ): + """Session created via get_session_sync must be closed after execution.""" + mock_session = MagicMock() + mock_get_session.return_value = mock_session + + MaintenanceWindowsBl.recover_strategy(logger=logging.getLogger(__name__)) + + mock_get_session.assert_called_once() + mock_session.close.assert_called_once() + + @patch("keep.api.bl.maintenance_windows_bl.get_alerts_by_status", return_value=[]) + @patch( + "keep.api.bl.maintenance_windows_bl.get_maintenance_windows_started", + return_value=[], + ) + @patch("keep.api.bl.maintenance_windows_bl.get_session_sync") + def test_does_not_close_caller_provided_session( + self, mock_get_session, mock_get_windows, mock_get_alerts + ): + """When a caller provides a session, recover_strategy must not close it.""" + caller_session = MagicMock() + + MaintenanceWindowsBl.recover_strategy( + logger=logging.getLogger(__name__), session=caller_session + ) + + mock_get_session.assert_not_called() + caller_session.close.assert_not_called() + + @patch("keep.api.bl.maintenance_windows_bl.get_alerts_by_status") + @patch( + "keep.api.bl.maintenance_windows_bl.get_maintenance_windows_started", + return_value=[], + ) + @patch("keep.api.bl.maintenance_windows_bl.get_session_sync") + def test_closes_session_on_exception( + self, mock_get_session, mock_get_windows, mock_get_alerts + ): + """Session must be closed even when an exception occurs mid-execution.""" + mock_session = MagicMock() + mock_get_session.return_value = mock_session + mock_get_alerts.side_effect = RuntimeError("simulated DB error") + + with pytest.raises(RuntimeError, match="simulated DB error"): + MaintenanceWindowsBl.recover_strategy( + logger=logging.getLogger(__name__) + ) + + mock_session.close.assert_called_once() + + +class TestCheckDismissalExpirySessionCleanup: + """DismissalExpiryBl.check_dismissal_expiry session lifecycle tests.""" + + @patch( + "keep.api.bl.dismissal_expiry_bl.DismissalExpiryBl.get_alerts_with_expired_dismissals", + return_value=[], + ) + @patch("keep.api.bl.dismissal_expiry_bl.get_session_sync") + def test_closes_session_when_created_internally( + self, mock_get_session, mock_get_expired + ): + """Session created via get_session_sync must be closed after execution.""" + mock_session = MagicMock() + mock_get_session.return_value = mock_session + + DismissalExpiryBl.check_dismissal_expiry( + logger=logging.getLogger(__name__) + ) + + mock_get_session.assert_called_once() + mock_session.close.assert_called_once() + + @patch( + "keep.api.bl.dismissal_expiry_bl.DismissalExpiryBl.get_alerts_with_expired_dismissals", + return_value=[], + ) + @patch("keep.api.bl.dismissal_expiry_bl.get_session_sync") + def test_does_not_close_caller_provided_session( + self, mock_get_session, mock_get_expired + ): + """When a caller provides a session, check_dismissal_expiry must not close it.""" + caller_session = MagicMock() + + DismissalExpiryBl.check_dismissal_expiry( + logger=logging.getLogger(__name__), session=caller_session + ) + + mock_get_session.assert_not_called() + caller_session.close.assert_not_called() + + @patch( + "keep.api.bl.dismissal_expiry_bl.DismissalExpiryBl.get_alerts_with_expired_dismissals" + ) + @patch("keep.api.bl.dismissal_expiry_bl.get_session_sync") + def test_closes_session_on_exception( + self, mock_get_session, mock_get_expired + ): + """Session must be closed even when an exception occurs mid-execution.""" + mock_session = MagicMock() + mock_get_session.return_value = mock_session + mock_get_expired.side_effect = RuntimeError("simulated DB error") + + with pytest.raises(RuntimeError, match="simulated DB error"): + DismissalExpiryBl.check_dismissal_expiry( + logger=logging.getLogger(__name__) + ) + + mock_session.close.assert_called_once()