Skip to content

Commit 623f32a

Browse files
committed
feat: Add support for Proto and Enum types
1 parent 451fd97 commit 623f32a

File tree

17 files changed

+1293
-64
lines changed

17 files changed

+1293
-64
lines changed

google/cloud/bigtable/data/_async/client.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -657,6 +657,7 @@ async def execute_query(
657657
DeadlineExceeded,
658658
ServiceUnavailable,
659659
),
660+
column_info: dict[str, Any] | None = None,
660661
) -> "ExecuteQueryIteratorAsync":
661662
"""
662663
Executes an SQL query on an instance.
@@ -705,6 +706,13 @@ async def execute_query(
705706
If None, defaults to prepare_operation_timeout.
706707
prepare_retryable_errors: a list of errors that will be retried if encountered during prepareQuery.
707708
Defaults to 4 (DeadlineExceeded) and 14 (ServiceUnavailable)
709+
column_info: Dictionary with mappings between column names and additional column information.
710+
An object where column names as keys and custom objects as corresponding
711+
values for deserialization. It's specifically useful for data types like
712+
protobuf where deserialization logic is on user-specific code. When provided,
713+
the custom object enables deserialization of backend-received column data.
714+
If not provided, data remains serialized as bytes for Proto Messages and
715+
integer for Proto Enums.
708716
Returns:
709717
ExecuteQueryIteratorAsync: an asynchronous iterator that yields rows returned by the query
710718
Raises:
@@ -771,6 +779,7 @@ async def execute_query(
771779
attempt_timeout,
772780
operation_timeout,
773781
retryable_excs=retryable_excs,
782+
column_info=column_info,
774783
)
775784

776785
@CrossSync.convert(sync_name="__enter__")

google/cloud/bigtable/data/_sync_autogen/client.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -485,6 +485,7 @@ def execute_query(
485485
DeadlineExceeded,
486486
ServiceUnavailable,
487487
),
488+
column_info: dict[str, Any] | None = None,
488489
) -> "ExecuteQueryIterator":
489490
"""Executes an SQL query on an instance.
490491
Returns an iterator to asynchronously stream back columns from selected rows.
@@ -532,6 +533,13 @@ def execute_query(
532533
If None, defaults to prepare_operation_timeout.
533534
prepare_retryable_errors: a list of errors that will be retried if encountered during prepareQuery.
534535
Defaults to 4 (DeadlineExceeded) and 14 (ServiceUnavailable)
536+
column_info: Dictionary with mappings between column names and additional column information.
537+
An object where column names as keys and custom objects as corresponding
538+
values for deserialization. It's specifically useful for data types like
539+
protobuf where deserialization logic is on user-specific code. When provided,
540+
the custom object enables deserialization of backend-received column data.
541+
If not provided, data remains serialized as bytes for Proto Messages and
542+
integer for Proto Enums.
535543
Returns:
536544
ExecuteQueryIterator: an asynchronous iterator that yields rows returned by the query
537545
Raises:
@@ -592,6 +600,7 @@ def execute_query(
592600
attempt_timeout,
593601
operation_timeout,
594602
retryable_excs=retryable_excs,
603+
column_info=column_info,
595604
)
596605

597606
def __enter__(self):

google/cloud/bigtable/data/execute_query/_async/execute_query_iterator.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ def __init__(
8787
operation_timeout: float,
8888
req_metadata: Sequence[Tuple[str, str]] = (),
8989
retryable_excs: Sequence[type[Exception]] = (),
90+
column_info: Dict[str, Any] | None = None,
9091
) -> None:
9192
"""
9293
Collects responses from ExecuteQuery requests and parses them into QueryResultRows.
@@ -107,6 +108,8 @@ def __init__(
107108
Failed requests will be retried within the budget
108109
req_metadata: metadata used while sending the gRPC request
109110
retryable_excs: a list of errors that will be retried if encountered.
111+
column_info: dict with mappings between column names and additional column information
112+
for protobuf deserialization.
110113
Raises:
111114
{NO_LOOP}
112115
:class:`ValueError <exceptions.ValueError>` as a safeguard if data is processed in an unexpected state
@@ -135,6 +138,7 @@ def __init__(
135138
exception_factory=_retry_exception_factory,
136139
)
137140
self._req_metadata = req_metadata
141+
self._column_info = column_info
138142
try:
139143
self._register_instance_task = CrossSync.create_task(
140144
self._client._register_instance,
@@ -202,7 +206,9 @@ async def _next_impl(self) -> CrossSync.Iterator[QueryResultRow]:
202206
raise ValueError(
203207
"Error parsing response before finalizing metadata"
204208
)
205-
results = self._reader.consume(batches_to_parse, self.metadata)
209+
results = self._reader.consume(
210+
batches_to_parse, self.metadata, self._column_info
211+
)
206212
if results is None:
207213
continue
208214

google/cloud/bigtable/data/execute_query/_query_result_parsing_utils.py

Lines changed: 93 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,12 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
from __future__ import annotations
1415

1516
from typing import Any, Callable, Dict, Type
17+
18+
from google.protobuf.message import Message
19+
from google.protobuf.internal.enum_type_wrapper import EnumTypeWrapper
1620
from google.cloud.bigtable.data.execute_query.values import Struct
1721
from google.cloud.bigtable.data.execute_query.metadata import SqlType
1822
from google.cloud.bigtable_v2 import Value as PBValue
@@ -30,24 +34,36 @@
3034
SqlType.Struct: "array_value",
3135
SqlType.Array: "array_value",
3236
SqlType.Map: "array_value",
37+
SqlType.Proto: "bytes_value",
38+
SqlType.Enum: "int_value",
3339
}
3440

3541

36-
def _parse_array_type(value: PBValue, metadata_type: SqlType.Array) -> Any:
42+
def _parse_array_type(
43+
value: PBValue,
44+
metadata_type: SqlType.Array,
45+
column_name: str,
46+
column_info: dict[str, Any],
47+
) -> Any:
3748
"""
3849
used for parsing an array represented as a protobuf to a python list.
3950
"""
4051
return list(
4152
map(
4253
lambda val: _parse_pb_value_to_python_value(
43-
val, metadata_type.element_type
54+
val, metadata_type.element_type, column_name, column_info
4455
),
4556
value.array_value.values,
4657
)
4758
)
4859

4960

50-
def _parse_map_type(value: PBValue, metadata_type: SqlType.Map) -> Any:
61+
def _parse_map_type(
62+
value: PBValue,
63+
metadata_type: SqlType.Map,
64+
column_name: str,
65+
column_info: dict[str, Any],
66+
) -> Any:
5167
"""
5268
used for parsing a map represented as a protobuf to a python dict.
5369
@@ -64,10 +80,16 @@ def _parse_map_type(value: PBValue, metadata_type: SqlType.Map) -> Any:
6480
map(
6581
lambda map_entry: (
6682
_parse_pb_value_to_python_value(
67-
map_entry.array_value.values[0], metadata_type.key_type
83+
map_entry.array_value.values[0],
84+
metadata_type.key_type,
85+
f"{column_name}.key",
86+
column_info,
6887
),
6988
_parse_pb_value_to_python_value(
70-
map_entry.array_value.values[1], metadata_type.value_type
89+
map_entry.array_value.values[1],
90+
metadata_type.value_type,
91+
f"{column_name}.value",
92+
column_info,
7193
),
7294
),
7395
value.array_value.values,
@@ -77,7 +99,12 @@ def _parse_map_type(value: PBValue, metadata_type: SqlType.Map) -> Any:
7799
raise ValueError("Invalid map entry - less or more than two values.")
78100

79101

80-
def _parse_struct_type(value: PBValue, metadata_type: SqlType.Struct) -> Struct:
102+
def _parse_struct_type(
103+
value: PBValue,
104+
metadata_type: SqlType.Struct,
105+
column_name: str,
106+
column_info: dict[str, Any],
107+
) -> Struct:
81108
"""
82109
used for parsing a struct represented as a protobuf to a
83110
google.cloud.bigtable.data.execute_query.Struct
@@ -88,29 +115,84 @@ def _parse_struct_type(value: PBValue, metadata_type: SqlType.Struct) -> Struct:
88115
struct = Struct()
89116
for value, field in zip(value.array_value.values, metadata_type.fields):
90117
field_name, field_type = field
91-
struct.add_field(field_name, _parse_pb_value_to_python_value(value, field_type))
118+
# qualify the column name for nested lookups
119+
nested_column_name = (
120+
f"{column_name}.{field_name}" if field_name else column_name
121+
)
122+
struct.add_field(
123+
field_name,
124+
_parse_pb_value_to_python_value(
125+
value, field_type, nested_column_name, column_info
126+
),
127+
)
92128

93129
return struct
94130

95131

96132
def _parse_timestamp_type(
97-
value: PBValue, metadata_type: SqlType.Timestamp
133+
value: PBValue,
134+
metadata_type: SqlType.Timestamp,
135+
column_name: str,
136+
column_info: dict[str, Any],
98137
) -> DatetimeWithNanoseconds:
99138
"""
100139
used for parsing a timestamp represented as a protobuf to DatetimeWithNanoseconds
101140
"""
102141
return DatetimeWithNanoseconds.from_timestamp_pb(value.timestamp_value)
103142

104143

105-
_TYPE_PARSERS: Dict[Type[SqlType.Type], Callable[[PBValue, Any], Any]] = {
144+
def _parse_proto_type(
145+
value: PBValue,
146+
metadata_type: SqlType.Proto,
147+
column_name: str,
148+
column_info: dict[str, Any],
149+
) -> Message | bytes:
150+
"""
151+
Parses a serialized protobuf message into a Message object.
152+
"""
153+
if column_info is not None and column_info.get(column_name) is not None:
154+
default_proto_message = column_info.get(column_name)
155+
if isinstance(default_proto_message, Message):
156+
proto_message = type(default_proto_message)()
157+
proto_message.ParseFromString(value.bytes_value)
158+
return proto_message
159+
return value.bytes_value
160+
161+
162+
def _parse_enum_type(
163+
value: PBValue,
164+
metadata_type: SqlType.Enum,
165+
column_name: str,
166+
column_info: dict[str, Any],
167+
) -> int | Any:
168+
"""
169+
Parses an integer value into a Protobuf enum.
170+
"""
171+
if column_info is not None and column_info.get(column_name) is not None:
172+
proto_enum = column_info.get(column_name)
173+
if isinstance(proto_enum, EnumTypeWrapper):
174+
return proto_enum.Name(value.int_value)
175+
return value.int_value
176+
177+
178+
_TYPE_PARSERS: Dict[
179+
Type[SqlType.Type], Callable[[PBValue, Any, str, dict[str, Any]], Any]
180+
] = {
106181
SqlType.Timestamp: _parse_timestamp_type,
107182
SqlType.Struct: _parse_struct_type,
108183
SqlType.Array: _parse_array_type,
109184
SqlType.Map: _parse_map_type,
185+
SqlType.Proto: _parse_proto_type,
186+
SqlType.Enum: _parse_enum_type,
110187
}
111188

112189

113-
def _parse_pb_value_to_python_value(value: PBValue, metadata_type: SqlType.Type) -> Any:
190+
def _parse_pb_value_to_python_value(
191+
value: PBValue,
192+
metadata_type: SqlType.Type,
193+
column_name: str,
194+
column_info: dict[str, Any] = None,
195+
) -> Any:
114196
"""
115197
used for converting the value represented as a protobufs to a python object.
116198
"""
@@ -126,7 +208,7 @@ def _parse_pb_value_to_python_value(value: PBValue, metadata_type: SqlType.Type)
126208

127209
if kind in _TYPE_PARSERS:
128210
parser = _TYPE_PARSERS[kind]
129-
return parser(value, metadata_type)
211+
return parser(value, metadata_type, column_name, column_info)
130212
elif kind in _REQUIRED_PROTO_FIELDS:
131213
field_name = _REQUIRED_PROTO_FIELDS[kind]
132214
return getattr(value, field_name)

google/cloud/bigtable/data/execute_query/_reader.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,10 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
from __future__ import annotations
1415

1516
from typing import (
17+
Any,
1618
List,
1719
TypeVar,
1820
Generic,
@@ -54,7 +56,10 @@ class _Reader(ABC, Generic[T]):
5456

5557
@abstractmethod
5658
def consume(
57-
self, batches_to_consume: List[bytes], metadata: Metadata
59+
self,
60+
batches_to_consume: List[bytes],
61+
metadata: Metadata,
62+
column_info: dict[str, Any] = None,
5863
) -> Optional[Iterable[T]]:
5964
"""This method receives a list of batches of bytes to be parsed as ProtoRows messages.
6065
It then uses the metadata to group the values in the parsed messages into rows. Returns
@@ -64,6 +69,8 @@ def consume(
6469
:meth:`google.cloud.bigtable.byte_cursor._ByteCursor.consume`
6570
method.
6671
metadata: metadata used to transform values to rows
72+
column_info: (Optional) dict with mappings between column names and additional column information
73+
for protobuf deserialization.
6774
6875
Returns:
6976
Iterable[T] or None: Iterable if gathered values can form one or more instances of T,
@@ -89,7 +96,7 @@ def _parse_proto_rows(self, bytes_to_parse: bytes) -> Iterable[PBValue]:
8996
return proto_rows.values
9097

9198
def _construct_query_result_row(
92-
self, values: Sequence[PBValue], metadata: Metadata
99+
self, values: Sequence[PBValue], metadata: Metadata, column_info: dict[str, Any]
93100
) -> QueryResultRow:
94101
result = QueryResultRow()
95102
columns = metadata.columns
@@ -99,20 +106,29 @@ def _construct_query_result_row(
99106
), "This function should be called only when count of values matches count of columns."
100107

101108
for column, value in zip(columns, values):
102-
parsed_value = _parse_pb_value_to_python_value(value, column.column_type)
109+
parsed_value = _parse_pb_value_to_python_value(
110+
value, column.column_type, column.column_name, column_info
111+
)
103112
result.add_field(column.column_name, parsed_value)
104113
return result
105114

106115
def consume(
107-
self, batches_to_consume: List[bytes], metadata: Metadata
116+
self,
117+
batches_to_consume: List[bytes],
118+
metadata: Metadata,
119+
column_info: dict[str, Any] = None,
108120
) -> Optional[Iterable[QueryResultRow]]:
109121
num_columns = len(metadata.columns)
110122
rows = []
111123
for batch_bytes in batches_to_consume:
112124
values = self._parse_proto_rows(batch_bytes)
113125
for row_data in batched(values, n=num_columns):
114126
if len(row_data) == num_columns:
115-
rows.append(self._construct_query_result_row(row_data, metadata))
127+
rows.append(
128+
self._construct_query_result_row(
129+
row_data, metadata, column_info
130+
)
131+
)
116132
else:
117133
raise ValueError(
118134
"Unexpected error, recieved bad number of values. "

google/cloud/bigtable/data/execute_query/_sync_autogen/execute_query_iterator.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ def __init__(
6363
operation_timeout: float,
6464
req_metadata: Sequence[Tuple[str, str]] = (),
6565
retryable_excs: Sequence[type[Exception]] = (),
66+
column_info: Dict[str, Any] | None = None,
6667
) -> None:
6768
"""Collects responses from ExecuteQuery requests and parses them into QueryResultRows.
6869
@@ -82,6 +83,8 @@ def __init__(
8283
Failed requests will be retried within the budget
8384
req_metadata: metadata used while sending the gRPC request
8485
retryable_excs: a list of errors that will be retried if encountered.
86+
column_info: dict with mappings between column names and additional column information
87+
for protobuf deserialization.
8588
Raises:
8689
None
8790
:class:`ValueError <exceptions.ValueError>` as a safeguard if data is processed in an unexpected state
@@ -110,6 +113,7 @@ def __init__(
110113
exception_factory=_retry_exception_factory,
111114
)
112115
self._req_metadata = req_metadata
116+
self._column_info = column_info
113117
try:
114118
self._register_instance_task = CrossSync._Sync_Impl.create_task(
115119
self._client._register_instance,
@@ -164,7 +168,9 @@ def _next_impl(self) -> CrossSync._Sync_Impl.Iterator[QueryResultRow]:
164168
raise ValueError(
165169
"Error parsing response before finalizing metadata"
166170
)
167-
results = self._reader.consume(batches_to_parse, self.metadata)
171+
results = self._reader.consume(
172+
batches_to_parse, self.metadata, self._column_info
173+
)
168174
if results is None:
169175
continue
170176
except ValueError as e:

0 commit comments

Comments
 (0)