Module astrapy.data.utils.table_converters
Expand source code
# Copyright DataStax, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import copy
import datetime
import decimal
import hashlib
import ipaddress
import json
import logging
import math
from typing import Any, Callable, Dict, Generic, cast
from astrapy.constants import ROW, MapEncodingMode
from astrapy.data.info.table_descriptor.table_columns import (
TableColumnTypeDescriptor,
TableKeyValuedColumnTypeDescriptor,
TablePassthroughColumnTypeDescriptor,
TableScalarColumnTypeDescriptor,
TableUDTColumnDescriptor,
TableUnsupportedColumnTypeDescriptor,
TableValuedColumnTypeDescriptor,
TableVectorColumnTypeDescriptor,
)
from astrapy.data.utils.extended_json_converters import (
convert_ejson_binary_object_to_bytes,
convert_to_ejson_bytes,
)
from astrapy.data.utils.table_types import (
ColumnType,
TableKeyValuedColumnType,
TableValuedColumnType,
TableVectorColumnType,
)
from astrapy.data.utils.vector_coercion import ensure_unrolled_if_iterable
from astrapy.data_types import (
DataAPIDate,
DataAPIDictUDT,
DataAPIDuration,
DataAPIMap,
DataAPISet,
DataAPITime,
DataAPITimestamp,
DataAPIVector,
)
from astrapy.ids import UUID, ObjectId
from astrapy.settings.error_messages import CANNOT_ENCODE_NAIVE_DATETIME_ERROR_MESSAGE
from astrapy.utils.api_options import FullSerdesOptions
from astrapy.utils.date_utils import _get_datetime_offset
NAN_FLOAT_STRING_REPRESENTATION = "NaN"
PLUS_INFINITY_FLOAT_STRING_REPRESENTATION = "Infinity"
MINUS_INFINITY_FLOAT_STRING_REPRESENTATION = "-Infinity"
DATETIME_TIME_FORMAT = "%H:%M:%S.%f"
DATETIME_DATE_FORMAT = "%Y-%m-%d"
DATETIME_DATETIME_FORMAT = "%Y-%m-%dT%H:%M:%S.%f"
logger = logging.getLogger(__name__)
def _create_scalar_tpostprocessor(
column_type: ColumnType,
options: FullSerdesOptions,
) -> Callable[[Any], Any]:
if column_type in {
ColumnType.TEXT,
ColumnType.ASCII,
}:
def _tpostprocessor_text(raw_value: Any) -> str | None:
return raw_value # type: ignore[no-any-return]
return _tpostprocessor_text
elif column_type == ColumnType.BOOLEAN:
def _tpostprocessor_bool(raw_value: Any) -> bool | None:
return raw_value # type: ignore[no-any-return]
return _tpostprocessor_bool
elif column_type in {
ColumnType.INT,
ColumnType.VARINT,
ColumnType.BIGINT,
ColumnType.SMALLINT,
ColumnType.TINYINT,
ColumnType.COUNTER,
}:
def _tpostprocessor_int(raw_value: Any) -> int | None:
if raw_value is None:
return None
# the 'int(...)' handles Decimal's
return int(raw_value)
return _tpostprocessor_int
elif column_type in {
ColumnType.FLOAT,
ColumnType.DOUBLE,
}:
def _tpostprocessor_float(raw_value: Any) -> float | None:
if raw_value is None:
return None
elif isinstance(raw_value, (str, decimal.Decimal)):
return float(raw_value)
# just a float already
return cast(float, raw_value)
return _tpostprocessor_float
elif column_type == ColumnType.BLOB:
def _tpostprocessor_bytes(raw_value: Any) -> bytes | None:
if raw_value is None:
return None
if isinstance(raw_value, dict):
# {"$binary": ...}
return convert_ejson_binary_object_to_bytes(raw_value)
elif isinstance(raw_value, str):
# within PKSchema, a bare string (e.g. "q83vASNFZ4k=") is encountered
return convert_ejson_binary_object_to_bytes({"$binary": raw_value})
else:
raise ValueError(
f"Unexpected value type encountered for a blob column: {column_type}"
)
return _tpostprocessor_bytes
elif column_type in {ColumnType.UUID, ColumnType.TIMEUUID}:
def _tpostprocessor_uuid(raw_value: Any) -> UUID | None:
if raw_value is None:
return None
return UUID(raw_value)
return _tpostprocessor_uuid
elif column_type == ColumnType.DATE:
if options.custom_datatypes_in_reading:
def _tpostprocessor_date(raw_value: Any) -> DataAPIDate | None:
if raw_value is None:
return None
return DataAPIDate.from_string(raw_value)
return _tpostprocessor_date
else:
def _tpostprocessor_date_stdlib(raw_value: Any) -> datetime.date | None:
if raw_value is None:
return None
return DataAPIDate.from_string(raw_value).to_date()
return _tpostprocessor_date_stdlib
elif column_type == ColumnType.TIME:
if options.custom_datatypes_in_reading:
def _tpostprocessor_time(raw_value: Any) -> DataAPITime | None:
if raw_value is None:
return None
return DataAPITime.from_string(raw_value)
return _tpostprocessor_time
else:
def _tpostprocessor_time_stdlib(raw_value: Any) -> datetime.time | None:
if raw_value is None:
return None
return DataAPITime.from_string(raw_value).to_time()
return _tpostprocessor_time_stdlib
elif column_type == ColumnType.TIMESTAMP:
if options.custom_datatypes_in_reading:
def _tpostprocessor_timestamp(raw_value: Any) -> DataAPITimestamp | None:
if raw_value is None:
return None
return DataAPITimestamp.from_string(raw_value)
return _tpostprocessor_timestamp
else:
def _tpostprocessor_timestamp_stdlib(
raw_value: Any,
) -> datetime.datetime | None:
if raw_value is None:
return None
da_timestamp = DataAPITimestamp.from_string(raw_value)
return da_timestamp.to_datetime(tz=options.datetime_tzinfo)
return _tpostprocessor_timestamp_stdlib
elif column_type == ColumnType.DURATION:
if options.custom_datatypes_in_reading:
def _tpostprocessor_duration(raw_value: Any) -> DataAPIDuration | None:
if raw_value is None:
return None
return DataAPIDuration.from_string(raw_value)
return _tpostprocessor_duration
else:
def _tpostprocessor_duration_stdlib(
raw_value: Any,
) -> datetime.timedelta | None:
if raw_value is None:
return None
return DataAPIDuration.from_string(raw_value).to_timedelta()
return _tpostprocessor_duration_stdlib
elif column_type == ColumnType.INET:
def _tpostprocessor_inet(
raw_value: Any,
) -> ipaddress.IPv4Address | ipaddress.IPv6Address | None:
if raw_value is None:
return None
return ipaddress.ip_address(raw_value)
return _tpostprocessor_inet
elif column_type == ColumnType.DECIMAL:
def _tpostprocessor_decimal(raw_value: Any) -> decimal.Decimal | None:
if raw_value is None:
return None
elif isinstance(raw_value, decimal.Decimal):
return raw_value
# else: it is "NaN", "-Infinity" or "Infinity"
return decimal.Decimal(f"{raw_value}")
return _tpostprocessor_decimal
else:
raise ValueError(f"Unrecognized scalar type for reads: {column_type}")
def _create_unsupported_tpostprocessor(
col_definition: TableUnsupportedColumnTypeDescriptor,
options: FullSerdesOptions,
) -> Callable[[Any], Any]:
w_msg = (
"An 'UNSUPPORTED' column definition was encountered, unexpectedly, in the "
"schema information accompanying table read results. The values for the "
"column will be returned as the API provides them (full definition: "
f"{str(col_definition.as_dict())})."
)
logger.warning(w_msg)
def _tpostprocessor_unsupported(raw_value: Any) -> Any:
return raw_value
return _tpostprocessor_unsupported
def _create_passthrough_tpostprocessor(
col_definition: TablePassthroughColumnTypeDescriptor,
options: FullSerdesOptions,
) -> Callable[[Any], Any]:
w_msg = (
"The schema information, accompanying table read results, contains a column "
"definition that the client cannot properly parse. The values for the "
"column will be returned as the API provides them (full definition: "
f"{str(col_definition.as_dict())})."
)
logger.warning(w_msg)
def _tpostprocessor_passthrough(raw_value: Any) -> Any:
return raw_value
return _tpostprocessor_passthrough
def _column_filler_value(
col_def: TableColumnTypeDescriptor,
) -> Any:
"""
Prepare a 'filler' for omitted columns. Usually a None, but not always,
e.g. for a list it is [].
This is used before any options-related choice is made. As such, regardless
of the serdes settings, it always uses a representation that could have come
from parsing a JSON.
For example it fills map columns with an empty dict (and never DataAPIMap).
"""
if isinstance(col_def, TableScalarColumnTypeDescriptor):
return None
elif isinstance(col_def, TableVectorColumnTypeDescriptor):
if col_def.column_type == TableVectorColumnType.VECTOR:
return None
else:
raise ValueError(
f"Unrecognized table vector-column descriptor for reads: {col_def.as_dict()}"
)
elif isinstance(col_def, TableValuedColumnTypeDescriptor):
if col_def.column_type == TableValuedColumnType.LIST:
return []
elif TableValuedColumnType.SET:
return set()
else:
raise ValueError(
f"Unrecognized table valued-column descriptor for reads: {col_def.as_dict()}"
)
elif isinstance(col_def, TableKeyValuedColumnTypeDescriptor):
if col_def.column_type == TableKeyValuedColumnType.MAP:
return {}
else:
raise ValueError(
f"Unrecognized table key-valued-column descriptor for reads: {col_def.as_dict()}"
)
elif isinstance(col_def, TableUDTColumnDescriptor):
if col_def.definition is None:
raise ValueError(
"Read-path: received a UDT column schema without 'definition'."
)
filler_dict = {
fld_name: _column_filler_value(fld_def)
for fld_name, fld_def in col_def.definition.fields.items()
}
return filler_dict
elif isinstance(col_def, TableUnsupportedColumnTypeDescriptor):
# For lack of better information, the filler is a None:
return None
elif isinstance(col_def, TablePassthroughColumnTypeDescriptor):
# Given the missing information, must fill with None:
return None
else:
raise ValueError(
f"Unrecognized table column descriptor for reads: {col_def.as_dict()}"
)
def _create_column_tpostprocessor(
col_def: TableColumnTypeDescriptor,
options: FullSerdesOptions,
) -> Callable[[Any], Any]:
if isinstance(col_def, TableScalarColumnTypeDescriptor):
return _create_scalar_tpostprocessor(col_def.column_type, options=options)
elif isinstance(col_def, TableVectorColumnTypeDescriptor):
if col_def.column_type == TableVectorColumnType.VECTOR:
value_tpostprocessor = _create_scalar_tpostprocessor(
ColumnType.FLOAT,
options=options,
)
if options.custom_datatypes_in_reading:
def _tpostprocessor_vector(
raw_items: list[float] | dict[str, str] | None,
) -> DataAPIVector | None:
if raw_items is None:
return None
elif isinstance(raw_items, dict):
# {"$binary": ...}
return DataAPIVector.from_bytes(
convert_ejson_binary_object_to_bytes(raw_items)
)
return DataAPIVector(
[value_tpostprocessor(item) for item in raw_items]
)
return _tpostprocessor_vector
else:
def _tpostprocessor_vector_as_list(
raw_items: list[float] | dict[str, str] | None,
) -> list[float] | None:
if raw_items is None:
return None
elif isinstance(raw_items, dict):
# {"$binary": ...}
return DataAPIVector.from_bytes(
convert_ejson_binary_object_to_bytes(raw_items)
).data
return [value_tpostprocessor(item) for item in raw_items]
return _tpostprocessor_vector_as_list
else:
raise ValueError(
f"Unrecognized table vector-column descriptor for reads: {col_def.as_dict()}"
)
elif isinstance(col_def, TableValuedColumnTypeDescriptor):
if col_def.column_type == TableValuedColumnType.LIST:
value_tpostprocessor = _create_column_tpostprocessor(
col_def.value_type, options=options
)
def _tpostprocessor_list(raw_items: list[Any] | None) -> list[Any] | None:
if raw_items is None:
return None
return [value_tpostprocessor(item) for item in raw_items]
return _tpostprocessor_list
elif TableValuedColumnType.SET:
value_tpostprocessor = _create_column_tpostprocessor(
col_def.value_type, options=options
)
if options.custom_datatypes_in_reading:
def _tpostprocessor_dataapiset(
raw_items: set[Any] | None,
) -> DataAPISet[Any] | None:
if raw_items is None:
return None
return DataAPISet(value_tpostprocessor(item) for item in raw_items)
return _tpostprocessor_dataapiset
else:
def _tpostprocessor_dataapiset_as_set(
raw_items: set[Any] | None,
) -> set[Any] | None:
if raw_items is None:
return None
return {value_tpostprocessor(item) for item in raw_items}
return _tpostprocessor_dataapiset_as_set
else:
raise ValueError(
f"Unrecognized table valued-column descriptor for reads: {col_def.as_dict()}"
)
elif isinstance(col_def, TableKeyValuedColumnTypeDescriptor):
if col_def.column_type == TableKeyValuedColumnType.MAP:
key_tpostprocessor = _create_column_tpostprocessor(
col_def.key_type, options=options
)
value_tpostprocessor = _create_column_tpostprocessor(
col_def.value_type, options=options
)
if options.custom_datatypes_in_reading:
def _tpostprocessor_dataapimap(
raw_items: dict[Any, Any] | list[list[Any]] | None,
) -> DataAPIMap[Any, Any] | None:
if raw_items is None:
return None
if isinstance(raw_items, dict):
return DataAPIMap(
(key_tpostprocessor(k), value_tpostprocessor(v))
for k, v in raw_items.items()
)
# it's a list-of-2tuples
return DataAPIMap(
(key_tpostprocessor(k), value_tpostprocessor(v))
for k, v in raw_items
)
return _tpostprocessor_dataapimap
else:
def _tpostprocessor_dataapimap_as_dict(
raw_items: dict[Any, Any] | list[list[Any]] | None,
) -> dict[Any, Any] | None:
if raw_items is None:
return None
if isinstance(raw_items, dict):
return {
key_tpostprocessor(k): value_tpostprocessor(v)
for k, v in raw_items.items()
}
# it's a list-of-2tuples
return {
key_tpostprocessor(k): value_tpostprocessor(v)
for k, v in raw_items
}
return _tpostprocessor_dataapimap_as_dict
else:
raise ValueError(
f"Unrecognized table key-valued-column descriptor for reads: {col_def.as_dict()}"
)
elif isinstance(col_def, TableUDTColumnDescriptor):
if col_def.definition is None:
raise ValueError(
f"Schema information lacks 'definition' field for {col_def.udt_name}."
)
# first the incoming dictionary must be deserialized in its types,
# then the ("udt-level") deserializer -- custom or default -- is invoked
values_tpostprocessor_map = {
k_fieldname: _create_column_tpostprocessor(k_fieldtype, options=options)
for k_fieldname, k_fieldtype in col_def.definition.fields.items()
}
# if there is a registered deserializer, no matter whether 'use custom dtypes':
deserializer_for_udt = options.deserializer_by_udt.get(col_def.udt_name)
# common to all settings: normalize (null/partials) to deserialized, full dicts:
null_filler_udt = _column_filler_value(col_def)
def _tpostprocessor_udt_baredict(
raw_items: dict[Any, Any] | None,
) -> dict[Any, Any]:
# convert nulls to dicts and apply postprocessor to fields
udt_deserialized_dict = {
k_fieldname: values_tpostprocessor_map[k_fieldname](v_fieldraw)
for k_fieldname, v_fieldraw in (raw_items or {}).items()
}
# complete using fillers for missing fields
return {
**null_filler_udt,
**udt_deserialized_dict,
}
if deserializer_for_udt:
def _tpostprocessor_udt_w_deser(raw_items: dict[Any, Any] | None) -> Any:
# further wrap with the desired configured deserializer
return deserializer_for_udt(
_tpostprocessor_udt_baredict(raw_items),
col_def.definition,
)
return _tpostprocessor_udt_w_deser
elif options.custom_datatypes_in_reading:
def _tpostprocessor_udt_defdeser(raw_items: dict[Any, Any] | None) -> Any:
# further wrap as DataAPIDictUDT
return DataAPIDictUDT(_tpostprocessor_udt_baredict(raw_items))
return _tpostprocessor_udt_defdeser
else:
return _tpostprocessor_udt_baredict
elif isinstance(col_def, TableUnsupportedColumnTypeDescriptor):
# 'Unsupported' columns (marked as such by the API) should never be
# returned in reading. However, this is no sufficient reason not to comply.
return _create_unsupported_tpostprocessor(col_def, options=options)
elif isinstance(col_def, TablePassthroughColumnTypeDescriptor):
# 'passthrough' columns (i.e. those whose schema the client cannot parse)
return _create_passthrough_tpostprocessor(col_def, options=options)
else:
raise ValueError(
f"Unrecognized table column descriptor for reads: {col_def.as_dict()}"
)
def create_row_tpostprocessor(
columns: dict[str, TableColumnTypeDescriptor],
options: FullSerdesOptions,
similarity_pseudocolumn: str | None,
) -> Callable[[dict[str, Any]], dict[str, Any]]:
tpostprocessor_map = {
col_name: _create_column_tpostprocessor(col_definition, options=options)
for col_name, col_definition in columns.items()
}
tfiller_map = {
col_name: _column_filler_value(col_definition)
for col_name, col_definition in columns.items()
}
if similarity_pseudocolumn is not None:
# whatever in the passed schema, requiring similarity overrides that 'column':
tpostprocessor_map[similarity_pseudocolumn] = _create_scalar_tpostprocessor(
column_type=ColumnType.FLOAT, options=options
)
tfiller_map[similarity_pseudocolumn] = None
column_name_set = set(tpostprocessor_map.keys())
def _tpostprocessor(raw_dict: dict[str, Any]) -> dict[str, Any]:
extra_fields = set(raw_dict.keys()) - column_name_set
if extra_fields:
xf_desc = ", ".join(f'"{f}"' for f in sorted(extra_fields))
raise ValueError(f"Returned row has unexpected fields: {xf_desc}")
return {
col_name: (
# making a copy here, since the user may mutate e.g. a map:
tpostprocessor(copy.copy(tfiller_map[col_name]))
if col_name not in raw_dict
else tpostprocessor(raw_dict[col_name])
)
for col_name, tpostprocessor in tpostprocessor_map.items()
}
return _tpostprocessor
def create_key_ktpostprocessor(
primary_key_schema: dict[str, TableColumnTypeDescriptor],
options: FullSerdesOptions,
) -> Callable[[list[Any]], tuple[tuple[Any, ...], dict[str, Any]]]:
ktpostprocessor_list: list[tuple[str, Callable[[Any], Any]]] = [
(col_name, _create_column_tpostprocessor(col_definition, options=options))
for col_name, col_definition in primary_key_schema.items()
]
def _ktpostprocessor(
primary_key_list: list[Any],
) -> tuple[tuple[Any, ...], dict[str, Any]]:
if len(primary_key_list) != len(ktpostprocessor_list):
raise ValueError(
"Primary key list length / schema mismatch "
f"(expected {len(ktpostprocessor_list)}, "
f"received {len(primary_key_list)} fields)"
)
k_tuple = tuple(
[
ktpostprocessor(pk_col_value)
for pk_col_value, (_, ktpostprocessor) in zip(
primary_key_list,
ktpostprocessor_list,
)
]
)
k_dict = {
pk_col_name: pk_processed_value
for pk_processed_value, (pk_col_name, _) in zip(
k_tuple,
ktpostprocessor_list,
)
}
return k_tuple, k_dict
return _ktpostprocessor
def preprocess_table_payload_value(
path: list[str],
value: Any,
options: FullSerdesOptions,
map2tuple_checker: Callable[[list[str]], bool] | None,
) -> Any:
"""
Walk a payload for Tables and apply the necessary and required conversions
to make it into a ready-to-jsondumps object.
"""
# The check for UDT dict-wrapper must come before the "plain dict" check
if isinstance(value, DataAPIDictUDT):
# field-wise serialize and return as (JSON-ready) map:
udt_dict = dict(value)
return {
udt_k: preprocess_table_payload_value(
path + [udt_k],
udt_v,
options=options,
map2tuple_checker=map2tuple_checker,
)
for udt_k, udt_v in udt_dict.items()
}
elif isinstance(value, (dict, DataAPIMap)):
# This is a nesting structure (but not the dict-wrapper for UDTs)
maps_can_become_tuples: bool
if options.encode_maps_as_lists_in_tables == MapEncodingMode.NEVER:
maps_can_become_tuples = False
elif options.encode_maps_as_lists_in_tables == MapEncodingMode.DATAAPIMAPS:
maps_can_become_tuples = isinstance(value, DataAPIMap)
else:
# 'ALWAYS' setting
maps_can_become_tuples = True
maps_become_tuples: bool
if maps_can_become_tuples:
if map2tuple_checker is None:
maps_become_tuples = False
else:
maps_become_tuples = map2tuple_checker(path)
else:
maps_become_tuples = False
# empty maps must always be encoded as `{}`, never as `[]` (#2005)
if maps_become_tuples and value:
return [
[
preprocess_table_payload_value(
path,
k,
options=options,
map2tuple_checker=map2tuple_checker,
),
preprocess_table_payload_value(
path + [k],
v,
options=options,
map2tuple_checker=map2tuple_checker,
),
]
for k, v in value.items()
]
return {
preprocess_table_payload_value(
path, k, options=options, map2tuple_checker=map2tuple_checker
): preprocess_table_payload_value(
path + [k], v, options=options, map2tuple_checker=map2tuple_checker
)
for k, v in value.items()
}
elif isinstance(value, (list, set, DataAPISet)):
return [
preprocess_table_payload_value(
path + [""], v, options=options, map2tuple_checker=map2tuple_checker
)
for v in value
]
# it's a scalar of some kind (which includes DataAPIVector)
if isinstance(value, float):
# Non-numbers must be manually made into a string
if math.isnan(value):
return NAN_FLOAT_STRING_REPRESENTATION
elif math.isinf(value):
if value > 0:
return PLUS_INFINITY_FLOAT_STRING_REPRESENTATION
else:
return MINUS_INFINITY_FLOAT_STRING_REPRESENTATION
return value
elif isinstance(value, bytes):
return convert_to_ejson_bytes(value)
elif isinstance(value, DataAPIVector):
if options.binary_encode_vectors:
return convert_to_ejson_bytes(value.to_bytes())
else:
# regular list of floats - which can contain non-numbers:
return [
preprocess_table_payload_value(
path + [""],
fval,
options=options,
map2tuple_checker=map2tuple_checker,
)
for fval in value.data
]
elif isinstance(value, DataAPITimestamp):
return value.to_string()
elif isinstance(value, DataAPIDate):
return value.to_string()
elif isinstance(value, DataAPITime):
return value.to_string()
elif isinstance(value, datetime.datetime):
# encoding in two steps (that's because the '%:z' strftime directive
# is not in all supported Python versions).
offset_tuple = _get_datetime_offset(value)
if offset_tuple is None:
if options.accept_naive_datetimes:
return DataAPITimestamp(int(value.timestamp() * 1000)).to_string()
raise ValueError(CANNOT_ENCODE_NAIVE_DATETIME_ERROR_MESSAGE)
date_part_str = value.strftime(DATETIME_DATETIME_FORMAT)
offset_h, offset_m = offset_tuple
offset_part_str = f"{offset_h:+03}:{offset_m:02}"
return f"{date_part_str}{offset_part_str}"
elif isinstance(value, datetime.date):
# there's no format to specify - and this is compliant anyway:
return value.strftime(DATETIME_DATE_FORMAT)
elif isinstance(value, datetime.time):
return value.strftime(DATETIME_TIME_FORMAT)
elif isinstance(value, decimal.Decimal):
# Non-numbers must be manually made into a string, just like floats
if math.isnan(value):
return NAN_FLOAT_STRING_REPRESENTATION
elif math.isinf(value):
if value > 0:
return PLUS_INFINITY_FLOAT_STRING_REPRESENTATION
else:
return MINUS_INFINITY_FLOAT_STRING_REPRESENTATION
# actually-numeric decimals: leave them as they are for the encoding step,
# which will apply the nasty trick to ensure all digits get there.
return value
elif isinstance(value, DataAPIDuration):
# using to_c_string over to_string until the ISO-format parsing can
# cope with subsecond fractions:
return value.to_c_string()
elif isinstance(value, UUID):
return str(value)
elif isinstance(value, (ipaddress.IPv4Address, ipaddress.IPv6Address)):
return str(value)
elif isinstance(value, datetime.timedelta):
return DataAPIDuration.from_timedelta(value).to_c_string()
elif isinstance(value, ObjectId):
raise ValueError(
"Values of type ObjectId are not supported. Consider switching to "
"using UUID-based identifiers instead."
)
# try to unroll if applicable and then preprocess known types:
_uvalue: Any
if options.unroll_iterables_to_lists:
_uvalue = ensure_unrolled_if_iterable(value)
else:
_uvalue = value
# process it as
if isinstance(_uvalue, list):
return [
preprocess_table_payload_value(
path + [""], v, options=options, map2tuple_checker=map2tuple_checker
)
for v in _uvalue
]
# is it a well-known, natively-JSON-serializable type:
if isinstance(_uvalue, (str, int, float, bool, type(None))):
return _uvalue
# check whether instance of a class with a registered serializer:
for k_cls, k_serializer in options.serializer_by_class.items():
if isinstance(_uvalue, k_cls) and k_serializer is not None:
udt_dict_form = k_serializer(_uvalue)
return {
udt_k: preprocess_table_payload_value(
path + [udt_k],
udt_v,
options=options,
map2tuple_checker=map2tuple_checker,
)
for udt_k, udt_v in udt_dict_form.items()
}
# this is a last-ditch attempt. Likely results in a "not JSON serializable" error"
return _uvalue
def preprocess_table_payload(
payload: dict[str, Any] | None,
options: FullSerdesOptions,
map2tuple_checker: Callable[[list[str]], bool] | None,
) -> dict[str, Any] | None:
"""
Normalize a payload for API calls.
This includes e.g. ensuring values for "$vector" key
are made into plain lists of floats.
Args:
payload (dict[str, Any]): A dict expressing a payload for an API call
options: a FullSerdesOptions setting the preprocessing configuration
map2tuple_checker: a boolean function of a path in the doc, that returns
True for "doc-like" portions of a payload, i.e. whose maps/DataAPIMaps
can be converted into association lists, if such autoconversion is
turned on. If this parameter is None, no paths are autoconverted.
Returns:
dict[str, Any]: a payload dict, pre-processed, ready for HTTP requests.
"""
if payload:
return cast(
Dict[str, Any],
preprocess_table_payload_value(
[],
payload,
options=options,
map2tuple_checker=map2tuple_checker,
),
)
else:
return payload
class _DecimalCleaner(json.JSONEncoder):
"""
This class cleans decimal (coming from decimal-oriented parsing of responses)
so that the schema can be made into a string, hashed, and used as key to the
converters cache safely.
"""
def default(self, obj: object) -> Any:
if isinstance(obj, decimal.Decimal):
return float(obj)
return super().default(obj)
class _TableConverterAgent(Generic[ROW]):
options: FullSerdesOptions
row_postprocessors: dict[
tuple[str, str | None], Callable[[dict[str, Any]], dict[str, Any]]
]
key_postprocessors: dict[
str, Callable[[list[Any]], tuple[tuple[Any, ...], dict[str, Any]]]
]
def __init__(self, *, options: FullSerdesOptions) -> None:
self.options = options
self.row_postprocessors = {}
self.key_postprocessors = {}
@staticmethod
def _hash_dict(input_dict: dict[str, Any]) -> str:
return hashlib.md5(
json.dumps(
input_dict,
sort_keys=True,
separators=(",", ":"),
cls=_DecimalCleaner,
).encode()
).hexdigest()
def _get_key_postprocessor(
self, primary_key_schema_dict: dict[str, Any]
) -> Callable[[list[Any]], tuple[tuple[Any, ...], dict[str, Any]]]:
schema_hash = self._hash_dict(primary_key_schema_dict)
if schema_hash not in self.key_postprocessors:
primary_key_schema: dict[str, TableColumnTypeDescriptor] = {
col_name: TableColumnTypeDescriptor.coerce(col_dict)
for col_name, col_dict in primary_key_schema_dict.items()
}
self.key_postprocessors[schema_hash] = create_key_ktpostprocessor(
primary_key_schema=primary_key_schema,
options=self.options,
)
return self.key_postprocessors[schema_hash]
def _get_row_postprocessor(
self,
columns_dict: dict[str, Any],
similarity_pseudocolumn: str | None,
) -> Callable[[dict[str, Any]], dict[str, Any]]:
schema_cache_key = (self._hash_dict(columns_dict), similarity_pseudocolumn)
if schema_cache_key not in self.row_postprocessors:
columns: dict[str, TableColumnTypeDescriptor] = {
col_name: TableColumnTypeDescriptor.coerce(col_dict)
for col_name, col_dict in columns_dict.items()
}
self.row_postprocessors[schema_cache_key] = create_row_tpostprocessor(
columns=columns,
options=self.options,
similarity_pseudocolumn=similarity_pseudocolumn,
)
return self.row_postprocessors[schema_cache_key]
def preprocess_payload(
self,
payload: dict[str, Any] | None,
map2tuple_checker: Callable[[list[str]], bool] | None,
) -> dict[str, Any] | None:
return preprocess_table_payload(
payload,
options=self.options,
map2tuple_checker=map2tuple_checker,
)
def postprocess_key(
self, primary_key_list: list[Any], *, primary_key_schema_dict: dict[str, Any]
) -> tuple[tuple[Any, ...], dict[str, Any]]:
"""
The primary key schema is not coerced here, just parsed from its json
"""
return self._get_key_postprocessor(
primary_key_schema_dict=primary_key_schema_dict
)(primary_key_list)
def postprocess_keys(
self,
primary_key_lists: list[list[Any]],
*,
primary_key_schema_dict: dict[str, Any],
) -> list[tuple[tuple[Any, ...], dict[str, Any]]]:
"""
The primary key schema is not coerced here, just parsed from its json
"""
if primary_key_lists:
_k_postprocessor = self._get_key_postprocessor(
primary_key_schema_dict=primary_key_schema_dict
)
return [
_k_postprocessor(primary_key_list)
for primary_key_list in primary_key_lists
]
else:
return []
def postprocess_row(
self,
raw_dict: dict[str, Any],
*,
columns_dict: dict[str, Any],
similarity_pseudocolumn: str | None,
) -> ROW:
"""
The columns schema is not coerced here, just parsed from its json
"""
return self._get_row_postprocessor(
columns_dict=columns_dict, similarity_pseudocolumn=similarity_pseudocolumn
)(raw_dict) # type: ignore[return-value]
def postprocess_rows(
self,
raw_dicts: list[dict[str, Any]],
*,
columns_dict: dict[str, Any],
similarity_pseudocolumn: str | None,
) -> list[ROW]:
"""
The columns schema is not coerced here, just parsed from its json
"""
if raw_dicts:
_r_postprocessor = self._get_row_postprocessor(
columns_dict=columns_dict,
similarity_pseudocolumn=similarity_pseudocolumn,
)
return [cast(ROW, _r_postprocessor(raw_dict)) for raw_dict in raw_dicts]
else:
return []
Functions
def create_key_ktpostprocessor(primary_key_schema: dict[str, TableColumnTypeDescriptor], options: FullSerdesOptions) ‑> Callable[[list[Any]], tuple[tuple[Any, ...], dict[str, Any]]]
-
Expand source code
def create_key_ktpostprocessor( primary_key_schema: dict[str, TableColumnTypeDescriptor], options: FullSerdesOptions, ) -> Callable[[list[Any]], tuple[tuple[Any, ...], dict[str, Any]]]: ktpostprocessor_list: list[tuple[str, Callable[[Any], Any]]] = [ (col_name, _create_column_tpostprocessor(col_definition, options=options)) for col_name, col_definition in primary_key_schema.items() ] def _ktpostprocessor( primary_key_list: list[Any], ) -> tuple[tuple[Any, ...], dict[str, Any]]: if len(primary_key_list) != len(ktpostprocessor_list): raise ValueError( "Primary key list length / schema mismatch " f"(expected {len(ktpostprocessor_list)}, " f"received {len(primary_key_list)} fields)" ) k_tuple = tuple( [ ktpostprocessor(pk_col_value) for pk_col_value, (_, ktpostprocessor) in zip( primary_key_list, ktpostprocessor_list, ) ] ) k_dict = { pk_col_name: pk_processed_value for pk_processed_value, (pk_col_name, _) in zip( k_tuple, ktpostprocessor_list, ) } return k_tuple, k_dict return _ktpostprocessor
def create_row_tpostprocessor(columns: dict[str, TableColumnTypeDescriptor], options: FullSerdesOptions, similarity_pseudocolumn: str | None) ‑> Callable[[dict[str, Any]], dict[str, Any]]
-
Expand source code
def create_row_tpostprocessor( columns: dict[str, TableColumnTypeDescriptor], options: FullSerdesOptions, similarity_pseudocolumn: str | None, ) -> Callable[[dict[str, Any]], dict[str, Any]]: tpostprocessor_map = { col_name: _create_column_tpostprocessor(col_definition, options=options) for col_name, col_definition in columns.items() } tfiller_map = { col_name: _column_filler_value(col_definition) for col_name, col_definition in columns.items() } if similarity_pseudocolumn is not None: # whatever in the passed schema, requiring similarity overrides that 'column': tpostprocessor_map[similarity_pseudocolumn] = _create_scalar_tpostprocessor( column_type=ColumnType.FLOAT, options=options ) tfiller_map[similarity_pseudocolumn] = None column_name_set = set(tpostprocessor_map.keys()) def _tpostprocessor(raw_dict: dict[str, Any]) -> dict[str, Any]: extra_fields = set(raw_dict.keys()) - column_name_set if extra_fields: xf_desc = ", ".join(f'"{f}"' for f in sorted(extra_fields)) raise ValueError(f"Returned row has unexpected fields: {xf_desc}") return { col_name: ( # making a copy here, since the user may mutate e.g. a map: tpostprocessor(copy.copy(tfiller_map[col_name])) if col_name not in raw_dict else tpostprocessor(raw_dict[col_name]) ) for col_name, tpostprocessor in tpostprocessor_map.items() } return _tpostprocessor
def preprocess_table_payload(payload: dict[str, Any] | None, options: FullSerdesOptions, map2tuple_checker: Callable[[list[str]], bool] | None) ‑> dict[str, typing.Any] | None
-
Normalize a payload for API calls. This includes e.g. ensuring values for "$vector" key are made into plain lists of floats.
Args
payload
:dict[str, Any]
- A dict expressing a payload for an API call
options
- a FullSerdesOptions setting the preprocessing configuration
map2tuple_checker
- a boolean function of a path in the doc, that returns True for "doc-like" portions of a payload, i.e. whose maps/DataAPIMaps can be converted into association lists, if such autoconversion is turned on. If this parameter is None, no paths are autoconverted.
Returns
dict[str, Any]
- a payload dict, pre-processed, ready for HTTP requests.
Expand source code
def preprocess_table_payload( payload: dict[str, Any] | None, options: FullSerdesOptions, map2tuple_checker: Callable[[list[str]], bool] | None, ) -> dict[str, Any] | None: """ Normalize a payload for API calls. This includes e.g. ensuring values for "$vector" key are made into plain lists of floats. Args: payload (dict[str, Any]): A dict expressing a payload for an API call options: a FullSerdesOptions setting the preprocessing configuration map2tuple_checker: a boolean function of a path in the doc, that returns True for "doc-like" portions of a payload, i.e. whose maps/DataAPIMaps can be converted into association lists, if such autoconversion is turned on. If this parameter is None, no paths are autoconverted. Returns: dict[str, Any]: a payload dict, pre-processed, ready for HTTP requests. """ if payload: return cast( Dict[str, Any], preprocess_table_payload_value( [], payload, options=options, map2tuple_checker=map2tuple_checker, ), ) else: return payload
def preprocess_table_payload_value(path: list[str], value: Any, options: FullSerdesOptions, map2tuple_checker: Callable[[list[str]], bool] | None) ‑> Any
-
Walk a payload for Tables and apply the necessary and required conversions to make it into a ready-to-jsondumps object.
Expand source code
def preprocess_table_payload_value( path: list[str], value: Any, options: FullSerdesOptions, map2tuple_checker: Callable[[list[str]], bool] | None, ) -> Any: """ Walk a payload for Tables and apply the necessary and required conversions to make it into a ready-to-jsondumps object. """ # The check for UDT dict-wrapper must come before the "plain dict" check if isinstance(value, DataAPIDictUDT): # field-wise serialize and return as (JSON-ready) map: udt_dict = dict(value) return { udt_k: preprocess_table_payload_value( path + [udt_k], udt_v, options=options, map2tuple_checker=map2tuple_checker, ) for udt_k, udt_v in udt_dict.items() } elif isinstance(value, (dict, DataAPIMap)): # This is a nesting structure (but not the dict-wrapper for UDTs) maps_can_become_tuples: bool if options.encode_maps_as_lists_in_tables == MapEncodingMode.NEVER: maps_can_become_tuples = False elif options.encode_maps_as_lists_in_tables == MapEncodingMode.DATAAPIMAPS: maps_can_become_tuples = isinstance(value, DataAPIMap) else: # 'ALWAYS' setting maps_can_become_tuples = True maps_become_tuples: bool if maps_can_become_tuples: if map2tuple_checker is None: maps_become_tuples = False else: maps_become_tuples = map2tuple_checker(path) else: maps_become_tuples = False # empty maps must always be encoded as `{}`, never as `[]` (#2005) if maps_become_tuples and value: return [ [ preprocess_table_payload_value( path, k, options=options, map2tuple_checker=map2tuple_checker, ), preprocess_table_payload_value( path + [k], v, options=options, map2tuple_checker=map2tuple_checker, ), ] for k, v in value.items() ] return { preprocess_table_payload_value( path, k, options=options, map2tuple_checker=map2tuple_checker ): preprocess_table_payload_value( path + [k], v, options=options, map2tuple_checker=map2tuple_checker ) for k, v in value.items() } elif isinstance(value, (list, set, DataAPISet)): return [ preprocess_table_payload_value( path + [""], v, options=options, map2tuple_checker=map2tuple_checker ) for v in value ] # it's a scalar of some kind (which includes DataAPIVector) if isinstance(value, float): # Non-numbers must be manually made into a string if math.isnan(value): return NAN_FLOAT_STRING_REPRESENTATION elif math.isinf(value): if value > 0: return PLUS_INFINITY_FLOAT_STRING_REPRESENTATION else: return MINUS_INFINITY_FLOAT_STRING_REPRESENTATION return value elif isinstance(value, bytes): return convert_to_ejson_bytes(value) elif isinstance(value, DataAPIVector): if options.binary_encode_vectors: return convert_to_ejson_bytes(value.to_bytes()) else: # regular list of floats - which can contain non-numbers: return [ preprocess_table_payload_value( path + [""], fval, options=options, map2tuple_checker=map2tuple_checker, ) for fval in value.data ] elif isinstance(value, DataAPITimestamp): return value.to_string() elif isinstance(value, DataAPIDate): return value.to_string() elif isinstance(value, DataAPITime): return value.to_string() elif isinstance(value, datetime.datetime): # encoding in two steps (that's because the '%:z' strftime directive # is not in all supported Python versions). offset_tuple = _get_datetime_offset(value) if offset_tuple is None: if options.accept_naive_datetimes: return DataAPITimestamp(int(value.timestamp() * 1000)).to_string() raise ValueError(CANNOT_ENCODE_NAIVE_DATETIME_ERROR_MESSAGE) date_part_str = value.strftime(DATETIME_DATETIME_FORMAT) offset_h, offset_m = offset_tuple offset_part_str = f"{offset_h:+03}:{offset_m:02}" return f"{date_part_str}{offset_part_str}" elif isinstance(value, datetime.date): # there's no format to specify - and this is compliant anyway: return value.strftime(DATETIME_DATE_FORMAT) elif isinstance(value, datetime.time): return value.strftime(DATETIME_TIME_FORMAT) elif isinstance(value, decimal.Decimal): # Non-numbers must be manually made into a string, just like floats if math.isnan(value): return NAN_FLOAT_STRING_REPRESENTATION elif math.isinf(value): if value > 0: return PLUS_INFINITY_FLOAT_STRING_REPRESENTATION else: return MINUS_INFINITY_FLOAT_STRING_REPRESENTATION # actually-numeric decimals: leave them as they are for the encoding step, # which will apply the nasty trick to ensure all digits get there. return value elif isinstance(value, DataAPIDuration): # using to_c_string over to_string until the ISO-format parsing can # cope with subsecond fractions: return value.to_c_string() elif isinstance(value, UUID): return str(value) elif isinstance(value, (ipaddress.IPv4Address, ipaddress.IPv6Address)): return str(value) elif isinstance(value, datetime.timedelta): return DataAPIDuration.from_timedelta(value).to_c_string() elif isinstance(value, ObjectId): raise ValueError( "Values of type ObjectId are not supported. Consider switching to " "using UUID-based identifiers instead." ) # try to unroll if applicable and then preprocess known types: _uvalue: Any if options.unroll_iterables_to_lists: _uvalue = ensure_unrolled_if_iterable(value) else: _uvalue = value # process it as if isinstance(_uvalue, list): return [ preprocess_table_payload_value( path + [""], v, options=options, map2tuple_checker=map2tuple_checker ) for v in _uvalue ] # is it a well-known, natively-JSON-serializable type: if isinstance(_uvalue, (str, int, float, bool, type(None))): return _uvalue # check whether instance of a class with a registered serializer: for k_cls, k_serializer in options.serializer_by_class.items(): if isinstance(_uvalue, k_cls) and k_serializer is not None: udt_dict_form = k_serializer(_uvalue) return { udt_k: preprocess_table_payload_value( path + [udt_k], udt_v, options=options, map2tuple_checker=map2tuple_checker, ) for udt_k, udt_v in udt_dict_form.items() } # this is a last-ditch attempt. Likely results in a "not JSON serializable" error" return _uvalue