Skip to content
Merged
19 changes: 14 additions & 5 deletions src/connectors/manager.ts
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import { Connector, ConnectorType, ConnectorRegistry, ExecuteOptions, ConnectorConfig } from "./interface.js";
import { SSHTunnel } from "../utils/ssh-tunnel.js";
import type { SSHTunnelConfig } from "../types/ssh.js";
import type { SSHTunnelConfig, SSHTunnelInfo } from "../types/ssh.js";
import type { SourceConfig } from "../types/config.js";
import { buildDSNFromSource } from "../config/toml-loader.js";
import { getDatabaseTypeFromDSN, getDefaultPortForType } from "../utils/dsn-obfuscate.js";
import { redactDSN } from "../config/env.js";
import { SafeURL } from "../utils/safe-url.js";
import { generateRdsAuthToken } from "../utils/aws-rds-signer.js";
import { parseSSHConfig, looksLikeSSHAlias, getDefaultSSHConfigPath } from "../utils/ssh-config-parser.js";
import { TUNNEL_ERROR_MARKER } from "../utils/error-classifier.js";

// Singleton instance for global access
let managerInstance: ConnectorManager | null = null;
Expand Down Expand Up @@ -191,10 +192,18 @@ export class ConnectorManager {

// Create and establish SSH tunnel
const tunnel = new SSHTunnel();
const tunnelInfo = await tunnel.establish(sshConfig, {
targetHost,
targetPort,
});
let tunnelInfo: SSHTunnelInfo;
try {
tunnelInfo = await tunnel.establish(sshConfig, {
targetHost,
targetPort,
});
} catch (error) {
if (error && typeof error === "object") {
(error as Record<string, unknown>)[TUNNEL_ERROR_MARKER] = true;
}
throw error;
}

// Update DSN to use local tunnel endpoint
url.hostname = "127.0.0.1";
Expand Down
47 changes: 45 additions & 2 deletions src/tools/__tests__/custom-tool-handler.test.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
import { describe, it, expect } from "vitest";
import { describe, it, expect, vi, afterEach } from "vitest";
import { z } from "zod";
import {
buildZodSchemaFromParameters,
buildInputSchema,
createCustomToolHandler,
} from "../custom-tool-handler.js";
import type { ParameterConfig } from "../../types/config.js";
import { ConnectorManager } from "../../connectors/manager.js";
import type { ToolConfig, ParameterConfig } from "../../types/config.js";

// Auto-mock the connector manager so we control connection/execution behavior
vi.mock("../../connectors/manager.js");

describe("Custom Tool Handler", () => {
describe("buildZodSchemaFromParameters", () => {
Expand Down Expand Up @@ -357,4 +362,42 @@ describe("Custom Tool Handler", () => {
expect(schema.required).toBeUndefined();
});
});

describe("createCustomToolHandler connection error classification", () => {
afterEach(() => {
vi.clearAllMocks();
});

it("returns SOURCE_UNREACHABLE (not a SQL error) when the connector throws a network error", async () => {
const econn: any = new Error("connect ECONNREFUSED 127.0.0.1:5432");
econn.code = "ECONNREFUSED";

vi.mocked(ConnectorManager.ensureConnected).mockResolvedValue(undefined as any);
vi.mocked(ConnectorManager.getCurrentConnector).mockReturnValue({
id: "postgres",
getId: () => "prod",
executeSQL: vi.fn().mockRejectedValue(econn),
} as any);
vi.mocked(ConnectorManager.getSourceConfig).mockReturnValue({
id: "prod",
type: "postgres",
} as any);

const toolConfig: ToolConfig = {
name: "get_user",
source: "prod",
statement: "SELECT * FROM users",
} as any;

const handler = createCustomToolHandler(toolConfig);
const res: any = await handler({}, {});
const payload = JSON.parse(res.content[0].text);

expect(res.isError).toBe(true);
expect(payload.code).toBe("SOURCE_UNREACHABLE");
expect(payload.details.source_id).toBe(toolConfig.source);
// Connection failures must NOT be augmented with SQL-context debugging info
expect(payload.error).not.toContain("SQL:");
});
});
});
59 changes: 59 additions & 0 deletions src/tools/__tests__/execute-sql.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,65 @@ describe('execute-sql tool', () => {
expect(parsedResult.error).toBe('Database error');
expect(parsedResult.code).toBe('EXECUTION_ERROR');
});

it('returns SOURCE_UNREACHABLE when the connector throws a network error', async () => {
const econn: any = new Error('connect ECONNREFUSED 127.0.0.1:5432');
econn.code = 'ECONNREFUSED';
mockGetCurrentConnector.mockReturnValue({
id: 'postgres',
getId: () => 'prod',
executeSQL: vi.fn().mockRejectedValue(econn),
} as any);
vi.mocked(ConnectorManager.getSourceConfig).mockReturnValue({ id: 'prod', type: 'postgres' } as any);
vi.mocked(ConnectorManager.ensureConnected).mockResolvedValue(undefined as any);

const handler = createExecuteSqlToolHandler('prod');
const res: any = await handler({ sql: 'SELECT 1' }, {});
const payload = JSON.parse(res.content[0].text);

expect(res.isError).toBe(true);
expect(payload.code).toBe('SOURCE_UNREACHABLE');
expect(payload.details.source_id).toBe('prod');
});

it('falls through to EXECUTION_ERROR when the source config is null', async () => {
const econn: any = new Error('connect ECONNREFUSED 127.0.0.1:5432');
econn.code = 'ECONNREFUSED';
mockGetCurrentConnector.mockReturnValue({
id: 'postgres',
getId: () => 'prod',
executeSQL: vi.fn().mockRejectedValue(econn),
} as any);
vi.mocked(ConnectorManager.getSourceConfig).mockReturnValue(null as any);
vi.mocked(ConnectorManager.ensureConnected).mockResolvedValue(undefined as any);

const handler = createExecuteSqlToolHandler('prod');
const res: any = await handler({ sql: 'SELECT 1' }, {});
const payload = JSON.parse(res.content[0].text);

expect(res.isError).toBe(true);
expect(payload.code).toBe('EXECUTION_ERROR');
});

it('uses the display source id "default" in single-source mode', async () => {
const econn: any = new Error('connect ECONNREFUSED 127.0.0.1:5432');
econn.code = 'ECONNREFUSED';
mockGetCurrentConnector.mockReturnValue({
id: 'postgres',
getId: () => 'default',
executeSQL: vi.fn().mockRejectedValue(econn),
} as any);
vi.mocked(ConnectorManager.getSourceConfig).mockReturnValue({ type: 'postgres' } as any);
vi.mocked(ConnectorManager.ensureConnected).mockResolvedValue(undefined as any);

const handler = createExecuteSqlToolHandler();
const res: any = await handler({ sql: 'SELECT 1' }, {});
const payload = JSON.parse(res.content[0].text);

expect(res.isError).toBe(true);
expect(payload.code).toBe('SOURCE_UNREACHABLE');
expect(payload.details.source_id).toBe('default');
});
});

