From 9061b8864d5bf99871622d741a422356bcc82120 Mon Sep 17 00:00:00 2001
From: Ekaitz Zarraga <ekaitz@elenq.tech>
Date: Fri, 2 Aug 2024 20:48:16 +0200
Subject: WIP: Prepared statements (numbers work, strings dont)

---
 src/duckdb/PreparedStatement.zig | 141 +++++++++++++++++++++++++++++++++++++++
 src/duckdb/db.zig                |  41 +++++++++++-
 2 files changed, 181 insertions(+), 1 deletion(-)
 create mode 100644 src/duckdb/PreparedStatement.zig

diff --git a/src/duckdb/PreparedStatement.zig b/src/duckdb/PreparedStatement.zig
new file mode 100644
index 0000000..e81ac4c
--- /dev/null
+++ b/src/duckdb/PreparedStatement.zig
@@ -0,0 +1,141 @@
+const std = @import("std");
+const c = @cImport({
+    @cInclude("duckdb.h");
+});
+const Result = @import("Result.zig").Result;
+
+_q : c.duckdb_prepared_statement,
+_current: usize,
+
+const Self = @This();
+
+pub fn init(stmt: c.duckdb_prepared_statement) !Self{
+    return .{
+        ._q = stmt,
+        ._current = 1,
+    };
+}
+
+pub fn bindInt(self: *Self, param: anytype) !void{
+    const state = switch(@TypeOf(param)) {
+        u8  => c.duckdb_bind_uint8 (self._q, self._current, param),
+        u16 => c.duckdb_bind_uint16(self._q, self._current, param),
+        u32 => c.duckdb_bind_uint32(self._q, self._current, param),
+        u64 => c.duckdb_bind_uint64(self._q, self._current, param),
+        i8  => c.duckdb_bind_int8 (self._q, self._current, param),
+        i16 => c.duckdb_bind_int16(self._q, self._current, param),
+        i32 => c.duckdb_bind_int32(self._q, self._current, param),
+        i64 => c.duckdb_bind_int64(self._q, self._current, param),
+        else => @compileError("Invalid int type for binding: "
+            ++ @typeName(@TypeOf(param))),
+    };
+    if ( state == c.DuckDBError ) {
+        return error.DuckDbBindError;
+    }
+    self._current += 1;
+}
+
+pub fn bindFloat(self: *Self, param: anytype) !void{
+    const state = switch(@TypeOf(param)) {
+        f32 => c.duckdb_bind_float (self._q, self._current, param),
+        f64 => c.duckdb_bind_double(self._q, self._current, param),
+        else => @compileError("Invalid float type for binding: "
+            ++ @typeName(@TypeOf(param))),
+    };
+    if ( state == c.DuckDBError ) {
+        return error.DuckDbBindError;
+    }
+    self._current += 1;
+}
+
+pub fn bindBool(self: *Self, param: bool) !void{
+    const state = c.duckdb_bind_bool(self._q, self._current, param);
+    if ( state == c.DuckDBError ) {
+        return error.DuckDbBindError;
+    }
+    self._current += 1;
+}
+
+pub fn bindNull(self: *Self) !void{
+    const state = c.duckdb_bind_bool(self._q, self._current);
+    if ( state == c.DuckDBError ) {
+        return error.DuckDbBindError;
+    }
+    self._current += 1;
+}
+
+pub fn bindString(self: *Self, param: []const u8) !void{
+    const state = c.duckdb_bind_varchar_length(self._q, self._current,
+        param.ptr, param.len);
+    if ( state == c.DuckDBError ) {
+        return error.DuckDbBindError;
+    }
+    self._current += 1;
+}
+
+pub fn bind(self: *Self, param: anytype) !void {
+    switch (@typeInfo(@TypeOf(param))) {
+        .Null    => return try self.bindNull(),
+        .Bool    => return try self.bindBool(param),
+        .Int     => return try self.bindInt(param),
+        .Float   => return try self.bindFloat(param),
+        .Array   => |arr| {
+            if (arr.child == u8){
+                return try self.bindString(param);
+            } else {
+                return error.UnbindableType;
+            }
+        },
+        .Pointer => |ptr| {
+            if (ptr.size == .Slice and ptr.child == u8){
+                return try self.bindString(param);
+            } else {
+                return error.UnbindableType;
+            }
+        },
+        else     => @compileError("Invalid type for binding: "
+            ++ @typeName(@TypeOf(param))),
+    }
+}
+
+pub fn bindAll(self: *Self, params: anytype) !void{
+    const param_count: usize = c.duckdb_nparams(self._q);
+    switch (@typeInfo(@TypeOf(params))) {
+        .Array   => |arr| {
+            if (arr.len == param_count){
+                inline for (params) |p|{
+                    try self.bind(p);
+                }
+                return;
+            } else {
+                return error.BindingLengthsDoNotMatch;
+            }
+        },
+        .Struct  => |str|{
+            const x = str.fields;
+            if (x.len == param_count){
+                inline for (x) |field|{
+                    try self.bind(@field(params, field.name));
+                }
+                return;
+            } else {
+                return error.BindingLengthsDoNotMatch;
+            }
+        },
+        else     => @compileError("Invalid type for binding: "
+            ++ @typeName(@TypeOf(params))),
+    }
+}
+
+pub fn exec(self: *Self, comptime T: type) !Result(T){
+    var result: c.duckdb_result = undefined;
+    const state = c.duckdb_execute_prepared(self._q, &result);
+    if ( state == c.DuckDBError ){
+        return error.DuckDbExecError;
+    }
+    return try Result(T).init(result);
+}
+
+pub fn deinit(self: *Self) void{
+    c.duckdb_destroy_prepare(&self._q);
+}
diff --git a/src/duckdb/db.zig b/src/duckdb/db.zig
index 4b5ebe3..4fad7cb 100644
--- a/src/duckdb/db.zig
+++ b/src/duckdb/db.zig
@@ -4,7 +4,7 @@ const c = @cImport({
     @cInclude("duckdb.h");
 });
 const Result = @import("Result.zig").Result;
