Skip to content

Commit b77f67b

Browse files
rhermanklinkdpgeorge
authored andcommitted
unix-ffi/sqlite3: Add commit and rollback functionality like CPython.
To increase the similarity between this module and CPythons sqlite3 module the commit() and rollback() as defined in CPythons version have been added, along with the different (auto)commit behaviors present there. The defaults are also set to the same as in CPython, and can be changed with the same parameters in connect(), as is showcased in the new test. Signed-off-by: Robert Klink <rhermanklink@ripe.net>
1 parent 83598cd commit b77f67b

File tree

2 files changed

+137
-38
lines changed

2 files changed

+137
-38
lines changed

‎unix-ffi/sqlite3/sqlite3.py

+95-38
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
sqlite3_open = sq3.func("i", "sqlite3_open", "sp")
1313
# int sqlite3_config(int, ...);
1414
sqlite3_config = sq3.func("i", "sqlite3_config", "ii")
15+
# int sqlite3_get_autocommit(sqlite3*);
16+
sqlite3_get_autocommit = sq3.func("i", "sqlite3_get_autocommit", "p")
1517
# int sqlite3_close_v2(sqlite3*);
1618
sqlite3_close = sq3.func("i", "sqlite3_close_v2", "p")
1719
# int sqlite3_prepare(
@@ -57,6 +59,9 @@
5759

5860
SQLITE_CONFIG_URI = 17
5961

62+
# For compatibility with CPython sqlite3 driver
63+
LEGACY_TRANSACTION_CONTROL = -1
64+
6065

6166
class Error(Exception):
6267
pass
@@ -71,86 +76,138 @@ def get_ptr_size():
7176
return uctypes.sizeof({"ptr": (0 | uctypes.PTR, uctypes.PTR)})
7277

7378

79+
def __prepare_stmt(db, sql):
80+
# Prepares a statement
81+
stmt_ptr = bytes(get_ptr_size())
82+
res = sqlite3_prepare(db, sql, -1, stmt_ptr, None)
83+
check_error(db, res)
84+
return int.from_bytes(stmt_ptr, sys.byteorder)
85+
86+
def __exec_stmt(db, sql):
87+
# Prepares, executes, and finalizes a statement
88+
stmt = __prepare_stmt(db, sql)
89+
sqlite3_step(stmt)
90+
res = sqlite3_finalize(stmt)
91+
check_error(db, res)
92+
93+
def __is_dml(sql):
94+
# Checks if a sql query is a DML, as these get a BEGIN in LEGACY_TRANSACTION_CONTROL
95+
for dml in ["INSERT", "DELETE", "UPDATE", "MERGE"]:
96+
if dml in sql.upper():
97+
return True
98+
return False
99+
100+
74101
class Connections:
75-
def __init__(self, h):
76-
self.h = h
102+
def __init__(self, db, isolation_level, autocommit):
103+
self.db = db
104+
self.isolation_level = isolation_level
105+
self.autocommit = autocommit
106+
107+
def commit(self):
108+
if self.autocommit == LEGACY_TRANSACTION_CONTROL and not sqlite3_get_autocommit(self.db):
109+
__exec_stmt(self.db, "COMMIT")
110+
elif self.autocommit == False:
111+
__exec_stmt(self.db, "COMMIT")
112+
__exec_stmt(self.db, "BEGIN")
113+
114+
def rollback(self):
115+
if self.autocommit == LEGACY_TRANSACTION_CONTROL and not sqlite3_get_autocommit(self.db):
116+
__exec_stmt(self.db, "ROLLBACK")
117+
elif self.autocommit == False:
118+
__exec_stmt(self.db, "ROLLBACK")
119+
__exec_stmt(self.db, "BEGIN")
77120

78121
def cursor(self):
79-
return Cursor(self.h)
122+
return Cursor(self.db, self.isolation_level, self.autocommit)
80123

81124
def close(self):
82-
if self.h:
83-
s = sqlite3_close(self.h)
84-
check_error(self.h, s)
85-
self.h = None
125+
if self.db:
126+
if self.autocommit == False and not sqlite3_get_autocommit(self.db):
127+
__exec_stmt(self.db, "ROLLBACK")
128+
129+
res = sqlite3_close(self.db)
130+
check_error(self.db, res)
131+
self.db = None
86132

87133

88134
class Cursor:
89-
def __init__(self, h):
90-
self.h = h
91-
self.stmnt = None
135+
def __init__(self, db, isolation_level, autocommit):
136+
self.db = db
137+
self.isolation_level = isolation_level
138+
self.autocommit = autocommit
139+
self.stmt = None
140+
141+
def __quote(val):
142+
if isinstance(val, str):
143+
return "'%s'" % val
144+
return str(val)
92145

93146
def execute(self, sql, params=None):
94-
if self.stmnt:
147+
if self.stmt:
95148
# If there is an existing statement, finalize that to free it
96-
res = sqlite3_finalize(self.stmnt)
97-
check_error(self.h, res)
149+
res = sqlite3_finalize(self.stmt)
150+
check_error(self.db, res)
98151

99152
if params:
100-
params = [quote(v) for v in params]
153+
params = [self.__quote(v) for v in params]
101154
sql = sql % tuple(params)
102155