describe('read-only mode enforcement', () => {
Expand Down
27 changes: 27 additions & 0 deletions src/tools/__tests__/search-objects.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1208,6 +1208,33 @@ describe('search_database_objects tool', () => {
const parsed = parseToolResponse(result);
expect(parsed.code).toBe('SEARCH_ERROR');
});

it('returns AUTH_FAILED when the connector throws a login error', async () => {
const elogin: any = new Error('Login failed for user');
elogin.code = 'ELOGIN';
// make every method the handler might call reject with the auth error
const failing = {
id: 'sqlserver',
getId: () => 'mssql',
getDefaultSchema: vi.fn().mockRejectedValue(elogin),
getSchemas: vi.fn().mockRejectedValue(elogin),
getTables: vi.fn().mockRejectedValue(elogin),
};
mockGetCurrentConnector.mockReturnValue(failing as any);
vi.mocked(ConnectorManager.ensureConnected).mockResolvedValue(undefined as any);
vi.mocked(ConnectorManager.getSourceConfig).mockReturnValue({ id: 'mssql', type: 'sqlserver' } as any);

const handler = createSearchDatabaseObjectsToolHandler('mssql');
const result: any = await handler(
{ object_type: 'table', detail_level: 'names', limit: 100 },
{}
);
const payload = parseToolResponse(result);

expect(result.isError).toBe(true);
expect(payload.code).toBe('AUTH_FAILED');
expect(payload.details.source_id).toBe('mssql');
});
});

