@@ -43,6 +43,7 @@ public enum SQLOrder {
4343/// case REAL (includes REAL, NUMERIC, DECIMAL, FLOAT, DOUBLE, DOUBLE_PRECISION)
4444/// case BLOB (includes BLOB, BINARY, VARBINARY)
4545/// case DATE (includes DATE, DATETIME, TIME, TIMESTAMP)
46+ public typealias SQLTableColums = [ ( name: String , type: SQLType ) ]
4647public typealias SQLValues = [ ( type: SQLType , value: Any ? ) ]
4748
4849public protocol SQLiteType {
@@ -377,9 +378,13 @@ open class SQLite: SQLiteType {
377378 var allRows : [ SQLValues ] = [ ]
378379 var rowValues : SQLValues = SQLValues ( [ ] )
379380
381+ guard let resultColumns = try ? getResultColumns ( table, sqlStatement: sqlStatement) else {
382+ throw SQLiteError . Column ( getErrorMessage ( dbPointer: dbPointer) )
383+ }
384+
380385 while sqlite3_step ( sqlStatement) == SQLITE_ROW {
381386 rowValues = SQLValues ( [ ] )
382- for (index, value) in table . columnTypes . enumerated ( ) {
387+ for (index, value) in resultColumns . enumerated ( ) {
383388
384389 let index = Int32 ( index) // column serial number, should start with 0
385390
@@ -438,6 +443,29 @@ open class SQLite: SQLiteType {
438443 return allRows
439444 }
440445
446+ private func getResultColumns( _ table: SQLTable , sqlStatement: OpaquePointer ? ) throws -> SQLTableColums {
447+ var columnNamesToReturn : [ String ] = [ ]
448+ let columnCount = sqlite3_column_count ( sqlStatement)
449+ for index in 0 ..< columnCount {
450+ if let columnName = sqlite3_column_name ( sqlStatement, index) {
451+ if let validatedColumnName = String ( validatingUTF8: columnName) {
452+ columnNamesToReturn. append ( validatedColumnName)
453+ } else {
454+ throw SQLiteError . Column ( getErrorMessage ( dbPointer: dbPointer) )
455+ }
456+ } else {
457+ throw SQLiteError . Column ( getErrorMessage ( dbPointer: dbPointer) )
458+ }
459+ }
460+ var resultColumns : SQLTableColums = [ ]
461+ for (index, column) in table. columns. enumerated ( ) {
462+ if columnNamesToReturn. contains ( column. name) {
463+ resultColumns. append ( ( column. name, table. columns [ index] . type) )
464+ }
465+ }
466+ return resultColumns
467+ }
468+
441469 public func getAllRows( from table: SQLTable ) throws -> [ SQLValues ] {
442470 let sql = " SELECT * FROM \( table. name) ; "
443471 let result = try getRow ( from: table, sql: sql)
0 commit comments