Skip to content

Commit d6963e2

Browse files
authored
feat: enable instance-level connection (#931)
1 parent 6c8672b commit d6963e2

File tree

4 files changed

+102
-9
lines changed

4 files changed

+102
-9
lines changed

‎google/cloud/spanner_dbapi/connection.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ class Connection:
8383
should end a that a new one should be started when the next statement is executed.
8484
"""
8585

86-
def __init__(self, instance, database, read_only=False):
86+
def __init__(self, instance, database=None, read_only=False):
8787
self._instance = instance
8888
self._database = database
8989
self._ddl_statements = []
@@ -242,6 +242,8 @@ def _session_checkout(self):
242242
:rtype: :class:`google.cloud.spanner_v1.session.Session`
243243
:returns: Cloud Spanner session object ready to use.
244244
"""
245+
if self.database is None:
246+
raise ValueError("Database needs to be passed for this operation")
245247
if not self._session:
246248
self._session = self.database._pool.get()
247249

@@ -252,6 +254,8 @@ def _release_session(self):
252254
253255
The session will be returned into the sessions pool.
254256
"""
257+
if self.database is None:
258+
raise ValueError("Database needs to be passed for this operation")
255259
self.database._pool.put(self._session)
256260
self._session = None
257261

@@ -368,7 +372,7 @@ def close(self):
368372
if self.inside_transaction:
369373
self._transaction.rollback()
370374

371-
if self._own_pool:
375+
if self._own_pool and self.database:
372376
self.database._pool.clear()
373377

374378
self.is_closed = True
@@ -378,6 +382,8 @@ def commit(self):
378382
379383
This method is non-operational in autocommit mode.
380384
"""
385+
if self.database is None:
386+
raise ValueError("Database needs to be passed for this operation")
381387
self._snapshot = None
382388

383389
if self._autocommit:
@@ -420,6 +426,8 @@ def cursor(self):
420426

421427
@check_not_closed
422428
def run_prior_DDL_statements(self):
429+
if self.database is None:
430+
raise ValueError("Database needs to be passed for this operation")
423431
if self._ddl_statements:
424432
ddl_statements = self._ddl_statements
425433
self._ddl_statements = []
@@ -474,6 +482,8 @@ def validate(self):
474482
:raises: :class:`google.cloud.exceptions.NotFound`: if the linked instance
475483
or database doesn't exist.
476484
"""
485+
if self.database is None:
486+
raise ValueError("Database needs to be passed for this operation")
477487
with self.database.snapshot() as snapshot:
478488
result = list(snapshot.execute_sql("SELECT 1"))
479489
if result != [[1]]:
@@ -492,7 +502,7 @@ def __exit__(self, etype, value, traceback):
492502

493503
def connect(
494504
instance_id,
495-
database_id,
505+
database_id=None,
496506
project=None,
497507
credentials=None,
498508
pool=None,
@@ -505,7 +515,7 @@ def connect(
505515
:param instance_id: The ID of the instance to connect to.
506516
507517
:type database_id: str
508-
:param database_id: The ID of the database to connect to.
518+
:param database_id: (Optional) The ID of the database to connect to.
509519
510520
:type project: str
511521
:param project: (Optional) The ID of the project which owns the
@@ -557,7 +567,9 @@ def connect(
557567
raise ValueError("project in url does not match client object project")
558568

559569
instance = client.instance(instance_id)
560-
conn = Connection(instance, instance.database(database_id, pool=pool))
570+
conn = Connection(
571+
instance, instance.database(database_id, pool=pool) if database_id else None
572+
)
561573
if pool is not None:
562574
conn._own_pool = False
563575

‎google/cloud/spanner_dbapi/cursor.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,8 @@ def execute(self, sql, args=None):
228228
:type args: list
229229
:param args: Additional parameters to supplement the SQL query.
230230
"""
231+
if self.connection.database is None:
232+
raise ValueError("Database needs to be passed for this operation")
231233
self._itr = None
232234
self._result_set = None
233235
self._row_count = _UNSET_COUNT
@@ -301,6 +303,8 @@ def executemany(self, operation, seq_of_params):
301303
:param seq_of_params: Sequence of additional parameters to run
302304
the query with.
303305
"""
306+
if self.connection.database is None:
307+
raise ValueError("Database needs to be passed for this operation")
304308
self._itr = None
305309
self._result_set = None
306310
self._row_count = _UNSET_COUNT
@@ -444,6 +448,8 @@ def _handle_DQL_with_snapshot(self, snapshot, sql, params):
444448
self._row_count = _UNSET_COUNT
445449

446450
def _handle_DQL(self, sql, params):
451+
if self.connection.database is None:
452+
raise ValueError("Database needs to be passed for this operation")
447453
sql, params = parse_utils.sql_pyformat_args_to_spanner(sql, params)
448454
if self.connection.read_only and not self.connection.autocommit:
449455
# initiate or use the existing multi-use snapshot
@@ -484,6 +490,8 @@ def list_tables(self):
484490
def run_sql_in_snapshot(self, sql, params=None, param_types=None):
485491
# Some SQL e.g. for INFORMATION_SCHEMA cannot be run in read-write transactions
486492
# hence this method exists to circumvent that limit.
493+
if self.connection.database is None:
494+
raise ValueError("Database needs to be passed for this operation")
487495
self.connection.run_prior_DDL_statements()
488496

489497
with self.connection.database.snapshot() as snapshot:

‎tests/unit/spanner_dbapi/test_connection.py

Lines changed: 46 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,14 @@ def test__session_checkout(self, mock_database):
169169
connection._session_checkout()
170170
self.assertEqual(connection._session, "db_session")
171171

172+
def test__session_checkout_database_error(self):
173+
from google.cloud.spanner_dbapi import Connection
174+
175+
connection = Connection(INSTANCE)
176+
177+
with pytest.raises(ValueError):
178+
connection._session_checkout()
179+
172180
@mock.patch("google.cloud.spanner_v1.database.Database")
173181
def test__release_session(self, mock_database):
174182
from google.cloud.spanner_dbapi import Connection
@@ -182,6 +190,13 @@ def test__release_session(self, mock_database):
182190
pool.put.assert_called_once_with("session")
183191
self.assertIsNone(connection._session)
184192

193+
def test__release_session_database_error(self):
194+
from google.cloud.spanner_dbapi import Connection
195+
196+
connection = Connection(INSTANCE)
197+
with pytest.raises(ValueError):
198+
connection._release_session()
199+
185200
def test_transaction_checkout(self):
186201
from google.cloud.spanner_dbapi import Connection
187202

@@ -294,6 +309,14 @@ def test_commit(self, mock_warn):
294309
AUTOCOMMIT_MODE_WARNING, UserWarning, stacklevel=2
295310
)
296311

312+
def test_commit_database_error(self):
313+
from google.cloud.spanner_dbapi import Connection
314+
315+
connection = Connection(INSTANCE)
316+
317+
with pytest.raises(ValueError):
318+
connection.commit()
319+
297320
@mock.patch.object(warnings, "warn")
298321
def test_rollback(self, mock_warn):
299322
from google.cloud.spanner_dbapi import Connection
@@ -347,6 +370,13 @@ def test_run_prior_DDL_statements(self, mock_database):
347370
with self.assertRaises(InterfaceError):
348371
connection.run_prior_DDL_statements()
349372

373+
def test_run_prior_DDL_statements_database_error(self):
374+
from google.cloud.spanner_dbapi import Connection
375+
376+
connection = Connection(INSTANCE)
377+
with pytest.raises(ValueError):
378+
connection.run_prior_DDL_statements()
379+
350380
def test_as_context_manager(self):
351381
connection = self._make_connection()
352382
with connection as conn:
@@ -766,6 +796,14 @@ def test_validate_error(self):
766796

767797
snapshot_obj.execute_sql.assert_called_once_with("SELECT 1")
768798

799+
def test_validate_database_error(self):
800+
from google.cloud.spanner_dbapi import Connection
801+
802+
connection = Connection(INSTANCE)
803+
804+
with pytest.raises(ValueError):
805+
connection.validate()
806+
769807
def test_validate_closed(self):
770808
from google.cloud.spanner_dbapi.exceptions import InterfaceError
771809

@@ -916,16 +954,14 @@ def test_request_priority(self):
916954
sql, params, param_types=param_types, request_options=None
917955
)
918956

919-
@mock.patch("google.cloud.spanner_v1.Client")
920-
def test_custom_client_connection(self, mock_client):
957+
def test_custom_client_connection(self):
921958
from google.cloud.spanner_dbapi import connect
922959

923960
client = _Client()
924961
connection = connect("test-instance", "test-database", client=client)
925962
self.assertTrue(connection.instance._client == client)
926963

927-
@mock.patch("google.cloud.spanner_v1.Client")
928-
def test_invalid_custom_client_connection(self, mock_client):
964+
def test_invalid_custom_client_connection(self):
929965
from google.cloud.spanner_dbapi import connect
930966

931967
client = _Client()
@@ -937,6 +973,12 @@ def test_invalid_custom_client_connection(self, mock_client):
937973
client=client,
938974
)
939975

976+
def test_connection_wo_database(self):
977+
from google.cloud.spanner_dbapi import connect
978+
979+
connection = connect("test-instance")
980+
self.assertTrue(connection.database is None)
981+
940982

941983
def exit_ctx_func(self, exc_type, exc_value, traceback):
942984
"""Context __exit__ method mock."""

‎tests/unit/spanner_dbapi/test_cursor.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,13 @@ def test_execute_attribute_error(self):
163163
with self.assertRaises(AttributeError):
164164
cursor.execute(sql="SELECT 1")
165165

166+
def test_execute_database_error(self):
167+
connection = self._make_connection(self.INSTANCE)
168+
cursor = self._make_one(connection)
169+
170+
with self.assertRaises(ValueError):
171+
cursor.execute(sql="SELECT 1")
172+
166173
def test_execute_autocommit_off(self):
167174
from google.cloud.spanner_dbapi.utils import PeekIterator
168175

@@ -607,6 +614,16 @@ def test_executemany_insert_batch_aborted(self):
607614
)
608615
self.assertIsInstance(connection._statements[0][1], ResultsChecksum)
609616

617+
@mock.patch("google.cloud.spanner_v1.Client")
618+
def test_executemany_database_error(self, mock_client):
619+
from google.cloud.spanner_dbapi import connect
620+
621+
connection = connect("test-instance")
622+
cursor = connection.cursor()
623+
624+
with self.assertRaises(ValueError):
625+
cursor.executemany("""SELECT * FROM table1 WHERE "col1" = @a1""", ())
626+
610627
@unittest.skipIf(
611628
sys.version_info[0] < 3, "Python 2 has an outdated iterator definition"
612629
)
@@ -754,6 +771,13 @@ def test_handle_dql_priority(self):
754771
sql, None, None, request_options=RequestOptions(priority=1)
755772
)
756773

774+
def test_handle_dql_database_error(self):
775+
connection = self._make_connection(self.INSTANCE)
776+
cursor = self._make_one(connection)
777+
778+
with self.assertRaises(ValueError):
779+
cursor._handle_DQL("sql", params=None)
780+
757781
def test_context(self):
758782
connection = self._make_connection(self.INSTANCE, self.DATABASE)
759783
cursor = self._make_one(connection)
@@ -814,6 +838,13 @@ def test_run_sql_in_snapshot(self):
814838
mock_snapshot.execute_sql.return_value = results
815839
self.assertEqual(cursor.run_sql_in_snapshot("sql"), list(results))
816840

841+
def test_run_sql_in_snapshot_database_error(self):
842+
connection = self._make_connection(self.INSTANCE)
843+
cursor = self._make_one(connection)
844+
845+
with self.assertRaises(ValueError):
846+
cursor.run_sql_in_snapshot("sql")
847+
817848
def test_get_table_column_schema(self):
818849
from google.cloud.spanner_dbapi.cursor import ColumnDetails
819850
from google.cloud.spanner_dbapi import _helpers

0 commit comments

Comments
 (0)