describe('case insensitivity', () => {
Expand Down
6 changes: 6 additions & 0 deletions src/tools/custom-tool-handler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import {
isAllowedInReadonlyMode,
createReadonlyViolationMessage,
trackToolRequest,
tryClassifyConnectionError,
} from "../utils/tool-handler-helpers.js";

/**
Expand Down Expand Up @@ -213,6 +214,11 @@ export function createCustomToolHandler(toolConfig: ToolConfig) {
success = false;
errorMessage = (error as Error).message;

// A connection/access failure is not a SQL problem — classify and return
// it cleanly, ahead of the ZodError / SQL-context augmentation below.
const classified = tryClassifyConnectionError(error, toolConfig.source, toolConfig.source);
if (classified) return classified;

// Provide helpful error messages for common issues
if (error instanceof z.ZodError) {
const issues = error.issues.map((i) => `${i.path.join(".")}: ${i.message}`).join("; ");
Expand Down
3 changes: 3 additions & 0 deletions src/tools/execute-sql.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import { BUILTIN_TOOL_EXECUTE_SQL } from "./builtin-tools.js";
import {
getEffectiveSourceId,
trackToolRequest,
tryClassifyConnectionError,
} from "../utils/tool-handler-helpers.js";
import { splitSQLStatements } from "../utils/sql-parser.js";

Expand Down Expand Up @@ -81,6 +82,8 @@ export function createExecuteSqlToolHandler(sourceId?: string) {
} catch (error) {
success = false;
errorMessage = (error as Error).message;
const classified = tryClassifyConnectionError(error, sourceId, effectiveSourceId);
if (classified) return classified;
return createToolErrorResponse(errorMessage, "EXECUTION_ERROR");
} finally {
// Track the request
Expand Down
3 changes: 3 additions & 0 deletions src/tools/search-objects.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import { quoteQualifiedIdentifier } from "../utils/identifier-quoter.js";
import {
getEffectiveSourceId,
trackToolRequest,
tryClassifyConnectionError,
} from "../utils/tool-handler-helpers.js";

/**
Expand Down Expand Up @@ -714,6 +715,8 @@ export function createSearchDatabaseObjectsToolHandler(sourceId?: string) {
} catch (error) {
success = false;
errorMessage = (error as Error).message;
const classified = tryClassifyConnectionError(error, sourceId, effectiveSourceId);
if (classified) return classified;
return createToolErrorResponse(
`Error searching database objects: ${errorMessage}`,
"SEARCH_ERROR"
Expand Down
46 changes: 46 additions & 0 deletions src/utils/__tests__/error-classifier.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import { describe, it, expect } from "vitest";
import { classifyConnectionError, TUNNEL_ERROR_MARKER } from "../error-classifier.js";

describe("classifyConnectionError", () => {
it("classifies network socket errors as SOURCE_UNREACHABLE", () => {
for (const code of ["ECONNREFUSED", "ETIMEDOUT", "ENOTFOUND", "EHOSTUNREACH", "ENETUNREACH", "ECONNRESET"]) {
const result = classifyConnectionError({ code }, "postgres", "staging");
expect(result?.code).toBe("SOURCE_UNREACHABLE");
expect(result?.message).toContain("staging");
}
});

it("classifies postgres auth errors as AUTH_FAILED", () => {
expect(classifyConnectionError({ code: "28P01" }, "postgres", "prod")?.code).toBe("AUTH_FAILED");
expect(classifyConnectionError({ code: "28000" }, "postgres", "prod")?.code).toBe("AUTH_FAILED");
});

it("classifies mysql/mariadb auth errors via code or errno", () => {
expect(classifyConnectionError({ code: "ER_ACCESS_DENIED_ERROR" }, "mysql", "m")?.code).toBe("AUTH_FAILED");
expect(classifyConnectionError({ errno: 1045 }, "mariadb", "m")?.code).toBe("AUTH_FAILED");
// 1698 = ER_ACCESS_DENIED_NO_PASSWORD_ERROR
expect(classifyConnectionError({ errno: 1698 }, "mysql", "m")?.code).toBe("AUTH_FAILED");
expect(classifyConnectionError({ errno: 1698 }, "mariadb", "m")?.code).toBe("AUTH_FAILED");
});

it("classifies sqlserver login errors as AUTH_FAILED", () => {
expect(classifyConnectionError({ code: "ELOGIN" }, "sqlserver", "s")?.code).toBe("AUTH_FAILED");
});

it("classifies marked SSH tunnel errors as TUNNEL_FAILED, ahead of network code", () => {
const err: any = { code: "ECONNREFUSED" };
err[TUNNEL_ERROR_MARKER] = true;
expect(classifyConnectionError(err, "postgres", "viaBastion")?.code).toBe("TUNNEL_FAILED");
});

it("returns null for unrecognized errors and non-objects", () => {
expect(classifyConnectionError({ code: "42601" }, "postgres", "x")).toBeNull(); // syntax error
expect(classifyConnectionError(new Error("boom"), "postgres", "x")).toBeNull();
expect(classifyConnectionError("nope", "postgres", "x")).toBeNull();
expect(classifyConnectionError(null, "postgres", "x")).toBeNull();
});

it("does not treat a mysql auth code as auth for a postgres source", () => {
expect(classifyConnectionError({ errno: 1045 }, "postgres", "x")).toBeNull();
});
});
90 changes: 90 additions & 0 deletions src/utils/error-classifier.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import type { ConnectorType } from "../connectors/interface.js";

/**
* Distinct error codes for connection/access failures, so an MCP client can
* tell "the source is down / mis-credentialed" (restore access) from "your
* query is wrong" (fix the SQL). Anything not matched here is left to the
* caller's existing generic error path.
*/
export type ConnectionErrorCode = "SOURCE_UNREACHABLE" | "AUTH_FAILED" | "TUNNEL_FAILED";

/**
* Property set on errors thrown while establishing an SSH tunnel, so the
* classifier can distinguish TUNNEL_FAILED from a plain network failure
* without parsing message text. Set in ConnectorManager.connectSource.
*/
export const TUNNEL_ERROR_MARKER = "__dbhubSSHTunnelError";

// Node socket-level codes that mean "could not reach / lost the source".
// Timeout (ETIMEDOUT) is folded in here: refused vs timed-out differ at the
// TCP level but call for the same remediation.
const NETWORK_CODES = new Set([
"ECONNREFUSED",
"ETIMEDOUT",
"ENOTFOUND",
"EHOSTUNREACH",
"ENETUNREACH",
"ECONNRESET",
]);

// Per-connector authentication failure signals. Keyed by code or errno.
const AUTH_CODES: Record<ConnectorType, ReadonlyArray<string | number>> = {
postgres: ["28P01", "28000"],
mysql: ["ER_ACCESS_DENIED_ERROR", 1045, 1698],
mariadb: ["ER_ACCESS_DENIED_ERROR", 1045, 1698],
sqlserver: ["ELOGIN"],
sqlite: [], // no network/auth layer
};

function unreachableMessage(sourceId: string): string {
return `Source '${sourceId}' is unreachable. ` +
`Verify the database is running and reachable (host, port, network), then retry.`;
}

function authMessage(sourceId: string): string {
return `Authentication failed for source '${sourceId}'. ` +
`Verify the credentials/access for this source are valid, then retry.`;
}

function tunnelMessage(sourceId: string): string {
return `SSH tunnel for source '${sourceId}' failed to establish. ` +
`Verify SSH host/credentials and bastion reachability, then retry.`;
}

/**
* Classify a thrown error from a connect attempt or query into a connection
* failure category. Returns null when the error is not a recognized
* connection/access failure (caller should fall back to its generic handling).
* Pure; never throws.
*/
export function classifyConnectionError(
error: unknown,
connectorType: ConnectorType,
sourceId: string
): { code: ConnectionErrorCode; message: string } | null {
if (!error || typeof error !== "object") {
return null;
}
const err = error as Record<string, unknown>;

// Tunnel marker wins over the underlying network code.
if (err[TUNNEL_ERROR_MARKER] === true) {
return { code: "TUNNEL_FAILED", message: tunnelMessage(sourceId) };
}

const code = err.code;
if (typeof code === "string" && NETWORK_CODES.has(code)) {
return { code: "SOURCE_UNREACHABLE", message: unreachableMessage(sourceId) };
}

const authCodes = AUTH_CODES[connectorType];
const errno = err.errno;
if (
(typeof code === "string" && authCodes.includes(code)) ||
(typeof errno === "number" && authCodes.includes(errno))
) {
return { code: "AUTH_FAILED", message: authMessage(sourceId) };
}

return null;
}
Loading
Loading