Skip to content

Add support for custom SQL aggregates #881

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Aug 23, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions SQLite.xcodeproj/project.pbxproj
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,9 @@
19A17FB80B94E882050AA908 /* FoundationTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = 19A1794CC4D7827E997E32A7 /* FoundationTests.swift */; };
19A17FDA323BAFDEC627E76F /* fixtures in Resources */ = {isa = PBXBuildFile; fileRef = 19A17E2695737FAB5D6086E3 /* fixtures */; };
19A17FF4A10B44D3937C8CAC /* Errors.swift in Sources */ = {isa = PBXBuildFile; fileRef = 19A1710E73A46D5AC721CDA9 /* Errors.swift */; };
3717F908221F5D8800B9BD3D /* CustomAggregationTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = 3717F907221F5D7C00B9BD3D /* CustomAggregationTests.swift */; };
3717F909221F5D8900B9BD3D /* CustomAggregationTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = 3717F907221F5D7C00B9BD3D /* CustomAggregationTests.swift */; };
3717F90A221F5D8900B9BD3D /* CustomAggregationTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = 3717F907221F5D7C00B9BD3D /* CustomAggregationTests.swift */; };
3D67B3E61DB2469200A4F4C6 /* libsqlite3.tbd in Frameworks */ = {isa = PBXBuildFile; fileRef = 3D67B3E51DB2469200A4F4C6 /* libsqlite3.tbd */; };
3D67B3E71DB246BA00A4F4C6 /* Blob.swift in Sources */ = {isa = PBXBuildFile; fileRef = EE247AEE1C3F06E900AE3E12 /* Blob.swift */; };
3D67B3E81DB246BA00A4F4C6 /* Connection.swift in Sources */ = {isa = PBXBuildFile; fileRef = EE247AEF1C3F06E900AE3E12 /* Connection.swift */; };
Expand Down Expand Up @@ -225,6 +228,7 @@
19A17B93B48B5560E6E51791 /* Fixtures.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = Fixtures.swift; sourceTree = "<group>"; };
19A17BA55DABB480F9020C8A /* DateAndTimeFunctions.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = DateAndTimeFunctions.swift; sourceTree = "<group>"; };
19A17E2695737FAB5D6086E3 /* fixtures */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = folder; path = fixtures; sourceTree = "<group>"; };
3717F907221F5D7C00B9BD3D /* CustomAggregationTests.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = CustomAggregationTests.swift; sourceTree = "<group>"; };
3D67B3E51DB2469200A4F4C6 /* libsqlite3.tbd */ = {isa = PBXFileReference; lastKnownFileType = "sourcecode.text-based-dylib-definition"; name = libsqlite3.tbd; path = Platforms/WatchOS.platform/Developer/SDKs/WatchOS3.0.sdk/usr/lib/libsqlite3.tbd; sourceTree = DEVELOPER_DIR; };
49EB68C31F7B3CB400D89D40 /* Coding.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Coding.swift; sourceTree = "<group>"; };
A121AC451CA35C79005A31D1 /* SQLite.framework */ = {isa = PBXFileReference; explicitFileType = wrapper.framework; includeInIndex = 0; path = SQLite.framework; sourceTree = BUILT_PRODUCTS_DIR; };
Expand Down Expand Up @@ -401,6 +405,7 @@
EE247B1D1C3F137700AE3E12 /* ConnectionTests.swift */,
EE247B1E1C3F137700AE3E12 /* CoreFunctionsTests.swift */,
EE247B1F1C3F137700AE3E12 /* CustomFunctionsTests.swift */,
3717F907221F5D7C00B9BD3D /* CustomAggregationTests.swift */,
EE247B201C3F137700AE3E12 /* ExpressionTests.swift */,
EE247B211C3F137700AE3E12 /* FTS4Tests.swift */,
EE247B2A1C3F141E00AE3E12 /* OperatorsTests.swift */,
Expand Down Expand Up @@ -835,6 +840,7 @@
03A65E921C6BB3030062603F /* SetterTests.swift in Sources */,
03A65E891C6BB3030062603F /* ConnectionTests.swift in Sources */,
03A65E8A1C6BB3030062603F /* CoreFunctionsTests.swift in Sources */,
3717F90A221F5D8900B9BD3D /* CustomAggregationTests.swift in Sources */,
03A65E931C6BB3030062603F /* StatementTests.swift in Sources */,
03A65E911C6BB3030062603F /* SchemaTests.swift in Sources */,
03A65E8D1C6BB3030062603F /* FTS4Tests.swift in Sources */,
Expand Down Expand Up @@ -923,6 +929,7 @@
EE247B271C3F137700AE3E12 /* CustomFunctionsTests.swift in Sources */,
EE247B341C3F142E00AE3E12 /* StatementTests.swift in Sources */,
EE247B301C3F141E00AE3E12 /* RTreeTests.swift in Sources */,
3717F908221F5D8800B9BD3D /* CustomAggregationTests.swift in Sources */,
EE247B231C3F137700AE3E12 /* BlobTests.swift in Sources */,
EE247B351C3F142E00AE3E12 /* ValueTests.swift in Sources */,
EE247B2F1C3F141E00AE3E12 /* QueryTests.swift in Sources */,
Expand Down Expand Up @@ -981,6 +988,7 @@
EE247B5F1C3F3FC700AE3E12 /* StatementTests.swift in Sources */,
EE247B5C1C3F3FC700AE3E12 /* RTreeTests.swift in Sources */,
EE247B571C3F3FC700AE3E12 /* CustomFunctionsTests.swift in Sources */,
3717F909221F5D8900B9BD3D /* CustomAggregationTests.swift in Sources */,
EE247B601C3F3FC700AE3E12 /* ValueTests.swift in Sources */,
EE247B551C3F3FC700AE3E12 /* ConnectionTests.swift in Sources */,
EE247B611C3F3FC700AE3E12 /* TestHelpers.swift in Sources */,
Expand Down
114 changes: 114 additions & 0 deletions Sources/SQLite/Core/Connection.swift
Original file line number Diff line number Diff line change
Expand Up @@ -597,8 +597,122 @@ public final class Connection {
if functions[function] == nil { self.functions[function] = [:] }
functions[function]?[argc] = box
}

