Skip to content

Commit 9095368

Browse files
authored
fix: asyncio microgen client get_all type (#126)
* feat: create AsyncIter class for mocking * fix: type error on mocked return on batch_get_documents
1 parent a4e5b00 commit 9095368

File tree

4 files changed

+17
-17
lines changed

4 files changed

+17
-17
lines changed

‎google/cloud/firestore_v1/async_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,7 @@ async def get_all(self, references, field_paths=None, transaction=None):
252252
metadata=self._rpc_metadata,
253253
)
254254

255-
for get_doc_response in response_iterator:
255+
async for get_doc_response in response_iterator:
256256
yield _parse_batch_get(get_doc_response, reference_map, self)
257257

258258
async def collections(self):

‎tests/unit/v1/test__helpers.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,15 @@ async def __call__(self, *args, **kwargs):
2525
return super(AsyncMock, self).__call__(*args, **kwargs)
2626

2727

28+
class AsyncIter:
29+
def __init__(self, items):
30+
self.items = items
31+
32+
async def __aiter__(self, **_):
33+
for i in self.items:
34+
yield i
35+
36+
2837
class TestGeoPoint(unittest.TestCase):
2938
@staticmethod
3039
def _get_target_class():

‎tests/unit/v1/test_async_client.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import aiounittest
1919

2020
import mock
21-
from tests.unit.v1.test__helpers import AsyncMock
21+
from tests.unit.v1.test__helpers import AsyncMock, AsyncIter
2222

2323

2424
class TestAsyncClient(aiounittest.AsyncTestCase):
@@ -237,7 +237,7 @@ def _next_page(self):
237237
async def _get_all_helper(self, client, references, document_pbs, **kwargs):
238238
# Create a minimal fake GAPIC with a dummy response.
239239
firestore_api = mock.Mock(spec=["batch_get_documents"])
240-
response_iterator = iter(document_pbs)
240+
response_iterator = AsyncIter(document_pbs)
241241
firestore_api.batch_get_documents.return_value = response_iterator
242242

243243
# Attach the fake GAPIC to a real client.

‎tests/unit/v1/test_async_collection.py

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,7 @@
1717
import aiounittest
1818

1919
import mock
20-
from tests.unit.v1.test__helpers import AsyncMock
21-
22-
23-
class MockAsyncIter:
24-
def __init__(self, count):
25-
self.count = count
26-
27-
async def __aiter__(self, **_):
28-
for i in range(self.count):
29-
yield i
20+
from tests.unit.v1.test__helpers import AsyncMock, AsyncIter
3021

3122

3223
class TestAsyncCollectionReference(aiounittest.AsyncTestCase):
@@ -258,7 +249,7 @@ async def test_list_documents_w_page_size(self):
258249
async def test_get(self, query_class):
259250
import warnings
260251

261-
query_class.return_value.stream.return_value = MockAsyncIter(3)
252+
query_class.return_value.stream.return_value = AsyncIter(range(3))
262253

263254
collection = self._make_one("collection")
264255
with warnings.catch_warnings(record=True) as warned:
@@ -280,7 +271,7 @@ async def test_get(self, query_class):
280271
async def test_get_with_transaction(self, query_class):
281272
import warnings
282273

283-
query_class.return_value.stream.return_value = MockAsyncIter(3)
274+
query_class.return_value.stream.return_value = AsyncIter(range(3))
284275

285276
collection = self._make_one("collection")
286277
transaction = mock.sentinel.txn
@@ -301,7 +292,7 @@ async def test_get_with_transaction(self, query_class):
301292
@mock.patch("google.cloud.firestore_v1.async_query.AsyncQuery", autospec=True)
302293
@pytest.mark.asyncio
303294
async def test_stream(self, query_class):
304-
query_class.return_value.stream.return_value = MockAsyncIter(3)
295+
query_class.return_value.stream.return_value = AsyncIter(range(3))
305296

306297
collection = self._make_one("collection")
307298
stream_response = collection.stream()
@@ -316,7 +307,7 @@ async def test_stream(self, query_class):
316307
@mock.patch("google.cloud.firestore_v1.async_query.AsyncQuery", autospec=True)
317308
@pytest.mark.asyncio
318309
async def test_stream_with_transaction(self, query_class):
319-
query_class.return_value.stream.return_value = MockAsyncIter(3)
310+
query_class.return_value.stream.return_value = AsyncIter(range(3))
320311

321312
collection = self._make_one("collection")
322313
transaction = mock.sentinel.txn

0 commit comments

Comments
 (0)