Skip to content

Commit eb2bd26

Browse files
committed
Address PR review feedback: errmsg NULL guards, bind/2 error handling, deserialize database_name, interrupt_mutex ordering
- errmsg conn branch: add conn->db NULL guard, return {:error, :connection_closed} - errmsg stmt branch: return nil instead of {:error, :connection_closed} - bind/2: handle {:error, reason} from bind_parameter_count after release - deserialize: use database_name parameter instead of hardcoded "main" - close: hold interrupt_mutex across sqlite3_close_v2 + NULL assignment
1 parent f2603a8 commit eb2bd26

2 files changed

Lines changed: 36 additions & 19 deletions

File tree

c_src/sqlite3_nif.c

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -394,20 +394,23 @@ exqlite_close(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[])
394394
}
395395
}
396396

397+
// Hold interrupt_mutex across close+NULL so that any concurrent
398+
// exqlite_interrupt() either completes its sqlite3_interrupt() call
399+
// before we start closing, or blocks until we've both closed and
400+
// NULLed conn->db (then sees NULL and skips).
401+
//
397402
// note: _v2 may not fully close the connection, hence why we check if
398403
// any transaction is open above, to make sure other connections aren't blocked.
399404
// v1 is guaranteed to close or error, but will return error if any
400405
// unfinalized statements, which we likely have, as we rely on the destructors
401406
// to later run to clean those up
407+
enif_mutex_lock(conn->interrupt_mutex);
402408
rc = sqlite3_close_v2(conn->db);
403409
if (rc != SQLITE_OK) {
410+
enif_mutex_unlock(conn->interrupt_mutex);
404411
connection_release_lock(conn);
405412
return make_sqlite3_error_tuple(env, rc, conn->db);
406413
}
407-
408-
// Acquire interrupt_mutex so any concurrent exqlite_interrupt() finishes
409-
// before we NULL out conn->db, eliminating the TOCTOU / use-after-free.
410-
enif_mutex_lock(conn->interrupt_mutex);
411414
conn->db = NULL;
412415
enif_mutex_unlock(conn->interrupt_mutex);
413416

@@ -1123,7 +1126,7 @@ exqlite_deserialize(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[])
11231126
}
11241127

11251128
memcpy(buffer, serialized.data, size);
1126-
rc = sqlite3_deserialize(conn->db, "main", buffer, size, size, flags);
1129+
rc = sqlite3_deserialize(conn->db, (const char*)database_name.data, buffer, size, size, flags);
11271130
if (rc != SQLITE_OK) {
11281131
sqlite3_free(buffer);
11291132
connection_release_lock(conn);
@@ -1525,13 +1528,17 @@ exqlite_errmsg(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[])
15251528

15261529
if (enif_get_resource(env, argv[0], connection_type, (void**)&conn)) {
15271530
connection_acquire_lock(conn);
1531+
if (conn->db == NULL) {
1532+
connection_release_lock(conn);
1533+
return make_error_tuple(env, am_connection_closed);
1534+
}
15281535
msg = sqlite3_errmsg(conn->db);
15291536
connection_release_lock(conn);
15301537
} else if (enif_get_resource(env, argv[0], statement_type, (void**)&statement)) {
15311538
statement_acquire_lock(statement);
15321539
if (statement->statement == NULL) {
15331540
statement_release_lock(statement);
1534-
return make_error_tuple(env, am_connection_closed);
1541+
return am_nil;
15351542
}
15361543
msg = sqlite3_errmsg(sqlite3_db_handle(statement->statement));
15371544
statement_release_lock(statement);

lib/exqlite/sqlite3.ex

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -187,25 +187,35 @@ defmodule Exqlite.Sqlite3 do
187187
def bind(stmt, nil), do: bind(stmt, [])
188188

189189
def bind(stmt, args) when is_list(args) do
190-
params_count = bind_parameter_count(stmt)
191-
args_count = length(args)
190+
case bind_parameter_count(stmt) do
191+
{:error, reason} ->
192+
{:error, reason}
192193

193-
if args_count == params_count do
194-
bind_all(args, stmt, 1)
195-
else
196-
raise ArgumentError, "expected #{params_count} arguments, got #{args_count}"
194+
params_count ->
195+
args_count = length(args)
196+
197+
if args_count == params_count do
198+
bind_all(args, stmt, 1)
199+
else
200+
raise ArgumentError, "expected #{params_count} arguments, got #{args_count}"
201+
end
197202
end
198203
end
199204

200205
def bind(stmt, args) when is_map(args) do
201-
params_count = bind_parameter_count(stmt)
202-
args_count = map_size(args)
206+
case bind_parameter_count(stmt) do
207+
{:error, reason} ->
208+
{:error, reason}
203209

204-
if args_count == params_count do
205-
bind_all_named(Map.to_list(args), stmt)
206-
else
207-
raise ArgumentError,
208-
"expected #{params_count} named arguments, got #{args_count}: #{inspect(Map.keys(args))}"
210+
params_count ->
211+
args_count = map_size(args)
212+
213+
if args_count == params_count do
214+
bind_all_named(Map.to_list(args), stmt)
215+
else
216+
raise ArgumentError,
217+
"expected #{params_count} named arguments, got #{args_count}: #{inspect(Map.keys(args))}"
218+
end
209219
end
210220
end
211221

0 commit comments

Comments
 (0)