/// Creates or redefines a custom SQL aggregate.
///
/// - Parameters:
///
/// - aggregate: The name of the aggregate to create or redefine.
///
/// - argumentCount: The number of arguments that the aggregate takes. If
/// `nil`, the aggregate may take any number of arguments.
///
/// Default: `nil`
///
/// - deterministic: Whether or not the aggregate is deterministic (_i.e._
/// the aggregate always returns the same result for a given input).
///
/// Default: `false`
///
/// - step: A block of code to run for each row of an aggregation group.
/// The block is called with an array of raw SQL values mapped to the
/// aggregate’s parameters, and an UnsafeMutablePointer to a state
/// variable.
///
/// - final: A block of code to run after each row of an aggregation group
/// is processed. The block is called with an UnsafeMutablePointer to a
/// state variable, and should return a raw SQL value (or nil).
///
/// - state: A block of code to run to produce a fresh state variable for
/// each aggregation group. The block should return an
/// UnsafeMutablePointer to the fresh state variable.
public func createAggregation<T>(
_ aggregate: String,
argumentCount: UInt? = nil,
deterministic: Bool = false,
step: @escaping ([Binding?], UnsafeMutablePointer<T>) -> (),
final: @escaping (UnsafeMutablePointer<T>) -> Binding?,
state: @escaping () -> UnsafeMutablePointer<T>) {


let argc = argumentCount.map { Int($0) } ?? -1
let box : Aggregate = { (stepFlag: Int, context: OpaquePointer?, argc: Int32, argv: UnsafeMutablePointer<OpaquePointer?>?) in
let ptr = sqlite3_aggregate_context(context, 64)! // needs to be at least as large as uintptr_t; better way to do this?
let p = ptr.assumingMemoryBound(to: UnsafeMutableRawPointer.self)
if stepFlag > 0 {
let arguments: [Binding?] = (0..<Int(argc)).map { idx in
let value = argv![idx]
switch sqlite3_value_type(value) {
case SQLITE_BLOB:
return Blob(bytes: sqlite3_value_blob(value), length: Int(sqlite3_value_bytes(value)))
case SQLITE_FLOAT:
return sqlite3_value_double(value)
case SQLITE_INTEGER:
return sqlite3_value_int64(value)
case SQLITE_NULL:
return nil
case SQLITE_TEXT:
return String(cString: UnsafePointer(sqlite3_value_text(value)))
case let type:
fatalError("unsupported value type: \(type)")
}
}

if ptr.assumingMemoryBound(to: Int64.self).pointee == 0 {
let v = state()
p.pointee = UnsafeMutableRawPointer(mutating: v)
}
step(arguments, p.pointee.assumingMemoryBound(to: T.self))
} else {
let result = final(p.pointee.assumingMemoryBound(to: T.self))
if let result = result as? Blob {
sqlite3_result_blob(context, result.bytes, Int32(result.bytes.count), nil)
} else if let result = result as? Double {
sqlite3_result_double(context, result)
} else if let result = result as? Int64 {
sqlite3_result_int64(context, result)
} else if let result = result as? String {
sqlite3_result_text(context, result, Int32(result.count), SQLITE_TRANSIENT)
} else if result == nil {
sqlite3_result_null(context)
} else {
fatalError("unsupported result type: \(String(describing: result))")
}
}
}

var flags = SQLITE_UTF8
#if !os(Linux)
if deterministic {
flags |= SQLITE_DETERMINISTIC
}
#endif

sqlite3_create_function_v2(
handle,
aggregate,
Int32(argc),
flags,
unsafeBitCast(box, to: UnsafeMutableRawPointer.self),
nil,
{ context, argc, value in
let function = unsafeBitCast(sqlite3_user_data(context), to: Aggregate.self)
function(1, context, argc, value)
},
{ context in
let function = unsafeBitCast(sqlite3_user_data(context), to: Aggregate.self)
function(0, context, 0, nil)
},
nil
)
if aggregations[aggregate] == nil { self.aggregations[aggregate] = [:] }
aggregations[aggregate]?[argc] = box
}