103-
stmnt_ptr = bytes(get_ptr_size())
104-
res = sqlite3_prepare(self.h, sql, -1, stmnt_ptr, None)
105-
check_error(self.h, res)
106-
self.stmnt = int.from_bytes(stmnt_ptr, sys.byteorder)
107-
self.num_cols = sqlite3_column_count(self.stmnt)
156+
if __is_dml(sql) and self.autocommit == LEGACY_TRANSACTION_CONTROL and sqlite3_get_autocommit(self.db):
157+
# For compatibility with CPython, add functionality for their default transaction
158+
# behavior. Changing autocommit from LEGACY_TRANSACTION_CONTROL will remove this
159+
__exec_stmt(self.db, "BEGIN " + self.isolation_level)
160+
161+
self.stmt = __prepare_stmt(self.db, sql)
162+
self.num_cols = sqlite3_column_count(self.stmt)
108163

109164
if not self.num_cols:
110165
v = self.fetchone()
111166
# If it's not select, actually execute it here
112167
# num_cols == 0 for statements which don't return data (=> modify it)
113168
assert v is None
114-
self.lastrowid = sqlite3_last_insert_rowid(self.h)
169+
self.lastrowid = sqlite3_last_insert_rowid(self.db)
115170

116171
def close(self):
117-
if self.stmnt:
118-
s = sqlite3_finalize(self.stmnt)
119-
check_error(self.h, s)
120-
self.stmnt = None
172+
if self.stmt:
173+
res = sqlite3_finalize(self.stmt)
174+
check_error(self.db, res)
175+
self.stmt = None
121176

122-
def make_row(self):
177+
def __make_row(self):
123178
res = []
124179
for i in range(self.num_cols):
125-
t = sqlite3_column_type(self.stmnt, i)
180+
t = sqlite3_column_type(self.stmt, i)
126181
if t == SQLITE_INTEGER:
127-
res.append(sqlite3_column_int(self.stmnt, i))
182+
res.append(sqlite3_column_int(self.stmt, i))
128183
elif t == SQLITE_FLOAT:
129-
res.append(sqlite3_column_double(self.stmnt, i))
184+
res.append(sqlite3_column_double(self.stmt, i))
130185
elif t == SQLITE_TEXT:
131-
res.append(sqlite3_column_text(self.stmnt, i))
186+
res.append(sqlite3_column_text(self.stmt, i))
132187
else:
133188
raise NotImplementedError
134189
return tuple(res)
135190

136191
def fetchone(self):
137-
res = sqlite3_step(self.stmnt)
192+
res = sqlite3_step(self.stmt)
138193
if res == SQLITE_DONE:
139194
return None
140195
if res == SQLITE_ROW:
141-
return self.make_row()
142-
check_error(self.h, res)
196+
return self.__make_row()
197+
check_error(self.db, res)
198+
143199

200+
def connect(fname, uri=False, isolation_level="", autocommit=LEGACY_TRANSACTION_CONTROL):
201+
if isolation_level not in [None, "", "DEFERRED", "IMMEDIATE", "EXCLUSIVE"]:
202+
raise Error("Invalid option for isolation level")
144203

145-
def connect(fname, uri=False):
146204
sqlite3_config(SQLITE_CONFIG_URI, int(uri))
147205

148206
sqlite_ptr = bytes(get_ptr_size())
149207
sqlite3_open(fname, sqlite_ptr)
150-
return Connections(int.from_bytes(sqlite_ptr, sys.byteorder))
208+
db = int.from_bytes(sqlite_ptr, sys.byteorder)
151209

210+
if autocommit == False:
211+
__exec_stmt(db, "BEGIN")
152212

153-
def quote(val):
154-
if isinstance(val, str):
155-
return "'%s'" % val
156-
return str(val)
213+
return Connections(db, isolation_level, autocommit)

‎unix-ffi/sqlite3/test_sqlite3_3.py

+42
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
import sqlite3
2+
3+
4+
def test_autocommit():
5+
conn = sqlite3.connect(":memory:", autocommit=True)
6+
7+
# First cursor creates table and inserts value (DML)
8+
cur = conn.cursor()
9+
cur.execute("CREATE TABLE foo(a int)")
10+
cur.execute("INSERT INTO foo VALUES (42)")
11+
cur.close()
12+
13+
# Second cursor fetches 42 due to the autocommit
14+
cur = conn.cursor()
15+
cur.execute("SELECT * FROM foo")
16+
assert cur.fetchone() == (42,)
17+
assert cur.fetchone() is None
18+
19+
cur.close()
20+
conn.close()
21+
22+
def test_manual():
23+
conn = sqlite3.connect(":memory:", autocommit=False)
24+
25+
# First cursor creates table, insert rolls back
26+
cur = conn.cursor()
27+
cur.execute("CREATE TABLE foo(a int)")
28+
conn.commit()
29+
cur.execute("INSERT INTO foo VALUES (42)")
30+
cur.close()
31+
conn.rollback()
32+
33+
# Second connection fetches nothing due to the rollback
34+
cur = conn.cursor()
35+
cur.execute("SELECT * FROM foo")
36+
assert cur.fetchone() is None
37+
38+
cur.close()
39+
conn.close()
40+
41+
test_autocommit()
42+
test_manual()

0 commit comments

Comments
 (0)