Skip to content

Commit 758bf48

Browse files
authored
fix: dbapi raised AttributeError with [] as arguments (#1257)
If the cursor.execute(sql, args) function was called with an empty array instead of None, it would raise an AttributeError like this: AttributeError: 'list' object has no attribute 'items' This is for example automatically done by SQLAlchemy when executing a raw statement on a dbapi connection.
1 parent a81af3b commit 758bf48

File tree

2 files changed

+46
-1
lines changed

2 files changed

+46
-1
lines changed

‎google/cloud/spanner_v1/transaction.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,7 @@ def _make_params_pb(params, param_types):
308308
:raises ValueError:
309309
If ``params`` is None but ``param_types`` is not None.
310310
"""
311-
if params is not None:
311+
if params:
312312
return Struct(
313313
fields={key: _make_value_pb(value) for key, value in params.items()}
314314
)

‎tests/mockserver_tests/test_basics.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
import unittest
1616

1717
from google.cloud.spanner_admin_database_v1.types import spanner_database_admin
18+
from google.cloud.spanner_dbapi import Connection
19+
from google.cloud.spanner_dbapi.parsed_statement import AutocommitDmlMode
1820
from google.cloud.spanner_v1.testing.mock_database_admin import DatabaseAdminServicer
1921
from google.cloud.spanner_v1.testing.mock_spanner import (
2022
start_mock_server,
@@ -29,6 +31,8 @@
2931
FixedSizePool,
3032
BatchCreateSessionsRequest,
3133
ExecuteSqlRequest,
34+
BeginTransactionRequest,
35+
TransactionOptions,
3236
)
3337
from google.cloud.spanner_v1.database import Database
3438
from google.cloud.spanner_v1.instance import Instance
@@ -62,6 +66,10 @@ def tearDownClass(cls):
6266
TestBasics.server.stop(grace=None)
6367
TestBasics.server = None
6468

69+
def teardown_method(self, *args, **kwargs):
70+
TestBasics.spanner_service.clear_requests()
71+
TestBasics.database_admin_service.clear_requests()
72+
6573
def _add_select1_result(self):
6674
result = result_set.ResultSet(
6775
dict(
@@ -88,6 +96,19 @@ def _add_select1_result(self):
8896
result.rows.extend(["1"])
8997
TestBasics.spanner_service.mock_spanner.add_result("select 1", result)
9098

99+
def add_update_count(
100+
self,
101+
sql: str,
102+
count: int,
103+
dml_mode: AutocommitDmlMode = AutocommitDmlMode.TRANSACTIONAL,
104+
):
105+
if dml_mode == AutocommitDmlMode.PARTITIONED_NON_ATOMIC:
106+
stats = dict(row_count_lower_bound=count)
107+
else:
108+
stats = dict(row_count_exact=count)
109+
result = result_set.ResultSet(dict(stats=result_set.ResultSetStats(stats)))
110+
TestBasics.spanner_service.mock_spanner.add_result(sql, result)
111+
91112
@property
92113
def client(self) -> Client:
93114
if self._client is None:
@@ -145,3 +166,27 @@ def test_create_table(self):
145166
)
146167
operation = database_admin_api.update_database_ddl(request)
147168
operation.result(1)
169+
170+
# TODO: Move this to a separate class once the mock server test setup has
171+
# been re-factored to use a base class for the boiler plate code.
172+
def test_dbapi_partitioned_dml(self):
173+
sql = "UPDATE singers SET foo='bar' WHERE active = true"
174+
self.add_update_count(sql, 100, AutocommitDmlMode.PARTITIONED_NON_ATOMIC)
175+
connection = Connection(self.instance, self.database)
176+
connection.autocommit = True
177+
connection.set_autocommit_dml_mode(AutocommitDmlMode.PARTITIONED_NON_ATOMIC)
178+
with connection.cursor() as cursor:
179+
# Note: SQLAlchemy uses [] as the list of parameters for statements
180+
# with no parameters.
181+
cursor.execute(sql, [])
182+
self.assertEqual(100, cursor.rowcount)
183+
184+
requests = self.spanner_service.requests
185+
self.assertEqual(3, len(requests), msg=requests)
186+
self.assertTrue(isinstance(requests[0], BatchCreateSessionsRequest))
187+
self.assertTrue(isinstance(requests[1], BeginTransactionRequest))
188+
self.assertTrue(isinstance(requests[2], ExecuteSqlRequest))
189+
begin_request: BeginTransactionRequest = requests[1]
190+
self.assertEqual(
191+
TransactionOptions(dict(partitioned_dml={})), begin_request.options
192+
)

0 commit comments

Comments
 (0)