fileprivate typealias Aggregate = @convention(block) (Int, OpaquePointer?, Int32, UnsafeMutablePointer<OpaquePointer?>?) -> Void
fileprivate typealias Function = @convention(block) (OpaquePointer?, Int32, UnsafeMutablePointer<OpaquePointer?>?) -> Void
fileprivate var functions = [String: [Int: Function]]()
fileprivate var aggregations = [String: [Int: Aggregate]]()

/// Defines a new collating sequence.
///
Expand Down
65 changes: 65 additions & 0 deletions Sources/SQLite/Typed/CustomFunctions.swift
Original file line number Diff line number Diff line change
Expand Up @@ -133,4 +133,69 @@ public extension Connection {
}
}

// MARK: -

public func createAggregation<T: AnyObject>(
_ aggregate: String,
argumentCount: UInt? = nil,
deterministic: Bool = false,
initialValue: T,
reduce: @escaping (T, [Binding?]) -> T,
result: @escaping (T) -> Binding?
) {

let step: ([Binding?], UnsafeMutablePointer<UnsafeMutableRawPointer>) -> () = { (bindings, ptr) in
let p = ptr.pointee.assumingMemoryBound(to: T.self)
let current = Unmanaged<T>.fromOpaque(p).takeRetainedValue()
let next = reduce(current, bindings)
ptr.pointee = Unmanaged.passRetained(next).toOpaque()
}

let final: (UnsafeMutablePointer<UnsafeMutableRawPointer>) -> Binding? = { (ptr) in
let p = ptr.pointee.assumingMemoryBound(to: T.self)
let obj = Unmanaged<T>.fromOpaque(p).takeRetainedValue()
let value = result(obj)
ptr.deallocate()
return value
}

let state: () -> UnsafeMutablePointer<UnsafeMutableRawPointer> = {
let p = UnsafeMutablePointer<UnsafeMutableRawPointer>.allocate(capacity: 1)
p.pointee = Unmanaged.passRetained(initialValue).toOpaque()
return p
}

createAggregation(aggregate, step: step, final: final, state: state)
}

public func createAggregation<T>(
_ aggregate: String,
argumentCount: UInt? = nil,
deterministic: Bool = false,
initialValue: T,
reduce: @escaping (T, [Binding?]) -> T,
result: @escaping (T) -> Binding?
) {

let step: ([Binding?], UnsafeMutablePointer<T>) -> () = { (bindings, p) in
let current = p.pointee
let next = reduce(current, bindings)
p.pointee = next
}

let final: (UnsafeMutablePointer<T>) -> Binding? = { (p) in
let v = result(p.pointee)
p.deallocate()
return v
}

let state: () -> UnsafeMutablePointer<T> = {
let p = UnsafeMutablePointer<T>.allocate(capacity: 1)
p.pointee = initialValue
return p
}

createAggregation(aggregate, step: step, final: final, state: state)
}

}
155 changes: 155 additions & 0 deletions Tests/SQLiteTests/CustomAggregationTests.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
import XCTest
import Foundation
import Dispatch
@testable import SQLite

#if SQLITE_SWIFT_STANDALONE
import sqlite3
#elseif SQLITE_SWIFT_SQLCIPHER
import SQLCipher
#elseif os(Linux)
import CSQLite
#else
import SQLite3
#endif

