diff --git a/SQLite.xcodeproj/project.pbxproj b/SQLite.xcodeproj/project.pbxproj index 30a40214..034a39bf 100644 --- a/SQLite.xcodeproj/project.pbxproj +++ b/SQLite.xcodeproj/project.pbxproj @@ -81,6 +81,9 @@ 19A17FB80B94E882050AA908 /* FoundationTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = 19A1794CC4D7827E997E32A7 /* FoundationTests.swift */; }; 19A17FDA323BAFDEC627E76F /* fixtures in Resources */ = {isa = PBXBuildFile; fileRef = 19A17E2695737FAB5D6086E3 /* fixtures */; }; 19A17FF4A10B44D3937C8CAC /* Errors.swift in Sources */ = {isa = PBXBuildFile; fileRef = 19A1710E73A46D5AC721CDA9 /* Errors.swift */; }; + 3717F908221F5D8800B9BD3D /* CustomAggregationTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = 3717F907221F5D7C00B9BD3D /* CustomAggregationTests.swift */; }; + 3717F909221F5D8900B9BD3D /* CustomAggregationTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = 3717F907221F5D7C00B9BD3D /* CustomAggregationTests.swift */; }; + 3717F90A221F5D8900B9BD3D /* CustomAggregationTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = 3717F907221F5D7C00B9BD3D /* CustomAggregationTests.swift */; }; 3D67B3E61DB2469200A4F4C6 /* libsqlite3.tbd in Frameworks */ = {isa = PBXBuildFile; fileRef = 3D67B3E51DB2469200A4F4C6 /* libsqlite3.tbd */; }; 3D67B3E71DB246BA00A4F4C6 /* Blob.swift in Sources */ = {isa = PBXBuildFile; fileRef = EE247AEE1C3F06E900AE3E12 /* Blob.swift */; }; 3D67B3E81DB246BA00A4F4C6 /* Connection.swift in Sources */ = {isa = PBXBuildFile; fileRef = EE247AEF1C3F06E900AE3E12 /* Connection.swift */; }; @@ -225,6 +228,7 @@ 19A17B93B48B5560E6E51791 /* Fixtures.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = Fixtures.swift; sourceTree = ""; }; 19A17BA55DABB480F9020C8A /* DateAndTimeFunctions.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = DateAndTimeFunctions.swift; sourceTree = ""; }; 19A17E2695737FAB5D6086E3 /* fixtures */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = folder; path = fixtures; sourceTree = ""; }; + 3717F907221F5D7C00B9BD3D /* CustomAggregationTests.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = CustomAggregationTests.swift; sourceTree = ""; }; 3D67B3E51DB2469200A4F4C6 /* libsqlite3.tbd */ = {isa = PBXFileReference; lastKnownFileType = "sourcecode.text-based-dylib-definition"; name = libsqlite3.tbd; path = Platforms/WatchOS.platform/Developer/SDKs/WatchOS3.0.sdk/usr/lib/libsqlite3.tbd; sourceTree = DEVELOPER_DIR; }; 49EB68C31F7B3CB400D89D40 /* Coding.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Coding.swift; sourceTree = ""; }; A121AC451CA35C79005A31D1 /* SQLite.framework */ = {isa = PBXFileReference; explicitFileType = wrapper.framework; includeInIndex = 0; path = SQLite.framework; sourceTree = BUILT_PRODUCTS_DIR; }; @@ -401,6 +405,7 @@ EE247B1D1C3F137700AE3E12 /* ConnectionTests.swift */, EE247B1E1C3F137700AE3E12 /* CoreFunctionsTests.swift */, EE247B1F1C3F137700AE3E12 /* CustomFunctionsTests.swift */, + 3717F907221F5D7C00B9BD3D /* CustomAggregationTests.swift */, EE247B201C3F137700AE3E12 /* ExpressionTests.swift */, EE247B211C3F137700AE3E12 /* FTS4Tests.swift */, EE247B2A1C3F141E00AE3E12 /* OperatorsTests.swift */, @@ -835,6 +840,7 @@ 03A65E921C6BB3030062603F /* SetterTests.swift in Sources */, 03A65E891C6BB3030062603F /* ConnectionTests.swift in Sources */, 03A65E8A1C6BB3030062603F /* CoreFunctionsTests.swift in Sources */, + 3717F90A221F5D8900B9BD3D /* CustomAggregationTests.swift in Sources */, 03A65E931C6BB3030062603F /* StatementTests.swift in Sources */, 03A65E911C6BB3030062603F /* SchemaTests.swift in Sources */, 03A65E8D1C6BB3030062603F /* FTS4Tests.swift in Sources */, @@ -923,6 +929,7 @@ EE247B271C3F137700AE3E12 /* CustomFunctionsTests.swift in Sources */, EE247B341C3F142E00AE3E12 /* StatementTests.swift in Sources */, EE247B301C3F141E00AE3E12 /* RTreeTests.swift in Sources */, + 3717F908221F5D8800B9BD3D /* CustomAggregationTests.swift in Sources */, EE247B231C3F137700AE3E12 /* BlobTests.swift in Sources */, EE247B351C3F142E00AE3E12 /* ValueTests.swift in Sources */, EE247B2F1C3F141E00AE3E12 /* QueryTests.swift in Sources */, @@ -981,6 +988,7 @@ EE247B5F1C3F3FC700AE3E12 /* StatementTests.swift in Sources */, EE247B5C1C3F3FC700AE3E12 /* RTreeTests.swift in Sources */, EE247B571C3F3FC700AE3E12 /* CustomFunctionsTests.swift in Sources */, + 3717F909221F5D8900B9BD3D /* CustomAggregationTests.swift in Sources */, EE247B601C3F3FC700AE3E12 /* ValueTests.swift in Sources */, EE247B551C3F3FC700AE3E12 /* ConnectionTests.swift in Sources */, EE247B611C3F3FC700AE3E12 /* TestHelpers.swift in Sources */, diff --git a/Sources/SQLite/Core/Connection.swift b/Sources/SQLite/Core/Connection.swift index 1bbf7f73..349c18b6 100644 --- a/Sources/SQLite/Core/Connection.swift +++ b/Sources/SQLite/Core/Connection.swift @@ -597,8 +597,122 @@ public final class Connection { if functions[function] == nil { self.functions[function] = [:] } functions[function]?[argc] = box } + + /// Creates or redefines a custom SQL aggregate. + /// + /// - Parameters: + /// + /// - aggregate: The name of the aggregate to create or redefine. + /// + /// - argumentCount: The number of arguments that the aggregate takes. If + /// `nil`, the aggregate may take any number of arguments. + /// + /// Default: `nil` + /// + /// - deterministic: Whether or not the aggregate is deterministic (_i.e._ + /// the aggregate always returns the same result for a given input). + /// + /// Default: `false` + /// + /// - step: A block of code to run for each row of an aggregation group. + /// The block is called with an array of raw SQL values mapped to the + /// aggregate’s parameters, and an UnsafeMutablePointer to a state + /// variable. + /// + /// - final: A block of code to run after each row of an aggregation group + /// is processed. The block is called with an UnsafeMutablePointer to a + /// state variable, and should return a raw SQL value (or nil). + /// + /// - state: A block of code to run to produce a fresh state variable for + /// each aggregation group. The block should return an + /// UnsafeMutablePointer to the fresh state variable. + public func createAggregation( + _ aggregate: String, + argumentCount: UInt? = nil, + deterministic: Bool = false, + step: @escaping ([Binding?], UnsafeMutablePointer) -> (), + final: @escaping (UnsafeMutablePointer) -> Binding?, + state: @escaping () -> UnsafeMutablePointer) { + + + let argc = argumentCount.map { Int($0) } ?? -1 + let box : Aggregate = { (stepFlag: Int, context: OpaquePointer?, argc: Int32, argv: UnsafeMutablePointer?) in + let ptr = sqlite3_aggregate_context(context, 64)! // needs to be at least as large as uintptr_t; better way to do this? + let p = ptr.assumingMemoryBound(to: UnsafeMutableRawPointer.self) + if stepFlag > 0 { + let arguments: [Binding?] = (0..?) -> Void fileprivate typealias Function = @convention(block) (OpaquePointer?, Int32, UnsafeMutablePointer?) -> Void fileprivate var functions = [String: [Int: Function]]() + fileprivate var aggregations = [String: [Int: Aggregate]]() /// Defines a new collating sequence. /// diff --git a/Sources/SQLite/Typed/CustomFunctions.swift b/Sources/SQLite/Typed/CustomFunctions.swift index 8910a24b..4d3e69ee 100644 --- a/Sources/SQLite/Typed/CustomFunctions.swift +++ b/Sources/SQLite/Typed/CustomFunctions.swift @@ -133,4 +133,69 @@ public extension Connection { } } + // MARK: - + + public func createAggregation( + _ aggregate: String, + argumentCount: UInt? = nil, + deterministic: Bool = false, + initialValue: T, + reduce: @escaping (T, [Binding?]) -> T, + result: @escaping (T) -> Binding? + ) { + + let step: ([Binding?], UnsafeMutablePointer) -> () = { (bindings, ptr) in + let p = ptr.pointee.assumingMemoryBound(to: T.self) + let current = Unmanaged.fromOpaque(p).takeRetainedValue() + let next = reduce(current, bindings) + ptr.pointee = Unmanaged.passRetained(next).toOpaque() + } + + let final: (UnsafeMutablePointer) -> Binding? = { (ptr) in + let p = ptr.pointee.assumingMemoryBound(to: T.self) + let obj = Unmanaged.fromOpaque(p).takeRetainedValue() + let value = result(obj) + ptr.deallocate() + return value + } + + let state: () -> UnsafeMutablePointer = { + let p = UnsafeMutablePointer.allocate(capacity: 1) + p.pointee = Unmanaged.passRetained(initialValue).toOpaque() + return p + } + + createAggregation(aggregate, step: step, final: final, state: state) + } + + public func createAggregation( + _ aggregate: String, + argumentCount: UInt? = nil, + deterministic: Bool = false, + initialValue: T, + reduce: @escaping (T, [Binding?]) -> T, + result: @escaping (T) -> Binding? + ) { + + let step: ([Binding?], UnsafeMutablePointer) -> () = { (bindings, p) in + let current = p.pointee + let next = reduce(current, bindings) + p.pointee = next + } + + let final: (UnsafeMutablePointer) -> Binding? = { (p) in + let v = result(p.pointee) + p.deallocate() + return v + } + + let state: () -> UnsafeMutablePointer = { + let p = UnsafeMutablePointer.allocate(capacity: 1) + p.pointee = initialValue + return p + } + + createAggregation(aggregate, step: step, final: final, state: state) + } + } diff --git a/Tests/SQLiteTests/CustomAggregationTests.swift b/Tests/SQLiteTests/CustomAggregationTests.swift new file mode 100644 index 00000000..f8efc7e4 --- /dev/null +++ b/Tests/SQLiteTests/CustomAggregationTests.swift @@ -0,0 +1,155 @@ +import XCTest +import Foundation +import Dispatch +@testable import SQLite + +#if SQLITE_SWIFT_STANDALONE +import sqlite3 +#elseif SQLITE_SWIFT_SQLCIPHER +import SQLCipher +#elseif os(Linux) +import CSQLite +#else +import SQLite3 +#endif + +class CustomAggregationTests : SQLiteTestCase { + override func setUp() { + super.setUp() + CreateUsersTable() + try! InsertUser("Alice", age: 30, admin: true) + try! InsertUser("Bob", age: 25, admin: true) + try! InsertUser("Eve", age: 28, admin: false) + } + + func testUnsafeCustomSum() { + let step = { (bindings: [Binding?], state: UnsafeMutablePointer) in + if let v = bindings[0] as? Int64 { + state.pointee += v + } + } + + let final = { (state: UnsafeMutablePointer) -> Binding? in + let v = state.pointee + let p = UnsafeMutableBufferPointer(start: state, count: 1) + p.deallocate() + return v + } + let _ = db.createAggregation("mySUM1", step: step, final: final) { + let v = UnsafeMutableBufferPointer.allocate(capacity: 1) + v[0] = 0 + return v.baseAddress! + } + let result = try! db.prepare("SELECT mySUM1(age) AS s FROM users") + let i = result.columnNames.index(of: "s")! + for row in result { + let value = row[i] as? Int64 + XCTAssertEqual(83, value) + } + } + + func testUnsafeCustomSumGrouping() { + let step = { (bindings: [Binding?], state: UnsafeMutablePointer) in + if let v = bindings[0] as? Int64 { + state.pointee += v + } + } + let final = { (state: UnsafeMutablePointer) -> Binding? in + let v = state.pointee + let p = UnsafeMutableBufferPointer(start: state, count: 1) + p.deallocate() + return v + } + let _ = db.createAggregation("mySUM2", step: step, final: final) { + let v = UnsafeMutableBufferPointer.allocate(capacity: 1) + v[0] = 0 + return v.baseAddress! + } + let result = try! db.prepare("SELECT mySUM2(age) AS s FROM users GROUP BY admin ORDER BY s") + let i = result.columnNames.index(of: "s")! + let values = result.compactMap { $0[i] as? Int64 } + XCTAssertTrue(values.elementsEqual([28, 55])) + } + + func testCustomSum() { + let reduce : (Int64, [Binding?]) -> Int64 = { (last, bindings) in + let v = (bindings[0] as? Int64) ?? 0 + return last + v + } + let _ = db.createAggregation("myReduceSUM1", initialValue: Int64(2000), reduce: reduce, result: { $0 }) + let result = try! db.prepare("SELECT myReduceSUM1(age) AS s FROM users") + let i = result.columnNames.index(of: "s")! + for row in result { + let value = row[i] as? Int64 + XCTAssertEqual(2083, value) + } + } + + func testCustomSumGrouping() { + let reduce : (Int64, [Binding?]) -> Int64 = { (last, bindings) in + let v = (bindings[0] as? Int64) ?? 0 + return last + v + } + let _ = db.createAggregation("myReduceSUM2", initialValue: Int64(3000), reduce: reduce, result: { $0 }) + let result = try! db.prepare("SELECT myReduceSUM2(age) AS s FROM users GROUP BY admin ORDER BY s") + let i = result.columnNames.index(of: "s")! + let values = result.compactMap { $0[i] as? Int64 } + XCTAssertTrue(values.elementsEqual([3028, 3055])) + } + + func testCustomStringAgg() { + let initial = String(repeating: " ", count: 64) + let reduce : (String, [Binding?]) -> String = { (last, bindings) in + let v = (bindings[0] as? String) ?? "" + return last + v + } + let _ = db.createAggregation("myReduceSUM3", initialValue: initial, reduce: reduce, result: { $0 }) + let result = try! db.prepare("SELECT myReduceSUM3(email) AS s FROM users") + let i = result.columnNames.index(of: "s")! + for row in result { + let value = row[i] as? String + XCTAssertEqual("\(initial)Alice@example.comBob@example.comEve@example.com", value) + } + } + + func testCustomObjectSum() { + { + let initial = TestObject(value: 1000) + let reduce : (TestObject, [Binding?]) -> TestObject = { (last, bindings) in + let v = (bindings[0] as? Int64) ?? 0 + return TestObject(value: last.value + v) + } + let _ = db.createAggregation("myReduceSUMX", initialValue: initial, reduce: reduce, result: { $0.value }) + // end this scope to ensure that the initial value is retained + // by the createAggregation call. + }(); + { + XCTAssertEqual(TestObject.inits, 1) + let result = try! db.prepare("SELECT myReduceSUMX(age) AS s FROM users") + let i = result.columnNames.index(of: "s")! + for row in result { + let value = row[i] as? Int64 + XCTAssertEqual(1083, value) + } + }() + XCTAssertEqual(TestObject.inits, 4) + XCTAssertEqual(TestObject.deinits, 3) // the initial value is still retained by the aggregate's state block, so deinits is one less than inits + } +} + +/// This class is used to test that aggregation state variables +/// can be reference types and are properly memory managed when +/// crossing the Swift<->C boundary multiple times. +class TestObject { + static var inits = 0 + static var deinits = 0 + + var value: Int64 + init(value: Int64) { + self.value = value + TestObject.inits += 1 + } + deinit { + TestObject.deinits += 1 + } +}