Skip to content

Commit 30bb3fb

Browse files
feat: improve type information (#176)
Co-authored-by: Tres Seaver <tseaver@palladion.com>
1 parent e8f6c4d commit 30bb3fb

File tree

11 files changed

+74
-48
lines changed

11 files changed

+74
-48
lines changed

‎google/cloud/firestore_v1/_helpers.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
from google.cloud.firestore_v1.types import common
3333
from google.cloud.firestore_v1.types import document
3434
from google.cloud.firestore_v1.types import write
35-
from typing import Any, Generator, List, NoReturn, Optional, Tuple
35+
from typing import Any, Generator, List, NoReturn, Optional, Tuple, Union
3636

3737
_EmptyDict: transforms.Sentinel
3838
_GRPC_ERROR_MAPPING: dict
@@ -69,7 +69,7 @@ def __init__(self, latitude, longitude) -> None:
6969
self.latitude = latitude
7070
self.longitude = longitude
7171

72-
def to_protobuf(self) -> Any:
72+
def to_protobuf(self) -> latlng_pb2.LatLng:
7373
"""Convert the current object to protobuf.
7474
7575
Returns:
@@ -253,7 +253,9 @@ def reference_value_to_document(reference_value, client) -> Any:
253253
return document
254254

255255

256-
def decode_value(value, client) -> Any:
256+
def decode_value(
257+
value, client
258+
) -> Union[None, bool, int, float, list, datetime.datetime, str, bytes, dict, GeoPoint]:
257259
"""Converts a Firestore protobuf ``Value`` to a native Python value.
258260
259261
Args:
@@ -316,7 +318,7 @@ def decode_dict(value_fields, client) -> dict:
316318
return {key: decode_value(value, client) for key, value in value_fields.items()}
317319

318320

319-
def get_doc_id(document_pb, expected_prefix) -> Any:
321+
def get_doc_id(document_pb, expected_prefix) -> str:
320322
"""Parse a document ID from a document protobuf.
321323
322324
Args:
@@ -887,7 +889,7 @@ class ReadAfterWriteError(Exception):
887889
"""
888890

889891

890-
def get_transaction_id(transaction, read_operation=True) -> Any:
892+
def get_transaction_id(transaction, read_operation=True) -> Union[bytes, None]:
891893
"""Get the transaction ID from a ``Transaction`` object.
892894
893895
Args:

‎google/cloud/firestore_v1/async_document.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525

2626
from google.api_core import exceptions # type: ignore
2727
from google.cloud.firestore_v1 import _helpers
28+
from google.cloud.firestore_v1.types import write
29+
from google.protobuf import timestamp_pb2
2830
from typing import Any, AsyncGenerator, Coroutine, Iterable, Union
2931

3032

@@ -61,7 +63,7 @@ async def create(
6163
document_data: dict,
6264
retry: retries.Retry = gapic_v1.method.DEFAULT,
6365
timeout: float = None,
64-
) -> Coroutine:
66+
) -> write.WriteResult:
6567
"""Create the current document in the Firestore database.
6668
6769
Args:
@@ -91,7 +93,7 @@ async def set(
9193
merge: bool = False,
9294
retry: retries.Retry = gapic_v1.method.DEFAULT,
9395
timeout: float = None,
94-
) -> Coroutine:
96+
) -> write.WriteResult:
9597
"""Replace the current document in the Firestore database.
9698
9799
A write ``option`` can be specified to indicate preconditions of
@@ -131,7 +133,7 @@ async def update(
131133
option: _helpers.WriteOption = None,
132134
retry: retries.Retry = gapic_v1.method.DEFAULT,
133135
timeout: float = None,
134-
) -> Coroutine:
136+
) -> write.WriteResult:
135137
"""Update an existing document in the Firestore database.
136138
137139
By default, this method verifies that the document exists on the
@@ -287,7 +289,7 @@ async def delete(
287289
option: _helpers.WriteOption = None,
288290
retry: retries.Retry = gapic_v1.method.DEFAULT,
289291
timeout: float = None,
290-
) -> Coroutine:
292+
) -> timestamp_pb2.Timestamp:
291293
"""Delete the current document in the Firestore database.
292294
293295
Args:

‎google/cloud/firestore_v1/async_transaction.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ async def get_all(
153153
references: list,
154154
retry: retries.Retry = gapic_v1.method.DEFAULT,
155155
timeout: float = None,
156-
) -> Coroutine:
156+
) -> AsyncGenerator[DocumentSnapshot, Any]:
157157
"""Retrieves multiple documents from Firestore.
158158
159159
Args:

‎google/cloud/firestore_v1/base_client.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ def _firestore_api_helper(self, transport, client_class, client_module) -> Any:
166166

167167
return self._firestore_api_internal
168168

