15
15
import unittest
16
16
17
17
from google .cloud .spanner_admin_database_v1 .types import spanner_database_admin
18
+ from google .cloud .spanner_dbapi import Connection
19
+ from google .cloud .spanner_dbapi .parsed_statement import AutocommitDmlMode
18
20
from google .cloud .spanner_v1 .testing .mock_database_admin import DatabaseAdminServicer
19
21
from google .cloud .spanner_v1 .testing .mock_spanner import (
20
22
start_mock_server ,
29
31
FixedSizePool ,
30
32
BatchCreateSessionsRequest ,
31
33
ExecuteSqlRequest ,
34
+ BeginTransactionRequest ,
35
+ TransactionOptions ,
32
36
)
33
37
from google .cloud .spanner_v1 .database import Database
34
38
from google .cloud .spanner_v1 .instance import Instance
@@ -62,6 +66,10 @@ def tearDownClass(cls):
62
66
TestBasics .server .stop (grace = None )
63
67
TestBasics .server = None
64
68
69
+ def teardown_method (self , * args , ** kwargs ):
70
+ TestBasics .spanner_service .clear_requests ()
71
+ TestBasics .database_admin_service .clear_requests ()
72
+
65
73
def _add_select1_result (self ):
66
74
result = result_set .ResultSet (
67
75
dict (
@@ -88,6 +96,19 @@ def _add_select1_result(self):
88
96
result .rows .extend (["1" ])
89
97
TestBasics .spanner_service .mock_spanner .add_result ("select 1" , result )
90
98
99
+ def add_update_count (
100
+ self ,
101
+ sql : str ,
102
+ count : int ,
103
+ dml_mode : AutocommitDmlMode = AutocommitDmlMode .TRANSACTIONAL ,
104
+ ):
105
+ if dml_mode == AutocommitDmlMode .PARTITIONED_NON_ATOMIC :
106
+ stats = dict (row_count_lower_bound = count )
107
+ else :
108
+ stats = dict (row_count_exact = count )
109
+ result = result_set .ResultSet (dict (stats = result_set .ResultSetStats (stats )))
110
+ TestBasics .spanner_service .mock_spanner .add_result (sql , result )
111
+
91
112
@property
92
113
def client (self ) -> Client :
93
114
if self ._client is None :
@@ -145,3 +166,27 @@ def test_create_table(self):
145
166
)
146
167
operation = database_admin_api .update_database_ddl (request )
147
168
operation .result (1 )
169
+
170
+ # TODO: Move this to a separate class once the mock server test setup has
171
+ # been re-factored to use a base class for the boiler plate code.
172
+ def test_dbapi_partitioned_dml (self ):
173
+ sql = "UPDATE singers SET foo='bar' WHERE active = true"
174
+ self .add_update_count (sql , 100 , AutocommitDmlMode .PARTITIONED_NON_ATOMIC )
175
+ connection = Connection (self .instance , self .database )
176
+ connection .autocommit = True
177
+ connection .set_autocommit_dml_mode (AutocommitDmlMode .PARTITIONED_NON_ATOMIC )
178
+ with connection .cursor () as cursor :
179
+ # Note: SQLAlchemy uses [] as the list of parameters for statements
180
+ # with no parameters.
181
+ cursor .execute (sql , [])
182
+ self .assertEqual (100 , cursor .rowcount )
183
+
184
+ requests = self .spanner_service .requests
185
+ self .assertEqual (3 , len (requests ), msg = requests )
186
+ self .assertTrue (isinstance (requests [0 ], BatchCreateSessionsRequest ))
187
+ self .assertTrue (isinstance (requests [1 ], BeginTransactionRequest ))
188
+ self .assertTrue (isinstance (requests [2 ], ExecuteSqlRequest ))
189
+ begin_request : BeginTransactionRequest = requests [1 ]
190
+ self .assertEqual (
191
+ TransactionOptions (dict (partitioned_dml = {})), begin_request .options
192
+ )
0 commit comments