class CustomAggregationTests : SQLiteTestCase {
override func setUp() {
super.setUp()
CreateUsersTable()
try! InsertUser("Alice", age: 30, admin: true)
try! InsertUser("Bob", age: 25, admin: true)
try! InsertUser("Eve", age: 28, admin: false)
}

func testUnsafeCustomSum() {
let step = { (bindings: [Binding?], state: UnsafeMutablePointer<Int64>) in
if let v = bindings[0] as? Int64 {
state.pointee += v
}
}

let final = { (state: UnsafeMutablePointer<Int64>) -> Binding? in
let v = state.pointee
let p = UnsafeMutableBufferPointer(start: state, count: 1)
p.deallocate()
return v
}
let _ = db.createAggregation("mySUM1", step: step, final: final) {
let v = UnsafeMutableBufferPointer<Int64>.allocate(capacity: 1)
v[0] = 0
return v.baseAddress!
}
let result = try! db.prepare("SELECT mySUM1(age) AS s FROM users")
let i = result.columnNames.index(of: "s")!
for row in result {
let value = row[i] as? Int64
XCTAssertEqual(83, value)
}
}

func testUnsafeCustomSumGrouping() {
let step = { (bindings: [Binding?], state: UnsafeMutablePointer<Int64>) in
if let v = bindings[0] as? Int64 {
state.pointee += v
}
}
let final = { (state: UnsafeMutablePointer<Int64>) -> Binding? in
let v = state.pointee
let p = UnsafeMutableBufferPointer(start: state, count: 1)
p.deallocate()
return v
}
let _ = db.createAggregation("mySUM2", step: step, final: final) {
let v = UnsafeMutableBufferPointer<Int64>.allocate(capacity: 1)
v[0] = 0
return v.baseAddress!
}
let result = try! db.prepare("SELECT mySUM2(age) AS s FROM users GROUP BY admin ORDER BY s")
let i = result.columnNames.index(of: "s")!
let values = result.compactMap { $0[i] as? Int64 }
XCTAssertTrue(values.elementsEqual([28, 55]))
}

func testCustomSum() {
let reduce : (Int64, [Binding?]) -> Int64 = { (last, bindings) in
let v = (bindings[0] as? Int64) ?? 0
return last + v
}
let _ = db.createAggregation("myReduceSUM1", initialValue: Int64(2000), reduce: reduce, result: { $0 })
let result = try! db.prepare("SELECT myReduceSUM1(age) AS s FROM users")
let i = result.columnNames.index(of: "s")!
for row in result {
let value = row[i] as? Int64
XCTAssertEqual(2083, value)
}
}

func testCustomSumGrouping() {
let reduce : (Int64, [Binding?]) -> Int64 = { (last, bindings) in
let v = (bindings[0] as? Int64) ?? 0
return last + v
}
let _ = db.createAggregation("myReduceSUM2", initialValue: Int64(3000), reduce: reduce, result: { $0 })
let result = try! db.prepare("SELECT myReduceSUM2(age) AS s FROM users GROUP BY admin ORDER BY s")
let i = result.columnNames.index(of: "s")!
let values = result.compactMap { $0[i] as? Int64 }
XCTAssertTrue(values.elementsEqual([3028, 3055]))
}

func testCustomStringAgg() {
let initial = String(repeating: " ", count: 64)
let reduce : (String, [Binding?]) -> String = { (last, bindings) in
let v = (bindings[0] as? String) ?? ""
return last + v
}
let _ = db.createAggregation("myReduceSUM3", initialValue: initial, reduce: reduce, result: { $0 })
let result = try! db.prepare("SELECT myReduceSUM3(email) AS s FROM users")
let i = result.columnNames.index(of: "s")!
for row in result {
let value = row[i] as? String
XCTAssertEqual("\(initial)Alice@example.comBob@example.comEve@example.com", value)
}
}

func testCustomObjectSum() {
{
let initial = TestObject(value: 1000)
let reduce : (TestObject, [Binding?]) -> TestObject = { (last, bindings) in
let v = (bindings[0] as? Int64) ?? 0
return TestObject(value: last.value + v)
}
let _ = db.createAggregation("myReduceSUMX", initialValue: initial, reduce: reduce, result: { $0.value })
// end this scope to ensure that the initial value is retained
// by the createAggregation call.
}();
{
XCTAssertEqual(TestObject.inits, 1)
let result = try! db.prepare("SELECT myReduceSUMX(age) AS s FROM users")
let i = result.columnNames.index(of: "s")!
for row in result {
let value = row[i] as? Int64
XCTAssertEqual(1083, value)
}
}()
XCTAssertEqual(TestObject.inits, 4)
XCTAssertEqual(TestObject.deinits, 3) // the initial value is still retained by the aggregate's state block, so deinits is one less than inits
}
}

/// This class is used to test that aggregation state variables
/// can be reference types and are properly memory managed when
/// crossing the Swift<->C boundary multiple times.
class TestObject {
static var inits = 0
static var deinits = 0

var value: Int64
init(value: Int64) {
self.value = value
TestObject.inits += 1
}
deinit {
TestObject.deinits += 1
}
}