-
+const PreparedStatement = @import("PreparedStatement.zig");
 
 const Connection = struct {
     _conn: c.duckdb_connection,
@@ -37,6 +37,17 @@ const Connection = struct {
         var x = try self.query(q, void);
         defer x.deinit();
     }
+
+    /// Make a prepared query. Caller needs to call prepared_query.deinit()
+    pub fn prepareStatement(self: *Connection, q: [:0]const u8)
+        !PreparedStatement {
+        var stmt: c.duckdb_prepared_statement = undefined;
+        const state = c.duckdb_prepare(self._conn, q, &stmt);
+        if ( state == c.DuckDBError ){
+            return error.DuckDBError;
+        }
+        return PreparedStatement.init(stmt);
+    }
 };
 
 
@@ -142,3 +153,31 @@ test "String queries" {
     w = try result.next();
     try std.testing.expect(std.mem.eql(u8, w.primer, "A very long string that is not inlined"));
 }
+
+test "Prepared queries" {
+    var database = try Database.init(":memory:");
+    defer database.deinit();
+    var connection = try database.connect();
+    defer connection.deinit();
+
+    try connection.run("CREATE TABLE ints (i INTEGER NOT NULL );");
+    var prepared = try connection.prepareStatement("INSERT INTO ints VALUES (?), (?);");
+    defer prepared.deinit();
+
+    const uno: i32 = 1;
+    const dos: i32 = 2;
+    try prepared.bindAll(.{uno, dos});
+    var res = try prepared.exec(void);
+    res.deinit();
+
+    const s: type = struct {
+        primer: i32,
+    };
+    var result = try connection.query("SELECT * FROM ints;", s);
+    defer result.deinit();
+
+    var r = try result.next();
+    try std.testing.expect(r.primer == uno);
+    r = try result.next();
+    try std.testing.expect(r.primer == dos);
+}
-- 
cgit v1.2.3