diff --git a/index.js b/index.js index c6a80d44d6..b198cab1cb 100644 --- a/index.js +++ b/index.js @@ -27,11 +27,84 @@ exports.Pool = Pool; exports.PoolCluster = PoolCluster; -exports.createServer = function (handler) { +const _serverHandlerKeys = ['query', 'ping', 'quit', 'init_db', 'auth']; + +function _hasHandlerKeys(obj) { + return _serverHandlerKeys.some((k) => typeof obj[k] === 'function'); +} + +function _wrapAuth(authHandler) { + return function (params, cb) { + Promise.resolve() + .then(() => authHandler(params)) + .then(() => cb(null)) + .catch((err) => + cb(null, { message: err.message, code: err.code || 1045 }) + ); + }; +} + +function _buildHandshakeArgs(handlers) { + const args = { + protocolVersion: 10, + serverVersion: handlers.serverVersion || 'mysql2-server', + connectionId: Math.floor(Math.random() * 1000000), + statusFlags: 2, + characterSet: 8, + capabilityFlags: 0xffffff, + }; + if (handlers.auth) { + args.authCallback = _wrapAuth(handlers.auth); + } + return args; +} + +exports.createServer = function (opts = {}) { const Server = require('./lib/server.js'); - const s = new Server(); - if (handler) { - s.on('connection', handler); + const Commands = require('./lib/commands/index.js'); + const { buildHandleCommand } = require('./lib/commands/server/index.js'); + + if (typeof opts === 'function') { + const fn = opts; + const s = new Server({ encoding: 'cesu8' }); + s.on('connection', (conn) => { + conn.on('error', () => {}); + const result = fn(conn); + if (!result || typeof result !== 'object' || !_hasHandlerKeys(result)) { + return; + } + const handlers = result; + const encoding = handlers.encoding || 'cesu8'; + conn.serverConfig = { encoding }; + conn.config.serverOptions = Object.assign({}, conn.config.serverOptions, { + handleCommand: buildHandleCommand(handlers), + encoding, + }); + conn.addCommand( + new Commands.ServerHandshake(_buildHandshakeArgs(handlers)) + ); + }); + return s; + } + + if (_hasHandlerKeys(opts)) { + const handleCommand = buildHandleCommand(opts); + const encoding = opts.encoding || 'cesu8'; + const s = new Server({ handleCommand, encoding }); + s.on('connection', (conn) => { + conn.on('error', () => {}); + conn.serverConfig = { encoding }; + conn.addCommand(new Commands.ServerHandshake(_buildHandshakeArgs(opts))); + }); + return s; + } + + const s = new Server({ + handleCommand: opts.handleCommand, + encoding: opts.encoding || 'cesu8', + }); + if (opts.onConnection) { + s.on('connection', opts.onConnection); } return s; }; diff --git a/lib/base/connection.js b/lib/base/connection.js index 77da3d0f37..12509f47b6 100644 --- a/lib/base/connection.js +++ b/lib/base/connection.js @@ -515,14 +515,26 @@ class BaseConnection extends EventEmitter { ); } } + if ( + !this._command && + this.config.isServer && + this.config.serverOptions?.handleCommand + ) { + const commandCode = packet.peekByte(); + this._command = this.config.serverOptions.handleCommand(commandCode); + } if (!this._command) { const marker = packet.peekByte(); - // If it's an Err Packet, we should use it. if (marker === 0xff) { const error = Packets.Error.fromPacket(packet); this.protocolError(error.message, error.code); + } else if (this.config.isServer && !this.config.serverOptions?.handleCommand) { + this.protocolError( + 'No handleCommand configured for server connection. ' + + 'Provide a handleCommand option to createServer() to handle client commands.', + 'PROTOCOL_UNEXPECTED_PACKET' + ); } else { - // Otherwise, it means it's some other unexpected packet. this.protocolError( 'Unexpected packet while no commands in the queue', 'PROTOCOL_UNEXPECTED_PACKET' @@ -1016,27 +1028,31 @@ class BaseConnection extends EventEmitter { // =================================== // outgoing server connection methods // =================================== + + get _serverEncoding() { + return ( + this.config.serverOptions?.encoding || + (this.serverConfig && this.serverConfig.encoding) || + 'cesu8' + ); + } + writeColumns(columns) { this.writePacket(Packets.ResultSetHeader.toPacket(columns.length)); columns.forEach((column) => { this.writePacket( - Packets.ColumnDefinition.toPacket(column, this.serverConfig.encoding) + Packets.ColumnDefinition.toPacket(column, this._serverEncoding) ); }); this.writeEof(); } - // row is array of columns, not hash writeTextRow(column) { - this.writePacket( - Packets.TextRow.toPacket(column, this.serverConfig.encoding) - ); + this.writePacket(Packets.TextRow.toPacket(column, this._serverEncoding)); } writeBinaryRow(column) { - this.writePacket( - Packets.BinaryRow.toPacket(column, this.serverConfig.encoding) - ); + this.writePacket(Packets.BinaryRow.toPacket(column, this._serverEncoding)); } writeTextResult(rows, columns, binary = false) { @@ -1061,13 +1077,11 @@ class BaseConnection extends EventEmitter { if (!args) { args = { affectedRows: 0 }; } - this.writePacket(Packets.OK.toPacket(args, this.serverConfig.encoding)); + this.writePacket(Packets.OK.toPacket(args, this._serverEncoding)); } writeError(args) { - // if we want to send error before initial hello was sent, use default encoding - const encoding = this.serverConfig ? this.serverConfig.encoding : 'cesu8'; - this.writePacket(Packets.Error.toPacket(args, encoding)); + this.writePacket(Packets.Error.toPacket(args, this._serverEncoding)); } serverHandshake(args) { diff --git a/lib/commands/auth_switch.js b/lib/commands/auth_switch.js index ddbbea192e..0571492ae8 100644 --- a/lib/commands/auth_switch.js +++ b/lib/commands/auth_switch.js @@ -96,9 +96,11 @@ function authSwitchRequest(packet, connection, command) { const authPlugin = getAuthPlugin(pluginName, connection); if (!authPlugin) { - throw new Error( - `Server requests authentication using unknown plugin ${pluginName}. See ${'TODO: add plugins doco here'} on how to configure or author authentication plugins.` + const err = new Error( + `Server requests authentication using unknown plugin ${pluginName}.` ); + connection.emit('error', err); + return; } connection._authPlugin = authPlugin({ connection, command }); Promise.resolve(connection._authPlugin(pluginData)) diff --git a/lib/commands/client_handshake.js b/lib/commands/client_handshake.js index 7de3a7bcbc..e930b26a65 100644 --- a/lib/commands/client_handshake.js +++ b/lib/commands/client_handshake.js @@ -14,6 +14,10 @@ const Command = require('./command.js'); const Packets = require('../packets/index.js'); const ClientConstants = require('../constants/client.js'); const CharsetToEncoding = require('../constants/charset_encodings.js'); + +// TODO: refactor to use plugins +// need to coordinate with ChangeUser command, +// currently it uses sync calculateNativePasswordAuthToken method from here const auth41 = require('../auth_41.js'); const { getAuthPlugin } = require('./auth_switch.js'); const { @@ -60,27 +64,17 @@ class ClientHandshake extends Command { } this.user = connection.config.user; this.password = connection.config.password; - // "password1" is an alias to the original "password" value - // to make it easier to integrate multi-factor authentication this.password1 = connection.config.password; - // "password2" and "password3" are the 2nd and 3rd factor authentication - // passwords, which can be undefined depending on the authentication - // plugin being used this.password2 = connection.config.password2; this.password3 = connection.config.password3; this.passwordSha1 = connection.config.passwordSha1; this.database = connection.config.database; this.authPluginName = this.handshake.authPluginName; - // Optimization: Try to use the server's preferred authentication method - // to avoid an unnecessary auth switch roundtrip const serverAuthMethod = this.handshake.authPluginName; const isSecureConnection = connection.config.ssl || connection.config.socketPath; - // Combine auth plugin data for easier handling - // Note: authPluginData2 can include a trailing NUL byte when PLUGIN_AUTH is set - // We must ensure exactly 20 bytes for the scramble const authPluginData = this.handshake.authPluginData1 && this.handshake.authPluginData2 ? Buffer.concat([ @@ -89,8 +83,6 @@ class ClientHandshake extends Command { ]).slice(0, 20) : Buffer.alloc(20); - // Check if user has custom auth plugin or legacy handler for the server-advertised method - // If so, we must not bypass the auth switch flow with our built-in implementation const hasCustomAuthPlugin = connection.config.authPlugins && Object.prototype.hasOwnProperty.call( @@ -100,8 +92,6 @@ class ClientHandshake extends Command { const hasLegacyAuthSwitchHandler = typeof connection.config.authSwitchHandler === 'function'; - // Determine which auth method to use - // Try to use server's preferred method if we can, otherwise fallback to native const canUseDirectAuth = !hasCustomAuthPlugin && !hasLegacyAuthSwitchHandler && @@ -113,7 +103,6 @@ class ClientHandshake extends Command { ? serverAuthMethod : 'mysql_native_password'; - // Calculate the auth token for the chosen method const authToken = this.calculateAuthToken( clientAuthMethod, this.password, @@ -144,9 +133,6 @@ class ClientHandshake extends Command { }); connection.writePacket(handshakeResponse.toPacket()); - // If we used a non-native auth method in the initial handshake response, - // we need to prepare for potential AuthMoreData packets by creating - // the appropriate auth plugin instance if (clientAuthMethod !== 'mysql_native_password') { this.initializeAuthPlugin(clientAuthMethod, authPluginData, connection); } diff --git a/lib/commands/command.js b/lib/commands/command.js index 3d318868bb..eebba7264c 100644 --- a/lib/commands/command.js +++ b/lib/commands/command.js @@ -24,6 +24,9 @@ class Command extends EventEmitter { if (!this.next) { this.next = this.start; connection._resetSequenceId(); + if (connection.config.isServer && packet) { + connection._bumpSequenceId(1); + } } if (packet && packet.isError()) { const err = packet.asError(connection.clientEncoding); diff --git a/lib/commands/server/index.js b/lib/commands/server/index.js new file mode 100644 index 0000000000..90c03cd80c --- /dev/null +++ b/lib/commands/server/index.js @@ -0,0 +1,56 @@ +'use strict'; + +const CommandCode = require('../../constants/commands.js'); +const ServerQuery = require('./query.js'); +const ServerPing = require('./ping.js'); +const ServerQuit = require('./quit.js'); +const ServerInitDb = require('./init_db.js'); +const { sendError } = require('./send_result.js'); +const Command = require('../command.js'); + +function defaultPing() {} +function defaultQuit() {} +function defaultInitDb() {} + +function buildHandleCommand(handlers) { + const queryHandler = handlers.query; + const pingHandler = handlers.ping || defaultPing; + const quitHandler = handlers.quit || defaultQuit; + const initDbHandler = handlers.init_db || defaultInitDb; + const fallback = handlers.handleCommand; + + return function handleCommand(commandCode) { + switch (commandCode) { + case CommandCode.QUERY: + if (queryHandler) { + return new ServerQuery(queryHandler); + } + break; + case CommandCode.PING: + return new ServerPing(pingHandler); + case CommandCode.QUIT: + return new ServerQuit(quitHandler); + case CommandCode.INIT_DB: + return new ServerInitDb(initDbHandler); + } + + if (fallback) { + return fallback(commandCode); + } + + const cmd = new Command(); + cmd.start = function (_packet, connection) { + sendError(connection, new Error('Command not supported')); + return null; + }; + return cmd; + }; +} + +module.exports = { + ServerQuery, + ServerPing, + ServerQuit, + ServerInitDb, + buildHandleCommand, +}; diff --git a/lib/commands/server/init_db.js b/lib/commands/server/init_db.js new file mode 100644 index 0000000000..0a4166cdc4 --- /dev/null +++ b/lib/commands/server/init_db.js @@ -0,0 +1,50 @@ +'use strict'; + +const Command = require('../command.js'); +const { sendResult, sendError } = require('./send_result.js'); + +class ServerInitDb extends Command { + constructor(handler) { + super(); + this._handler = handler; + } + + start(packet, connection) { + packet.readInt8(); + const encoding = + (connection.clientHelloReply && connection.clientHelloReply.encoding) || + 'utf8'; + const schemaName = packet.readString(undefined, encoding); + let result; + try { + result = this._handler(schemaName); + } catch (err) { + sendError(connection, err); + return null; + } + if (result && typeof result.then === 'function') { + result + .then(() => sendResult(connection, undefined)) + .catch((err) => sendError(connection, err)) + .then(() => { + this.next = null; + this.emit('end'); + connection._command = connection._commands.shift(); + if (connection._command) { + connection.sequenceId = 0; + connection.compressedSequenceId = 0; + connection.handlePacket(); + } + }); + return ServerInitDb.prototype._awaitResult; + } + sendResult(connection, undefined); + return null; + } + + _awaitResult() { + return ServerInitDb.prototype._awaitResult; + } +} + +module.exports = ServerInitDb; diff --git a/lib/commands/server/ping.js b/lib/commands/server/ping.js new file mode 100644 index 0000000000..8d01ee27a9 --- /dev/null +++ b/lib/commands/server/ping.js @@ -0,0 +1,46 @@ +'use strict'; + +const Command = require('../command.js'); +const { sendResult, sendError } = require('./send_result.js'); + +class ServerPing extends Command { + constructor(handler) { + super(); + this._handler = handler; + } + + start(packet, connection) { + packet.readInt8(); + let result; + try { + result = this._handler(); + } catch (err) { + sendError(connection, err); + return null; + } + if (result && typeof result.then === 'function') { + result + .then(() => sendResult(connection, undefined)) + .catch((err) => sendError(connection, err)) + .then(() => { + this.next = null; + this.emit('end'); + connection._command = connection._commands.shift(); + if (connection._command) { + connection.sequenceId = 0; + connection.compressedSequenceId = 0; + connection.handlePacket(); + } + }); + return ServerPing.prototype._awaitResult; + } + sendResult(connection, undefined); + return null; + } + + _awaitResult() { + return ServerPing.prototype._awaitResult; + } +} + +module.exports = ServerPing; diff --git a/lib/commands/server/query.js b/lib/commands/server/query.js new file mode 100644 index 0000000000..63d1834b18 --- /dev/null +++ b/lib/commands/server/query.js @@ -0,0 +1,54 @@ +'use strict'; + +const Command = require('../command.js'); +const Packets = require('../../packets/index.js'); +const { sendResult, sendError } = require('./send_result.js'); + +class ServerQuery extends Command { + constructor(handler) { + super(); + this._handler = handler; + } + + start(packet, connection) { + const encoding = + (connection.clientHelloReply && connection.clientHelloReply.encoding) || + 'utf8'; + const queryPacket = Packets.Query.fromPacket(packet, encoding); + let result; + try { + result = this._handler(queryPacket.query); + } catch (err) { + sendError(connection, err); + return null; + } + if (result && typeof result.then === 'function') { + this._handleAsync(result, connection); + return ServerQuery.prototype._awaitResult; + } + sendResult(connection, result); + return null; + } + + _handleAsync(promise, connection) { + promise + .then((val) => sendResult(connection, val)) + .catch((err) => sendError(connection, err)) + .then(() => { + this.next = null; + this.emit('end'); + connection._command = connection._commands.shift(); + if (connection._command) { + connection.sequenceId = 0; + connection.compressedSequenceId = 0; + connection.handlePacket(); + } + }); + } + + _awaitResult() { + return ServerQuery.prototype._awaitResult; + } +} + +module.exports = ServerQuery; diff --git a/lib/commands/server/quit.js b/lib/commands/server/quit.js new file mode 100644 index 0000000000..4173eb537a --- /dev/null +++ b/lib/commands/server/quit.js @@ -0,0 +1,23 @@ +'use strict'; + +const Command = require('../command.js'); + +class ServerQuit extends Command { + constructor(handler) { + super(); + this._handler = handler; + } + + start(packet, connection) { + packet.readInt8(); + Promise.resolve() + .then(() => this._handler()) + .catch(() => {}) + .then(() => { + connection.stream.end(); + }); + return null; + } +} + +module.exports = ServerQuit; diff --git a/lib/commands/server/send_result.js b/lib/commands/server/send_result.js new file mode 100644 index 0000000000..dba6e750a5 --- /dev/null +++ b/lib/commands/server/send_result.js @@ -0,0 +1,71 @@ +'use strict'; + +const columnDefaults = { + catalog: 'def', + schema: '', + table: '', + orgTable: '', + orgName: '', + characterSet: 33, + columnLength: 255, + columnType: 253, + flags: 0, + decimals: 0, +}; + +function normalizeColumn(col) { + if (typeof col === 'string') { + return Object.assign({}, columnDefaults, { name: col, orgName: col }); + } + return Object.assign({}, columnDefaults, { orgName: col.name }, col); +} + +function inferColumns(row) { + return Object.keys(row).map((name) => + Object.assign({}, columnDefaults, { name, orgName: name }) + ); +} + +function sendResult(connection, result) { + if (result === undefined || result === null) { + connection.writeOk(); + connection.sequenceId = 0; + return; + } + + if (Array.isArray(result)) { + const columns = result.length > 0 ? inferColumns(result[0]) : []; + connection.writeTextResult(result, columns); + connection.sequenceId = 0; + return; + } + + if (typeof result === 'object') { + if (result.rows !== undefined && result.columns !== undefined) { + connection.writeTextResult( + result.rows, + result.columns.map(normalizeColumn) + ); + connection.sequenceId = 0; + return; + } + if (result.affectedRows !== undefined) { + connection.writeOk(result); + connection.sequenceId = 0; + return; + } + } + + connection.writeOk(); + connection.sequenceId = 0; +} + +function sendError(connection, err) { + connection.writeError({ + message: err.message || String(err), + code: err.code || err.errno || 1149, + }); + connection.sequenceId = 0; +} + +module.exports = { sendResult, sendError }; diff --git a/lib/commands/server_handshake.js b/lib/commands/server_handshake.js index 4d4b4c80ee..cb811b2844 100644 --- a/lib/commands/server_handshake.js +++ b/lib/commands/server_handshake.js @@ -28,19 +28,21 @@ class ServerHandshake extends Command { connection.emit('error', new Error('Error generating random bytes')); return; } - connection.writePacket(serverHelloPacket.toPacket(0)); + connection.writePacket(serverHelloPacket.toPacket(10)); }); return ServerHandshake.prototype.readClientReply; } readClientReply(packet, connection) { - // check auth here const clientHelloReply = Packets.HandshakeResponse.fromPacket( packet, this.args.capabilityFlags ); - // TODO check we don't have something similar already connection.clientHelloReply = clientHelloReply; + const yieldToHandleCommand = !!connection.config.serverOptions?.handleCommand; + const nextState = yieldToHandleCommand + ? null + : ServerHandshake.prototype.dispatchCommands; if (this.args.authCallback) { this.args.authCallback( { @@ -52,12 +54,10 @@ class ServerHandshake extends Command { authToken: clientHelloReply.authToken, }, (err, mysqlError) => { - // if (err) if (!mysqlError) { connection.writeOk(); + connection.sequenceId = 0; } else { - // TODO create constants / errorToCode - // 1045 = ER_ACCESS_DENIED_ERROR connection.writeError({ message: mysqlError.message || '', code: mysqlError.code || 1045, @@ -68,8 +68,11 @@ class ServerHandshake extends Command { ); } else { connection.writeOk(); + if (yieldToHandleCommand) { + connection.sequenceId = 0; + } } - return ServerHandshake.prototype.dispatchCommands; + return nextState; } _isStatement(query, name) { diff --git a/lib/connection_config.js b/lib/connection_config.js index c1a238fa60..d45ccac76b 100644 --- a/lib/connection_config.js +++ b/lib/connection_config.js @@ -70,6 +70,7 @@ const validOptions = { waitForConnections: 1, jsonStrings: 1, gracefulEnd: 1, + serverOptions: 1, }; class ConnectionConfig { @@ -94,6 +95,7 @@ class ConnectionConfig { } } this.isServer = options.isServer; + this.serverOptions = options.serverOptions; this.stream = options.stream; this.host = options.host || 'localhost'; this.port = diff --git a/lib/packets/handshake.js b/lib/packets/handshake.js index ee894a0d3a..b48ee42f08 100644 --- a/lib/packets/handshake.js +++ b/lib/packets/handshake.js @@ -1,10 +1,22 @@ 'use strict'; +const crypto = require('crypto'); const Packet = require('../packets/packet'); const ClientConstants = require('../constants/client.js'); // https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::Handshake +const getSalt = () => + new Promise((accept, reject) => { + crypto.randomBytes(20, (err, data) => { + if (err) { + reject(err); + return; + } + accept(data); + }); + }); + class Handshake { constructor(args) { this.protocolVersion = args.protocolVersion; @@ -15,25 +27,22 @@ class Handshake { this.authPluginData2 = args.authPluginData2; this.characterSet = args.characterSet; this.statusFlags = args.statusFlags; - this.authPluginName = args.authPluginName; + this.authPluginName = args.authPluginName || 'mysql_native_password'; + this.getSalt = args.getSalt || getSalt; } setScrambleData(cb) { - require('crypto').randomBytes(20, (err, data) => { - if (err) { - cb(err); - return; - } - this.authPluginData1 = data.slice(0, 8); - this.authPluginData2 = data.slice(8, 20); + this.getSalt().then((salt) => { + this.authPluginData1 = salt.slice(0, 8); + this.authPluginData2 = salt.slice(8, 20); cb(); }); } - toPacket(sequenceId) { + toPacket() { const length = 68 + Buffer.byteLength(this.serverVersion, 'utf8'); const buffer = Buffer.alloc(length + 4, 0); // zero fill, 10 bytes filler later needs to contain zeros - const packet = new Packet(sequenceId, buffer, 0, length + 4); + const packet = new Packet(0, buffer, 0, length + 4); packet.offset = 4; packet.writeInt8(this.protocolVersion); packet.writeString(this.serverVersion, 'cesu8'); @@ -51,7 +60,7 @@ class Handshake { packet.skip(10); packet.writeBuffer(this.authPluginData2); packet.writeInt8(0); - packet.writeString('mysql_native_password', 'latin1'); + packet.writeString(this.authPluginName, 'latin1'); packet.writeInt8(0); return packet; } diff --git a/lib/packets/handshake_response.js b/lib/packets/handshake_response.js index f1c1680274..469451c717 100644 --- a/lib/packets/handshake_response.js +++ b/lib/packets/handshake_response.js @@ -4,59 +4,14 @@ const ClientConstants = require('../constants/client.js'); const CharsetToEncoding = require('../constants/charset_encodings.js'); const Packet = require('../packets/packet.js'); -const auth41 = require('../auth_41.js'); - class HandshakeResponse { constructor(handshake) { this.user = handshake.user || ''; this.database = handshake.database || ''; - this.password = handshake.password || ''; - this.passwordSha1 = handshake.passwordSha1; - this.authPluginData1 = handshake.authPluginData1; - this.authPluginData2 = handshake.authPluginData2; this.compress = handshake.compress; this.clientFlags = handshake.flags; - - // Accept pre-calculated authToken and authPluginName from caller - // This allows the caller to optimize by using the server's preferred auth method - if ( - handshake.authToken !== undefined && - handshake.authPluginName !== undefined - ) { - // Validate types to fail fast with clear errors - if (!Buffer.isBuffer(handshake.authToken)) { - throw new TypeError( - 'HandshakeResponse authToken must be a Buffer when provided' - ); - } - if (typeof handshake.authPluginName !== 'string') { - throw new TypeError( - 'HandshakeResponse authPluginName must be a string when provided' - ); - } - this.authToken = handshake.authToken; - this.authPluginName = handshake.authPluginName; - } else { - // Fallback to legacy behavior: calculate mysql_native_password token - // TODO: pre-4.1 auth support - let authToken; - if (this.passwordSha1) { - authToken = auth41.calculateTokenFromPasswordSha( - this.passwordSha1, - this.authPluginData1, - this.authPluginData2 - ); - } else { - authToken = auth41.calculateToken( - this.password, - this.authPluginData1, - this.authPluginData2 - ); - } - this.authToken = authToken; - this.authPluginName = 'mysql_native_password'; - } - + this.authToken = handshake.authToken || Buffer.alloc(0); + this.authPluginName = handshake.authPluginName; this.charsetNumber = handshake.charsetNumber; this.encoding = CharsetToEncoding[handshake.charsetNumber]; this.connectAttributes = handshake.connectAttributes; diff --git a/lib/packets/query.js b/lib/packets/query.js index 1523b58b15..22967196a1 100644 --- a/lib/packets/query.js +++ b/lib/packets/query.js @@ -97,6 +97,15 @@ class Query { const p = this.serializeToBuffer(Packet.MockBuffer()); return this.serializeToBuffer(Buffer.allocUnsafe(p.offset)); } + + static fromPacket(packet, encoding) { + const _commandCode = packet.readInt8(); + if (_commandCode !== CommandCode.QUERY) { + throw new Error('Incorrect command code for Query packet'); + } + const query = packet.readString(undefined, encoding); + return new Query(query); + } } module.exports = Query; diff --git a/lib/server.js b/lib/server.js index e0633e8676..5b6afca493 100644 --- a/lib/server.js +++ b/lib/server.js @@ -8,8 +8,9 @@ const ConnectionConfig = require('./connection_config'); // TODO: inherit Server from net.Server class Server extends EventEmitter { - constructor() { + constructor(options) { super(); + this._options = options; this.connections = []; this._server = net.createServer(this._handleConnection.bind(this)); } @@ -18,6 +19,7 @@ class Server extends EventEmitter { const connectionConfig = new ConnectionConfig({ stream: socket, isServer: true, + serverOptions: this._options, }); const connection = new Connection({ config: connectionConfig }); this.emit('connection', connection); diff --git a/test/integration/test-server-api.test.mts b/test/integration/test-server-api.test.mts new file mode 100644 index 0000000000..f5a080cb39 --- /dev/null +++ b/test/integration/test-server-api.test.mts @@ -0,0 +1,205 @@ +import { describe, it, strict } from 'poku'; +import mysql from '../../index.js'; + +type Connection = ReturnType; + +function createTestServer(opts: any): Promise<{ server: any; port: number }> { + return new Promise((resolve) => { + const server = mysql.createServer(opts); + // @ts-expect-error: internal access + server.listen(0, () => { + // @ts-expect-error: internal access + const port = server._server.address().port; + resolve({ server, port }); + }); + }); +} + +function connectAndQuery( + port: number, + sql: string, + opts: Record = {} +): Promise<{ rows?: any; error?: any }> { + return new Promise((resolve) => { + const conn: Connection = mysql.createConnection({ + host: '127.0.0.1', + port, + user: opts.user || 'test', + database: 'test', + }); + conn.on('error', (err: any) => { + resolve({ error: err }); + }); + conn.query(sql, (err: any, rows: any) => { + conn.end(() => {}); + if (err) return resolve({ error: err }); + resolve({ rows }); + }); + }); +} + +await describe('Server API - static handlers', async () => { + await it('should handle query returning array of rows', async () => { + const { server, port } = await createTestServer({ + query() { + return [{ id: 1, name: 'hello' }]; + }, + }); + const { rows } = await connectAndQuery(port, 'SELECT 1'); + server.close(() => {}); + strict.ok(Array.isArray(rows)); + strict.equal(rows.length, 1); + strict.equal(rows[0].id, '1'); + strict.equal(rows[0].name, 'hello'); + }); + + await it('should handle query returning { rows, columns }', async () => { + const { server, port } = await createTestServer({ + query() { + return { + rows: [{ val: 42 }], + columns: [{ name: 'val' }], + }; + }, + }); + const { rows } = await connectAndQuery(port, 'SELECT val'); + server.close(() => {}); + strict.equal(rows[0].val, '42'); + }); + + await it('should handle query returning affectedRows', async () => { + const { server, port } = await createTestServer({ + query() { + return { affectedRows: 3, insertId: 10 }; + }, + }); + const { rows: result } = await connectAndQuery(port, 'INSERT INTO t'); + server.close(() => {}); + strict.equal(result.affectedRows, 3); + strict.equal(result.insertId, 10); + }); + + await it('should handle query throwing error', async () => { + const { server, port } = await createTestServer({ + query() { + throw new Error('Something went wrong'); + }, + }); + const { error } = await connectAndQuery(port, 'SELECT 1'); + server.close(() => {}); + strict.ok(error); + strict.ok(error.message.includes('Something went wrong')); + }); + + await it('should handle ping', async () => { + let pingCalled = false; + const { server, port } = await createTestServer({ + query() { + return []; + }, + ping() { + pingCalled = true; + }, + }); + await new Promise((resolve) => { + const conn: Connection = mysql.createConnection({ + host: '127.0.0.1', + port, + user: 'test', + database: 'test', + }); + conn.ping(() => { + conn.end(() => server.close(() => resolve())); + }); + }); + strict.ok(pingCalled, 'ping handler should have been called'); + }); +}); + +await describe('Server API - factory function', async () => { + await it('should support per-connection state via factory', async () => { + let connectionCount = 0; + const { server, port } = await createTestServer((_conn: any) => { + const myId = ++connectionCount; + return { + query(sql: string) { + return [{ connId: myId, sql }]; + }, + }; + }); + const { rows } = await connectAndQuery(port, 'hello world'); + server.close(() => {}); + strict.equal(rows[0].connId, '1'); + strict.equal(rows[0].sql, 'hello world'); + }); +}); + +await describe('Server API - auth handler', async () => { + await it('should reject connection when auth throws', async () => { + const { server, port } = await createTestServer({ + auth({ user }: { user: string }) { + if (user !== 'admin') throw new Error('Access denied'); + }, + query() { + return []; + }, + }); + const { error } = await connectAndQuery(port, 'SELECT 1'); + server.close(() => {}); + strict.ok(error); + strict.ok(error.message.includes('Access denied')); + }); + + await it('should accept connection when auth resolves', async () => { + const { server, port } = await createTestServer({ + auth({ user }: { user: string }) { + if (user !== 'admin') throw new Error('Access denied'); + }, + query() { + return [{ ok: 1 }]; + }, + }); + const { rows } = await connectAndQuery(port, 'SELECT 1', { + user: 'admin', + }); + server.close(() => {}); + strict.equal(rows[0].ok, '1'); + }); +}); + +await describe('Server API - backward compatibility', async () => { + await it('legacy createServer(handler) + serverHandshake', async () => { + const { server, port } = await createTestServer((conn: any) => { + conn.serverHandshake({ + protocolVersion: 10, + serverVersion: 'legacy-server', + connectionId: 1, + statusFlags: 2, + characterSet: 8, + capabilityFlags: 0xffffff, + }); + conn.on('query', () => { + conn.writeColumns([ + { + catalog: 'def', + schema: '', + table: '', + orgTable: '', + name: 'legacy', + orgName: 'legacy', + characterSet: 33, + columnLength: 255, + columnType: 253, + flags: 0, + decimals: 0, + }, + ]); + conn.writeTextRow(['yes']); + conn.writeEof(); + }); + }); + const { rows } = await connectAndQuery(port, 'SELECT 1'); + server.close(() => {}); + strict.equal(rows[0].legacy, 'yes'); + }); +}); diff --git a/test/unit/packets/test-handshake-response-auth-plugin.test.mts b/test/unit/packets/test-handshake-response-auth-plugin.test.mts index 64d3935122..53db752dc4 100644 --- a/test/unit/packets/test-handshake-response-auth-plugin.test.mts +++ b/test/unit/packets/test-handshake-response-auth-plugin.test.mts @@ -63,52 +63,11 @@ await describe('HandshakeResponse with auth plugin name', async () => { strict.equal(response.authToken, customToken); }); - await it('should fallback to mysql_native_password when not specified', () => { + await it('should default to empty authToken when not specified', () => { const response = new HandshakeResponse(baseConfig); - strict.equal(response.authPluginName, 'mysql_native_password'); strict.ok(Buffer.isBuffer(response.authToken)); - }); - - await it('should throw TypeError for non-Buffer authToken', () => { - let errorThrown = false; - try { - new HandshakeResponse({ - ...baseConfig, - authToken: 'not a buffer' as unknown as Buffer, - authPluginName: 'caching_sha2_password', - }); - } catch (err: unknown) { - if ( - err instanceof TypeError && - err.message.includes('must be a Buffer') - ) { - errorThrown = true; - } - } - strict.ok(errorThrown, 'Should throw TypeError for non-Buffer authToken'); - }); - - await it('should throw TypeError for non-string authPluginName', () => { - let errorThrown = false; - try { - new HandshakeResponse({ - ...baseConfig, - authToken: Buffer.alloc(32), - authPluginName: 12345 as unknown as string, - }); - } catch (err: unknown) { - if ( - err instanceof TypeError && - err.message.includes('must be a string') - ) { - errorThrown = true; - } - } - strict.ok( - errorThrown, - 'Should throw TypeError for non-string authPluginName' - ); + strict.equal(response.authToken.length, 0); }); await it('should handle empty password', () => { diff --git a/test/unit/packets/test-handshake-response-server-flags.test.mts b/test/unit/packets/test-handshake-response-server-flags.test.mts index 12ac37ba4b..9047d208f9 100644 --- a/test/unit/packets/test-handshake-response-server-flags.test.mts +++ b/test/unit/packets/test-handshake-response-server-flags.test.mts @@ -14,11 +14,10 @@ const allFlags = const baseConfig = { user: 'testuser', database: 'testdb', - password: 'testpass', flags: allFlags, charsetNumber: 255, - authPluginData1: Buffer.alloc(8), - authPluginData2: Buffer.alloc(12), + authToken: Buffer.alloc(20), + authPluginName: 'mysql_native_password', connectAttributes: { _client_name: 'test', _pid: '1234' }, }; diff --git a/typings/mysql/index.d.ts b/typings/mysql/index.d.ts index 51248ded37..ad684db450 100644 --- a/typings/mysql/index.d.ts +++ b/typings/mysql/index.d.ts @@ -20,7 +20,14 @@ import { Prepare as BasePrepare, PrepareStatementInfo, } from './lib/protocol/sequences/Prepare.js'; -import { Server } from './lib/Server.js'; +import { + Server, + ServerOptions, + ServerHandlers, + ServerFactory, + ServerResult, + AuthParams, +} from './lib/Server.js'; import { escape as SqlStringEscape, escapeId as SqlStringEscapeId, @@ -40,6 +47,11 @@ export { ExecuteValues, QueryValues, PrepareStatementInfo, + ServerOptions, + ServerHandlers, + ServerFactory, + ServerResult, + AuthParams, }; export * from './lib/protocol/packets/index.js'; @@ -83,7 +95,12 @@ export interface ConnectionConfig extends ConnectionOptions { }; } -export function createServer(handler: (conn: BaseConnection) => any): Server; +export function createServer( + handler: (conn: BaseConnection) => any +): Server; +export function createServer(handlers: ServerHandlers): Server; +export function createServer(factory: ServerFactory): Server; +export function createServer(options: ServerOptions): Server; export type { QueryTraceContext, diff --git a/typings/mysql/lib/Connection.d.ts b/typings/mysql/lib/Connection.d.ts index d73d9e2613..f7f51fc68f 100644 --- a/typings/mysql/lib/Connection.d.ts +++ b/typings/mysql/lib/Connection.d.ts @@ -298,6 +298,8 @@ export interface ConnectionOptions { isServer?: boolean; + serverOptions?: import('./Server.js').ServerOptions; + maxPreparedStatements?: number; namedPlaceholders?: boolean; diff --git a/typings/mysql/lib/Server.d.ts b/typings/mysql/lib/Server.d.ts index 195adeed9c..b67b9ca3e4 100644 --- a/typings/mysql/lib/Server.d.ts +++ b/typings/mysql/lib/Server.d.ts @@ -1,6 +1,40 @@ import { EventEmitter } from 'events'; import { Connection } from './Connection.js'; +export type ServerResult = + | Array> + | { rows: Array>; columns: Array<{ name: string }> } + | { affectedRows: number; insertId?: number } + | void; + +export interface AuthParams { + user: string; + database: string; + address: string; + authPluginData1: Buffer; + authPluginData2: Buffer; + authToken: Buffer; +} + +export interface ServerHandlers { + auth?(params: AuthParams): void | Promise; + query?(sql: string): ServerResult | Promise; + ping?(): void | Promise; + quit?(): void | Promise; + init_db?(schema: string): void | Promise; + handleCommand?(commandCode: number): any; + serverVersion?: string; + encoding?: string; +} + +export type ServerFactory = (connection: Connection) => ServerHandlers; + +export interface ServerOptions { + onConnection?: (conn: Connection) => void; + handleCommand?: (commandCode: number) => any; + encoding?: string; +} + declare class Server extends EventEmitter { connections: Array; diff --git a/website/docs/documentation/mysql-server.mdx b/website/docs/documentation/mysql-server.mdx index 3a72d7831e..86a3e9c3f5 100644 --- a/website/docs/documentation/mysql-server.mdx +++ b/website/docs/documentation/mysql-server.mdx @@ -1,53 +1,278 @@ # MySQL Server API -## Server - -- `createServer()` - creates server instance -- `Server.listen` - listen port / unix socket (same arguments as [net.Server.listen](https://nodejs.org/api/net.html#net_server_listen_port_host_backlog_callback)) - -### Events - -- **connect** - - new incoming connection. - -
- -## Connection - -- `serverHandshake({ serverVersion, protocolVersion, connectionId, statusFlags, characterSet, capabilityFlags })` - - send server handshake initialisation packet, wait handshake response and start listening for commands - - `capabilityFlags` controls which protocol features are advertised. Client flags are masked against these capabilities when parsing the handshake response, so the server only honors mutually-supported features. -- `writeOk({ affectedRows: num, insertId: num })` - - send [OK packet](https://dev.mysql.com/doc/internals/en/overview.html#packet-OK_Packet) to client -- `writeEof(warnings, statusFlags)` - - send EOF packet -- `writeTextResult(rows, fields)` - - write query result to client. Rows and fields are in the same format as in `connection.query` callback. -- `writeColumns(fields)` - - write fields + EOF packets. -- `writeTextRow(row)` - - write array (not hash!) of values as result row -- **TODO:** binary protocol - -### Events - -Every command packet received by the server will be emitted as a **packet** event with the parameters: - -- **packet:** Packet - - The packet itself -- **knownCommand:** boolean - - is this command known to the server -- **\*commandCode:** number - - the parsed command code (first byte) - -In addition special events are emitted for [commands](https://dev.mysql.com/doc/internals/en/text-protocol.html) received from the client. If no listener is present a fallback behavior will be invoked. - -- `quit()` - - Default: close the connection -- `init_db(schemaName: string)` - - Default: return OK -- `query(sql: string)` - - Please attach a listener to this. Default: return HA_ERR_INTERNAL_ERROR -- `field_list(table: string, fields: string)` - - Default: return ER_WARN_DEPRECATED_SYNTAX -- `ping()` - Default: return OK +mysql2 can act as a MySQL server, handling connections from standard MySQL clients. There are three API styles depending on your needs. + +## Quick Start + +The simplest way to create a server — define handler functions for the commands you want to support: + +```js +const mysql = require('mysql2'); + +const server = mysql.createServer({ + async query(sql) { + if (sql === 'SELECT 1') { + return [{ '1': 1 }]; + } + return { affectedRows: 0 }; + }, +}); + +server.listen(3307); +``` + +Connect with any MySQL client: + +```bash +mysql -h 127.0.0.1 -P 3307 -u any_user -e "SELECT 1" +``` + +## Returning Results + +Handler return values are automatically serialized into MySQL protocol packets: + +### Rows (array of objects) + +Column definitions are inferred from the first row's keys: + +```js +createServer({ + query(sql) { + return [ + { id: 1, name: 'Alice' }, + { id: 2, name: 'Bob' }, + ]; + }, +}); +``` + +### Explicit columns + +For full control over column metadata, return `{ rows, columns }`: + +```js +createServer({ + query(sql) { + return { + rows: [{ total: 42 }], + columns: [{ name: 'total', columnType: 8 /* LONGLONG */ }], + }; + }, +}); +``` + +### OK packet (INSERT/UPDATE/DELETE) + +Return an object with `affectedRows`: + +```js +createServer({ + query(sql) { + return { affectedRows: 3, insertId: 10 }; + }, +}); +``` + +### Void (implicit OK) + +Returning nothing sends an OK packet with zero affected rows: + +```js +createServer({ + ping() { + // no return needed — OK is sent automatically + }, +}); +``` + +### Errors + +Throw an error to send an error packet to the client: + +```js +createServer({ + query(sql) { + throw new Error('Query not supported'); + }, +}); +``` + +## Handlers + +All handlers are optional. Unhandled commands get a sensible default response. + +| Handler | Arguments | Default behavior | +|---------|-----------|-----------------| +| `query(sql)` | SQL string | Error: "Command not supported" | +| `ping()` | (none) | OK | +| `quit()` | (none) | Close connection | +| `init_db(schema)` | Schema name | OK | +| `auth(params)` | `{ user, database, address, authToken, ... }` | Accept all | + +Handlers can be synchronous or `async`: + +```js +createServer({ + async query(sql) { + const rows = await database.query(sql); + return rows; + }, +}); +``` + +## Authentication + +Provide an `auth` handler to control access. Throw to reject, return to accept: + +```js +createServer({ + auth({ user, authToken }) { + if (user !== 'admin') { + throw new Error('Access denied'); + } + }, + query(sql) { + return []; + }, +}); +``` + +The `auth` handler receives: +- `user` — client username +- `database` — requested database +- `address` — client IP address +- `authPluginData1`, `authPluginData2` — server challenge data +- `authToken` — client's authentication response + +## Per-Connection State (Factory Function) + +For servers that need per-connection state (upstream connections, session data), pass a factory function instead of a static object: + +```js +const server = mysql.createServer((connection) => { + // This runs once per client connection. + // Variables here are scoped to this connection. + const upstream = mysql.createConnection({ host: 'real-mysql', user: 'root' }); + + return { + auth({ user }) { + if (user !== 'proxy_user') throw new Error('Denied'); + }, + async query(sql) { + const [rows] = await upstream.promise().query(sql); + return rows; + }, + quit() { + upstream.end(); + }, + }; +}); +``` + +The factory receives the raw `connection` object, available via closure in all handlers — but you rarely need it directly since return values handle serialization. + +## Low-Level API + +For full protocol control (custom commands, streaming, binary protocol), use the `handleCommand` option: + +```js +const Command = require('mysql2/lib/commands/command'); +const CommandCode = require('mysql2/lib/constants/commands'); +const Packets = require('mysql2/lib/packets'); + +const server = mysql.createServer({ + onConnection(conn) { + conn.serverHandshake({ + protocolVersion: 10, + serverVersion: '8.0.0', + connectionId: 1, + statusFlags: 2, + characterSet: 8, + capabilityFlags: 0xffffff, + }); + }, + handleCommand(commandCode) { + switch (commandCode) { + case CommandCode.QUERY: + return new MyQueryCommand(); + case CommandCode.PING: + return new MyPingCommand(); + default: + return new DefaultErrorCommand(); + } + }, +}); +``` + +Each command is a `Command` subclass — a packet-driven state machine: + +```js +class MyQueryCommand extends Command { + start(packet, connection) { + const encoding = connection.clientHelloReply?.encoding || 'utf8'; + const query = Packets.Query.fromPacket(packet, encoding); + connection.writeTextResult( + [{ result: 'hello' }], + [{ name: 'result', catalog: 'def', schema: '', table: '', orgTable: '', + orgName: 'result', characterSet: 33, columnLength: 255, + columnType: 253, flags: 0, decimals: 0 }] + ); + connection.sequenceId = 0; + return null; // command done + } +} +``` + +## Legacy API + +The original event-based API is still fully supported: + +```js +const server = mysql.createServer((conn) => { + conn.serverHandshake({ + protocolVersion: 10, + serverVersion: '5.7.0', + connectionId: 1234, + statusFlags: 2, + characterSet: 8, + capabilityFlags: 0xffffff, + }); + + conn.on('query', (sql) => { + conn.writeTextResult(rows, columns); + }); + + conn.on('ping', () => { + conn.writeOk(); + }); +}); +``` + +### Server write methods + +- `writeOk({ affectedRows, insertId })` — send OK packet +- `writeError({ message, code })` — send Error packet +- `writeEof(warnings, statusFlags)` — send EOF packet +- `writeTextResult(rows, fields)` — write complete query result +- `writeColumns(fields)` — write column definitions + EOF +- `writeTextRow(row)` — write a single row (array of values) +- `writePacket(packet)` — write a raw packet + +## Server Options + +- `Server.listen(port)` — listen on a TCP port (same args as `net.Server.listen`) +- `Server.close(callback)` — close the server + +### `createServer` options + +| Option | Description | +|--------|-------------| +| `query(sql)` | Handle COM_QUERY | +| `ping()` | Handle COM_PING | +| `quit()` | Handle COM_QUIT | +| `init_db(schema)` | Handle COM_INIT_DB | +| `auth(params)` | Authentication handler | +| `handleCommand(code)` | Low-level command factory | +| `onConnection(conn)` | Raw connection handler | +| `encoding` | Server encoding (default: `'cesu8'`) | +| `serverVersion` | Version string sent to clients (default: `'mysql2-server'`) | diff --git a/website/docs/examples/tests/server.mdx b/website/docs/examples/tests/server.mdx index 5c96039b94..651dd9dd4b 100644 --- a/website/docs/examples/tests/server.mdx +++ b/website/docs/examples/tests/server.mdx @@ -4,17 +4,102 @@ import TabItem from '@theme/TabItem'; # Server - + + +```js +'use strict'; + +const mysql = require('mysql2'); + +const server = mysql.createServer({ + query(sql) { + return [ + { greeting: 'Hello from mysql2 server!', query: sql }, + ]; + }, +}); + +server.listen(3333); +console.log('MySQL server listening on port 3333'); +``` + + + + +```js +'use strict'; + +const mysql = require('mysql2'); + +const server = mysql.createServer((connection) => { + const upstream = mysql.createConnection({ + host: 'localhost', + port: 3306, + user: 'root', + }); + + return { + async query(sql) { + const [rows, fields] = await upstream.promise().query(sql); + return { rows, columns: fields }; + }, + quit() { + upstream.end(); + }, + }; +}); + +server.listen(3334); +console.log('MySQL proxy listening on port 3334'); +``` + + + + +```js +'use strict'; + +const mysql = require('mysql2'); +const auth = require('mysql2/lib/auth_41.js'); + +const VALID_USER = 'admin'; +const PASSWORD_DOUBLE_SHA = auth.doubleSha1('secret123'); + +const server = mysql.createServer({ + auth(params) { + if (params.user !== VALID_USER) { + throw new Error(`Unknown user: ${params.user}`); + } + const isValid = auth.verifyToken( + params.authPluginData1, + params.authPluginData2, + params.authToken, + PASSWORD_DOUBLE_SHA, + ); + if (!isValid) { + throw new Error('Wrong password'); + } + }, + query(sql) { + return [{ message: `Hello ${VALID_USER}, you ran: ${sql}` }]; + }, +}); + +server.listen(3335); +console.log('Authenticated MySQL server on port 3335'); +// Connect: mysql -h 127.0.0.1 -P 3335 -u admin -psecret123 +``` + + + ```js 'use strict'; const mysql = require('mysql2'); -const flags = require('mysql2/lib/constants/client.js'); const auth = require('mysql2/lib/auth_41.js'); function authenticate(params, cb) { - console.log(params); const doubleSha = auth.doubleSha1('pass123'); const isValid = auth.verifyToken( params.authPluginData1, @@ -25,34 +110,23 @@ function authenticate(params, cb) { if (isValid) { cb(null); } else { - // for list of codes lib/constants/errors.js - cb(null, { message: 'wrong password dude', code: 1045 }); + cb(null, { message: 'wrong password', code: 1045 }); } } const server = mysql.createServer(); server.listen(3333); server.on('connection', (conn) => { - // we can deny connection here: - // conn.writeError({ message: 'secret', code: 123 }); - // conn.close(); conn.serverHandshake({ protocolVersion: 10, - serverVersion: '5.6.10', // 'node.js rocks', + serverVersion: '5.6.10', connectionId: 1234, statusFlags: 2, characterSet: 8, - // capabilityFlags: 0xffffff, - // capabilityFlags: -2113931265, capabilityFlags: 2181036031, authCallback: authenticate, }); - conn.on('field_list', (table, fields) => { - console.log('FIELD LIST:', table, fields); - conn.writeEof(); - }); - conn.on('query', (query) => { conn.writeColumns([ { @@ -69,10 +143,8 @@ server.on('connection', (conn) => { decimals: 0, }, ]); - conn.writeTextRow(['test тест テスト փորձարկում পরীক্ষা kiểm tra ']); - conn.writeTextRow(['ტესტი પરીક્ષણ מבחן פּרובירן اختبار परीक्षण']); + conn.writeTextRow(['test row data']); conn.writeEof(); - conn.close(); }); }); ```