169-
def _target_helper(self, client_class) -> Any:
169+
def _target_helper(self, client_class) -> str:
170170
"""Return the target (where the API is).
171171
Eg. "firestore.googleapis.com"
172172
@@ -273,7 +273,7 @@ def _document_path_helper(self, *document_path) -> List[str]:
273273
return joined_path.split(_helpers.DOCUMENT_PATH_DELIMITER)
274274

275275
@staticmethod
276-
def field_path(*field_names: Tuple[str]) -> Any:
276+
def field_path(*field_names: Tuple[str]) -> str:
277277
"""Create a **field path** from a list of nested field names.
278278
279279
A **field path** is a ``.``-delimited concatenation of the field
@@ -438,7 +438,7 @@ def _reference_info(references: list) -> Tuple[list, dict]:
438438
return document_paths, reference_map
439439

440440

441-
def _get_reference(document_path: str, reference_map: dict) -> Any:
441+
def _get_reference(document_path: str, reference_map: dict) -> BaseDocumentReference:
442442
"""Get a document reference from a dictionary.
443443
444444
This just wraps a simple dictionary look-up with a helpful error that is
@@ -536,7 +536,18 @@ def _get_doc_mask(field_paths: Iterable[str]) -> Optional[types.common.DocumentM
536536
return types.DocumentMask(field_paths=field_paths)
537537

538538

539-
def _path_helper(path: tuple) -> Any:
539+
def _item_to_collection_ref(iterator, item: str) -> BaseCollectionReference:
540+
"""Convert collection ID to collection ref.
541+
542+
Args:
543+
iterator (google.api_core.page_iterator.GRPCIterator):
544+
iterator response
545+
item (str): ID of the collection
546+
"""
547+
return iterator.client.collection(item)
548+
549+
550+
def _path_helper(path: tuple) -> Tuple[str]:
540551
"""Standardize path into a tuple of path segments.
541552
542553
Args:

‎google/cloud/firestore_v1/base_collection.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def parent(self):
107107
def _query(self) -> BaseQuery:
108108
raise NotImplementedError
109109

110-
def document(self, document_id: str = None) -> Any:
110+
def document(self, document_id: str = None) -> DocumentReference:
111111
"""Create a sub-document underneath the current collection.
112112
113113
Args:

‎google/cloud/firestore_v1/base_document.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,10 @@
2222
from google.cloud.firestore_v1 import field_path as field_path_module
2323
from google.cloud.firestore_v1.types import common
2424

25-
from typing import Any
26-
from typing import Iterable
27-
from typing import NoReturn
28-
from typing import Tuple
25+
# Types needed only for Type Hints
26+
from google.cloud.firestore_v1.types import firestore
27+
from google.cloud.firestore_v1.types import write
28+
from typing import Any, Dict, Iterable, NoReturn, Union, Tuple
2929

3030

3131
class BaseDocumentReference(object):
@@ -475,7 +475,7 @@ def get(self, field_path: str) -> Any:
475475
nested_data = field_path_module.get_nested_value(field_path, self._data)
476476
return copy.deepcopy(nested_data)
477477

478-
def to_dict(self) -> Any:
478+
def to_dict(self) -> Union[Dict[str, Any], None]:
479479
"""Retrieve the data contained in this snapshot.
480480
481481
A copy is returned since the data may contain mutable values,
@@ -512,7 +512,7 @@ def _get_document_path(client, path: Tuple[str]) -> str:
512512
return _helpers.DOCUMENT_PATH_DELIMITER.join(parts)
513513

514514

515-
def _consume_single_get(response_iterator) -> Any:
515+
def _consume_single_get(response_iterator) -> firestore.BatchGetDocumentsResponse:
516516
"""Consume a gRPC stream that should contain a single response.
517517
518518
The stream will correspond to a ``BatchGetDocuments`` request made
@@ -543,7 +543,7 @@ def _consume_single_get(response_iterator) -> Any:
543543
return all_responses[0]
544544

545545

546-
def _first_write_result(write_results: list) -> Any:
546+
def _first_write_result(write_results: list) -> write.WriteResult:
547547
"""Get first write result from list.
548548
549549
For cases where ``len(write_results) > 1``, this assumes the writes

‎google/cloud/firestore_v1/base_query.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -314,7 +314,7 @@ def where(self, field_path: str, op_string: str, value) -> "BaseQuery":
314314
)
315315

316316
@staticmethod
317-
def _make_order(field_path, direction) -> Any:
317+
def _make_order(field_path, direction) -> StructuredQuery.Order:
318318
"""Helper for :meth:`order_by`."""
319319
return query.StructuredQuery.Order(
320320
field=query.StructuredQuery.FieldReference(field_path=field_path),
@@ -394,7 +394,7 @@ def limit(self, count: int) -> "BaseQuery":
394394
all_descendants=self._all_descendants,
395395
)
396396

397-
def limit_to_last(self, count: int):
397+
def limit_to_last(self, count: int) -> "BaseQuery":
398398
"""Limit a query to return the last `count` matching results.
399399
If the current query already has a `limit_to_last`
400400
set, this will override it.
@@ -651,7 +651,7 @@ def end_at(
651651
document_fields_or_snapshot, before=False, start=False
652652
)
653653

654-
def _filters_pb(self) -> Any:
654+
def _filters_pb(self) -> StructuredQuery.Filter:
655655
"""Convert all the filters into a single generic Filter protobuf.
656656
657657
This may be a lone field filter or unary filter, may be a composite
@@ -674,7 +674,7 @@ def _filters_pb(self) -> Any:
674674
return query.StructuredQuery.Filter(composite_filter=composite_filter)
675675

676676
@staticmethod
677-
def _normalize_projection(projection) -> Any:
677+
def _normalize_projection(projection) -> StructuredQuery.Projection:
678678
"""Helper: convert field paths to message."""
679679
if projection is not None:
680680

@@ -836,7 +836,7 @@ def stream(
836836
def on_snapshot(self, callback) -> NoReturn:
837837
raise NotImplementedError
838838

839-
def _comparator(self, doc1, doc2) -> Any:
839+
def _comparator(self, doc1, doc2) -> int:
840840
_orders = self._orders
841841

842842
# Add implicit sorting by name, using the last specified direction.
@@ -883,7 +883,7 @@ def _comparator(self, doc1, doc2) -> Any:
883883
return 0
884884

885885

886-
def _enum_from_op_string(op_string: str) -> Any:
886+
def _enum_from_op_string(op_string: str) -> int:
887887
"""Convert a string representation of a binary operator to an enum.
888888
889889
These enums come from the protobuf message definition
@@ -926,7 +926,7 @@ def _isnan(value) -> bool:
926926
return False
927927

928928

929-
def _enum_from_direction(direction: str) -> Any:
929+
def _enum_from_direction(direction: str) -> int:
930930
"""Convert a string representation of a direction to an enum.
931931
932932
Args:
@@ -954,7 +954,7 @@ def _enum_from_direction(direction: str) -> Any:
954954
raise ValueError(msg)
955955

956956

957-
def _filter_pb(field_or_unary) -> Any:
957+
def _filter_pb(field_or_unary) -> StructuredQuery.Filter:
958958
"""Convert a specific protobuf filter to the generic filter type.
959959
960960
Args:

‎google/cloud/firestore_v1/client.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,9 @@
4646
)
4747
from typing import Any, Generator, Iterable, Tuple
4848

49+
# Types needed only for Type Hints
50+
from google.cloud.firestore_v1.base_document import DocumentSnapshot
51+
4952

5053
class Client(BaseClient):
5154
"""Client for interacting with Google Cloud Firestore API.
@@ -209,7 +212,7 @@ def get_all(
209212
transaction: Transaction = None,
210213
retry: retries.Retry = gapic_v1.method.DEFAULT,
211214
timeout: float = None,
212-
) -> Generator[Any, Any, None]:
215+
) -> Generator[DocumentSnapshot, Any, None]:
213216
"""Retrieve a batch of documents.
214217
215218
.. note::

‎google/cloud/firestore_v1/document.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,9 @@
2525

2626
from google.api_core import exceptions # type: ignore
2727
from google.cloud.firestore_v1 import _helpers
28+
from google.cloud.firestore_v1.types import write
2829
from google.cloud.firestore_v1.watch import Watch
30+
from google.protobuf import timestamp_pb2
2931
from typing import Any, Callable, Generator, Iterable
3032

3133

@@ -62,7 +64,7 @@ def create(
6264
document_data: dict,
6365
retry: retries.Retry = gapic_v1.method.DEFAULT,
6466
timeout: float = None,
65-
) -> Any:
67+
) -> write.WriteResult:
6668
"""Create the current document in the Firestore database.
6769
6870
Args:
@@ -92,7 +94,7 @@ def set(
9294
merge: bool = False,
9395
retry: retries.Retry = gapic_v1.method.DEFAULT,
9496
timeout: float = None,
95-
) -> Any:
97+
) -> write.WriteResult:
9698
"""Replace the current document in the Firestore database.
9799
98100
A write ``option`` can be specified to indicate preconditions of
@@ -132,7 +134,7 @@ def update(
132134
option: _helpers.WriteOption = None,
133135
retry: retries.Retry = gapic_v1.method.DEFAULT,
134136
timeout: float = None,
135-
) -> Any:
137+
) -> write.WriteResult:
136138
"""Update an existing document in the Firestore database.
137139
138140
By default, this method verifies that the document exists on the
@@ -288,7 +290,7 @@ def delete(
288290
option: _helpers.WriteOption = None,
289291
retry: retries.Retry = gapic_v1.method.DEFAULT,
290292
timeout: float = None,
291-
) -> Any:
293+
) -> timestamp_pb2.Timestamp:
292294
"""Delete the current document in the Firestore database.
293295
294296
Args:
@@ -339,7 +341,7 @@ def get(
339341
transaction (Optional[:class:`~google.cloud.firestore_v1.transaction.Transaction`]):
340342
An existing transaction that this reference
341343
will be retrieved in.
342-
retry (google.api_core.retry.Retry): Designation of what errors, if any,
344+
retry (google.api_core.retry.Retry): Designation of what errors, if an y,
343345
should be retried. Defaults to a system-specified policy.
344346
timeout (float): The timeout for this request. Defaults to a
345347
system-specified value.

0 commit comments

Comments
 (0)