Module astrapy.data.utils.table_converters
Expand source code
# Copyright DataStax, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import copy
import datetime
import decimal
import hashlib
import ipaddress
import json
import math
from typing import Any, Callable, Generic, cast
from astrapy.constants import ROW
from astrapy.data.info.table_descriptor.table_columns import (
TableColumnTypeDescriptor,
TableKeyValuedColumnTypeDescriptor,
TableScalarColumnTypeDescriptor,
TableUnsupportedColumnTypeDescriptor,
TableValuedColumnTypeDescriptor,
TableVectorColumnTypeDescriptor,
)
from astrapy.data.utils.extended_json_converters import (
convert_ejson_binary_object_to_bytes,
convert_to_ejson_bytes,
)
from astrapy.data.utils.table_types import (
ColumnType,
TableKeyValuedColumnType,
TableUnsupportedColumnType,
TableValuedColumnType,
TableVectorColumnType,
)
from astrapy.data.utils.vector_coercion import ensure_unrolled_if_iterable
from astrapy.data_types import (
DataAPIDate,
DataAPIDuration,
DataAPIMap,
DataAPISet,
DataAPITime,
DataAPITimestamp,
DataAPIVector,
)
from astrapy.ids import UUID, ObjectId
from astrapy.settings.error_messages import CANNOT_ENCODE_NAIVE_DATETIME_ERROR_MESSAGE
from astrapy.utils.api_options import FullSerdesOptions
from astrapy.utils.date_utils import _get_datetime_offset
NAN_FLOAT_STRING_REPRESENTATION = "NaN"
PLUS_INFINITY_FLOAT_STRING_REPRESENTATION = "Infinity"
MINUS_INFINITY_FLOAT_STRING_REPRESENTATION = "-Infinity"
DATETIME_TIME_FORMAT = "%H:%M:%S.%f"
DATETIME_DATE_FORMAT = "%Y-%m-%d"
DATETIME_DATETIME_FORMAT = "%Y-%m-%dT%H:%M:%S.%f"
def _create_scalar_tpostprocessor(
column_type: ColumnType,
options: FullSerdesOptions,
) -> Callable[[Any], Any]:
if column_type in {
ColumnType.TEXT,
ColumnType.ASCII,
}:
def _tpostprocessor_text(raw_value: Any) -> str | None:
return raw_value # type: ignore[no-any-return]
return _tpostprocessor_text
elif column_type == ColumnType.BOOLEAN:
def _tpostprocessor_bool(raw_value: Any) -> bool | None:
return raw_value # type: ignore[no-any-return]
return _tpostprocessor_bool
elif column_type in {
ColumnType.INT,
ColumnType.VARINT,
ColumnType.BIGINT,
ColumnType.SMALLINT,
ColumnType.TINYINT,
}:
def _tpostprocessor_int(raw_value: Any) -> int | None:
if raw_value is None:
return None
# the 'int(...)' handles Decimal's
return int(raw_value)
return _tpostprocessor_int
elif column_type in {
ColumnType.FLOAT,
ColumnType.DOUBLE,
}:
def _tpostprocessor_float(raw_value: Any) -> float | None:
if raw_value is None:
return None
elif isinstance(raw_value, (str, decimal.Decimal)):
return float(raw_value)
# just a float already
return cast(float, raw_value)
return _tpostprocessor_float
elif column_type == ColumnType.BLOB:
def _tpostprocessor_bytes(raw_value: Any) -> bytes | None:
if raw_value is None:
return None
if isinstance(raw_value, dict):
# {"$binary": ...}
return convert_ejson_binary_object_to_bytes(raw_value)
elif isinstance(raw_value, str):
# within PKSchema, a bare string (e.g. "q83vASNFZ4k=") is encountered
return convert_ejson_binary_object_to_bytes({"$binary": raw_value})
else:
raise ValueError(
f"Unexpected value type encountered for a blob column: {column_type}"
)
return _tpostprocessor_bytes
elif column_type == ColumnType.UUID:
def _tpostprocessor_uuid(raw_value: Any) -> UUID | None:
if raw_value is None:
return None
return UUID(raw_value)
return _tpostprocessor_uuid
elif column_type == ColumnType.DATE:
if options.custom_datatypes_in_reading:
def _tpostprocessor_date(raw_value: Any) -> DataAPIDate | None:
if raw_value is None:
return None
return DataAPIDate.from_string(raw_value)
return _tpostprocessor_date
else:
def _tpostprocessor_date_stdlib(raw_value: Any) -> datetime.date | None:
if raw_value is None:
return None
return DataAPIDate.from_string(raw_value).to_date()
return _tpostprocessor_date_stdlib
elif column_type == ColumnType.TIME:
if options.custom_datatypes_in_reading:
def _tpostprocessor_time(raw_value: Any) -> DataAPITime | None:
if raw_value is None:
return None
return DataAPITime.from_string(raw_value)
return _tpostprocessor_time
else:
def _tpostprocessor_time_stdlib(raw_value: Any) -> datetime.time | None:
if raw_value is None:
return None
return DataAPITime.from_string(raw_value).to_time()
return _tpostprocessor_time_stdlib
elif column_type == ColumnType.TIMESTAMP:
if options.custom_datatypes_in_reading:
def _tpostprocessor_timestamp(raw_value: Any) -> DataAPITimestamp | None:
if raw_value is None:
return None
return DataAPITimestamp.from_string(raw_value)
return _tpostprocessor_timestamp
else:
def _tpostprocessor_timestamp_stdlib(
raw_value: Any,
) -> datetime.datetime | None:
if raw_value is None:
return None
da_timestamp = DataAPITimestamp.from_string(raw_value)
return da_timestamp.to_datetime(tz=options.datetime_tzinfo)
return _tpostprocessor_timestamp_stdlib
elif column_type == ColumnType.DURATION:
if options.custom_datatypes_in_reading:
def _tpostprocessor_duration(raw_value: Any) -> DataAPIDuration | None:
if raw_value is None:
return None
return DataAPIDuration.from_string(raw_value)
return _tpostprocessor_duration
else:
def _tpostprocessor_duration_stdlib(
raw_value: Any,
) -> datetime.timedelta | None:
if raw_value is None:
return None
return DataAPIDuration.from_string(raw_value).to_timedelta()
return _tpostprocessor_duration_stdlib
elif column_type == ColumnType.INET:
def _tpostprocessor_inet(
raw_value: Any,
) -> ipaddress.IPv4Address | ipaddress.IPv6Address | None:
if raw_value is None:
return None
return ipaddress.ip_address(raw_value)
return _tpostprocessor_inet
elif column_type == ColumnType.DECIMAL:
def _tpostprocessor_decimal(raw_value: Any) -> decimal.Decimal | None:
if raw_value is None:
return None
elif isinstance(raw_value, decimal.Decimal):
return raw_value
# else: it is "NaN", "-Infinity" or "Infinity"
return decimal.Decimal(f"{raw_value}")
return _tpostprocessor_decimal
else:
raise ValueError(f"Unrecognized scalar type for reads: {column_type}")
def _create_unsupported_tpostprocessor(
cql_definition: str,
options: FullSerdesOptions,
) -> Callable[[Any], Any]:
if cql_definition == "counter":
return _create_scalar_tpostprocessor(
column_type=ColumnType.INT, options=options
)
elif cql_definition == "varchar":
return _create_scalar_tpostprocessor(
column_type=ColumnType.TEXT, options=options
)
elif cql_definition == "timeuuid":
return _create_scalar_tpostprocessor(
column_type=ColumnType.UUID, options=options
)
else:
raise ValueError(
f"Unrecognized table unsupported-column cqlDefinition for reads: {cql_definition}"
)
def _column_filler_value(
col_def: TableColumnTypeDescriptor,
options: FullSerdesOptions,
) -> Any:
if isinstance(col_def, TableScalarColumnTypeDescriptor):
return None
elif isinstance(col_def, TableVectorColumnTypeDescriptor):
if col_def.column_type == TableVectorColumnType.VECTOR:
return None
else:
raise ValueError(
f"Unrecognized table vector-column descriptor for reads: {col_def.as_dict()}"
)
elif isinstance(col_def, TableValuedColumnTypeDescriptor):
if col_def.column_type == TableValuedColumnType.LIST:
return []
elif TableValuedColumnType.SET:
if options.custom_datatypes_in_reading:
return DataAPISet()
else:
return set()
else:
raise ValueError(
f"Unrecognized table valued-column descriptor for reads: {col_def.as_dict()}"
)
elif isinstance(col_def, TableKeyValuedColumnTypeDescriptor):
if col_def.column_type == TableKeyValuedColumnType.MAP:
if options.custom_datatypes_in_reading:
return DataAPIMap()
else:
return {}
else:
raise ValueError(
f"Unrecognized table key-valued-column descriptor for reads: {col_def.as_dict()}"
)
elif isinstance(col_def, TableUnsupportedColumnTypeDescriptor):
# For lack of better information,
# the filler for unreported unsupported columns is a None:
return None
else:
raise ValueError(
f"Unrecognized table column descriptor for reads: {col_def.as_dict()}"
)
def _create_column_tpostprocessor(
col_def: TableColumnTypeDescriptor,
options: FullSerdesOptions,
) -> Callable[[Any], Any]:
if isinstance(col_def, TableScalarColumnTypeDescriptor):
return _create_scalar_tpostprocessor(col_def.column_type, options=options)
elif isinstance(col_def, TableVectorColumnTypeDescriptor):
if col_def.column_type == TableVectorColumnType.VECTOR:
value_tpostprocessor = _create_scalar_tpostprocessor(
ColumnType.FLOAT,
options=options,
)
if options.custom_datatypes_in_reading:
def _tpostprocessor_vector(
raw_items: list[float] | dict[str, str] | None,
) -> DataAPIVector | None:
if raw_items is None:
return None
elif isinstance(raw_items, dict):
# {"$binary": ...}
return DataAPIVector.from_bytes(
convert_ejson_binary_object_to_bytes(raw_items)
)
return DataAPIVector(
[value_tpostprocessor(item) for item in raw_items]
)
return _tpostprocessor_vector
else:
def _tpostprocessor_vector_as_list(
raw_items: list[float] | dict[str, str] | None,
) -> list[float] | None:
if raw_items is None:
return None
elif isinstance(raw_items, dict):
# {"$binary": ...}
return DataAPIVector.from_bytes(
convert_ejson_binary_object_to_bytes(raw_items)
).data
return [value_tpostprocessor(item) for item in raw_items]
return _tpostprocessor_vector_as_list
else:
raise ValueError(
f"Unrecognized table vector-column descriptor for reads: {col_def.as_dict()}"
)
elif isinstance(col_def, TableValuedColumnTypeDescriptor):
if col_def.column_type == TableValuedColumnType.LIST:
value_tpostprocessor = _create_scalar_tpostprocessor(
col_def.value_type, options=options
)
def _tpostprocessor_list(raw_items: list[Any] | None) -> list[Any] | None:
if raw_items is None:
return None
return [value_tpostprocessor(item) for item in raw_items]
return _tpostprocessor_list
elif TableValuedColumnType.SET:
value_tpostprocessor = _create_scalar_tpostprocessor(
col_def.value_type, options=options
)
if options.custom_datatypes_in_reading:
def _tpostprocessor_dataapiset(
raw_items: set[Any] | None,
) -> DataAPISet[Any] | None:
if raw_items is None:
return None
return DataAPISet(value_tpostprocessor(item) for item in raw_items)
return _tpostprocessor_dataapiset
else:
def _tpostprocessor_dataapiset_as_set(
raw_items: set[Any] | None,
) -> set[Any] | None:
if raw_items is None:
return None
return {value_tpostprocessor(item) for item in raw_items}
return _tpostprocessor_dataapiset_as_set
else:
raise ValueError(
f"Unrecognized table valued-column descriptor for reads: {col_def.as_dict()}"
)
elif isinstance(col_def, TableKeyValuedColumnTypeDescriptor):
if col_def.column_type == TableKeyValuedColumnType.MAP:
key_tpostprocessor = _create_scalar_tpostprocessor(
col_def.key_type, options=options
)
value_tpostprocessor = _create_scalar_tpostprocessor(
col_def.value_type, options=options
)
if options.custom_datatypes_in_reading:
def _tpostprocessor_dataapimap(
raw_items: dict[Any, Any] | None,
) -> DataAPIMap[Any, Any] | None:
if raw_items is None:
return None
return DataAPIMap(
(key_tpostprocessor(k), value_tpostprocessor(v))
for k, v in raw_items.items()
)
return _tpostprocessor_dataapimap
else:
def _tpostprocessor_dataapimap_as_dict(
raw_items: dict[Any, Any] | None,
) -> dict[Any, Any] | None:
if raw_items is None:
return None
return {
key_tpostprocessor(k): value_tpostprocessor(v)
for k, v in raw_items.items()
}
return _tpostprocessor_dataapimap_as_dict
else:
raise ValueError(
f"Unrecognized table key-valued-column descriptor for reads: {col_def.as_dict()}"
)
elif isinstance(col_def, TableUnsupportedColumnTypeDescriptor):
if col_def.column_type == TableUnsupportedColumnType.UNSUPPORTED:
# if UNSUPPORTED columns encountered: find the 'type' in the right place:
return _create_unsupported_tpostprocessor(
cql_definition=col_def.api_support.cql_definition,
options=options,
)
else:
raise ValueError(
f"Unrecognized table unsupported-column descriptor for reads: {col_def.as_dict()}"
)
else:
raise ValueError(
f"Unrecognized table column descriptor for reads: {col_def.as_dict()}"
)
def create_row_tpostprocessor(
columns: dict[str, TableColumnTypeDescriptor],
options: FullSerdesOptions,
similarity_pseudocolumn: str | None,
) -> Callable[[dict[str, Any]], dict[str, Any]]:
tpostprocessor_map = {
col_name: _create_column_tpostprocessor(col_definition, options=options)
for col_name, col_definition in columns.items()
}
tfiller_map = {
col_name: _column_filler_value(col_definition, options=options)
for col_name, col_definition in columns.items()
}
if similarity_pseudocolumn is not None:
# whatever in the passed schema, requiring similarity overrides that 'column':
tpostprocessor_map[similarity_pseudocolumn] = _create_scalar_tpostprocessor(
column_type=ColumnType.FLOAT, options=options
)
tfiller_map[similarity_pseudocolumn] = None
column_name_set = set(tpostprocessor_map.keys())
def _tpostprocessor(raw_dict: dict[str, Any]) -> dict[str, Any]:
extra_fields = set(raw_dict.keys()) - column_name_set
if extra_fields:
xf_desc = ", ".join(f'"{f}"' for f in sorted(extra_fields))
raise ValueError(f"Returned row has unexpected fields: {xf_desc}")
return {
col_name: (
# making a copy here, since the user may mutate e.g. a map:
copy.copy(tfiller_map[col_name])
if col_name not in raw_dict
else tpostprocessor(raw_dict[col_name])
)
for col_name, tpostprocessor in tpostprocessor_map.items()
}
return _tpostprocessor
def create_key_ktpostprocessor(
primary_key_schema: dict[str, TableColumnTypeDescriptor],
options: FullSerdesOptions,
) -> Callable[[list[Any]], tuple[tuple[Any, ...], dict[str, Any]]]:
ktpostprocessor_list: list[tuple[str, Callable[[Any], Any]]] = [
(col_name, _create_column_tpostprocessor(col_definition, options=options))
for col_name, col_definition in primary_key_schema.items()
]
def _ktpostprocessor(
primary_key_list: list[Any],
) -> tuple[tuple[Any, ...], dict[str, Any]]:
if len(primary_key_list) != len(ktpostprocessor_list):
raise ValueError(
"Primary key list length / schema mismatch "
f"(expected {len(ktpostprocessor_list)}, "
f"received {len(primary_key_list)} fields)"
)
k_tuple = tuple(
[
ktpostprocessor(pk_col_value)
for pk_col_value, (_, ktpostprocessor) in zip(
primary_key_list,
ktpostprocessor_list,
)
]
)
k_dict = {
pk_col_name: pk_processed_value
for pk_processed_value, (pk_col_name, _) in zip(
k_tuple,
ktpostprocessor_list,
)
}
return k_tuple, k_dict
return _ktpostprocessor
def preprocess_table_payload_value(
path: list[str], value: Any, options: FullSerdesOptions
) -> Any:
"""
Walk a payload for Tables and apply the necessary and required conversions
to make it into a ready-to-jsondumps object.
"""
# is this a nesting structure?
if isinstance(value, (dict, DataAPIMap)):
return {
preprocess_table_payload_value(
path, k, options=options
): preprocess_table_payload_value(path + [k], v, options=options)
for k, v in value.items()
}
elif isinstance(value, (list, set, DataAPISet)):
return [
preprocess_table_payload_value(path + [""], v, options=options)
for v in value
]
# it's a scalar of some kind (which includes DataAPIVector)
if isinstance(value, float):
# Non-numbers must be manually made into a string
if math.isnan(value):
return NAN_FLOAT_STRING_REPRESENTATION
elif math.isinf(value):
if value > 0:
return PLUS_INFINITY_FLOAT_STRING_REPRESENTATION
else:
return MINUS_INFINITY_FLOAT_STRING_REPRESENTATION
return value
elif isinstance(value, bytes):
return convert_to_ejson_bytes(value)
elif isinstance(value, DataAPIVector):
if options.binary_encode_vectors:
return convert_to_ejson_bytes(value.to_bytes())
else:
# regular list of floats - which can contain non-numbers:
return [
preprocess_table_payload_value(path + [""], fval, options=options)
for fval in value.data
]
elif isinstance(value, DataAPITimestamp):
return value.to_string()
elif isinstance(value, DataAPIDate):
return value.to_string()
elif isinstance(value, DataAPITime):
return value.to_string()
elif isinstance(value, datetime.datetime):
# encoding in two steps (that's because the '%:z' strftime directive
# is not in all supported Python versions).
offset_tuple = _get_datetime_offset(value)
if offset_tuple is None:
if options.accept_naive_datetimes:
return DataAPITimestamp(int(value.timestamp() * 1000)).to_string()
raise ValueError(CANNOT_ENCODE_NAIVE_DATETIME_ERROR_MESSAGE)
date_part_str = value.strftime(DATETIME_DATETIME_FORMAT)
offset_h, offset_m = offset_tuple
offset_part_str = f"{offset_h:+03}:{offset_m:02}"
return f"{date_part_str}{offset_part_str}"
elif isinstance(value, datetime.date):
# there's no format to specify - and this is compliant anyway:
return value.strftime(DATETIME_DATE_FORMAT)
elif isinstance(value, datetime.time):
return value.strftime(DATETIME_TIME_FORMAT)
elif isinstance(value, decimal.Decimal):
# Non-numbers must be manually made into a string, just like floats
if math.isnan(value):
return NAN_FLOAT_STRING_REPRESENTATION
elif math.isinf(value):
if value > 0:
return PLUS_INFINITY_FLOAT_STRING_REPRESENTATION
else:
return MINUS_INFINITY_FLOAT_STRING_REPRESENTATION
# actually-numeric decimals: leave them as they are for the encoding step,
# which will apply the nasty trick to ensure all digits get there.
return value
elif isinstance(value, DataAPIDuration):
# using to_c_string over to_string until the ISO-format parsing can
# cope with subsecond fractions:
return value.to_c_string()
elif isinstance(value, UUID):
return str(value)
elif isinstance(value, (ipaddress.IPv4Address, ipaddress.IPv6Address)):
return str(value)
elif isinstance(value, datetime.timedelta):
return DataAPIDuration.from_timedelta(value).to_c_string()
elif isinstance(value, ObjectId):
raise ValueError(
"Values of type ObjectId are not supported. Consider switching to "
"using UUID-based identifiers instead."
)
# Now it is either a generator-like or a "safe" scalar
# value can be something that must be unrolled:
if options.unroll_iterables_to_lists:
_value = ensure_unrolled_if_iterable(value)
# process it as
if isinstance(_value, list):
return [
preprocess_table_payload_value(path + [""], v, options=options)
for v in _value
]
return _value
# all options are exhausted save for str, int, bool, None:
return value
def preprocess_table_payload(
payload: dict[str, Any] | None, options: FullSerdesOptions
) -> dict[str, Any] | None:
"""
Normalize a payload for API calls.
This includes e.g. ensuring values for "$vector" key
are made into plain lists of floats.
Args:
payload (dict[str, Any]): A dict expressing a payload for an API call
Returns:
dict[str, Any]: a payload dict, pre-processed, ready for HTTP requests.
"""
if payload:
return cast(
dict[str, Any],
preprocess_table_payload_value([], payload, options=options),
)
else:
return payload
class _DecimalCleaner(json.JSONEncoder):
"""
This class cleans decimal (coming from decimal-oriented parsing of responses)
so that the schema can be made into a string, hashed, and used as key to the
converters cache safely.
"""
def default(self, obj: object) -> Any:
if isinstance(obj, decimal.Decimal):
return float(obj)
return super().default(obj)
class _TableConverterAgent(Generic[ROW]):
options: FullSerdesOptions
row_postprocessors: dict[
tuple[str, str | None], Callable[[dict[str, Any]], dict[str, Any]]
]
key_postprocessors: dict[
str, Callable[[list[Any]], tuple[tuple[Any, ...], dict[str, Any]]]
]
def __init__(self, *, options: FullSerdesOptions) -> None:
self.options = options
self.row_postprocessors = {}
self.key_postprocessors = {}
@staticmethod
def _hash_dict(input_dict: dict[str, Any]) -> str:
return hashlib.md5(
json.dumps(
input_dict,
sort_keys=True,
separators=(",", ":"),
cls=_DecimalCleaner,
).encode()
).hexdigest()
def _get_key_postprocessor(
self, primary_key_schema_dict: dict[str, Any]
) -> Callable[[list[Any]], tuple[tuple[Any, ...], dict[str, Any]]]:
schema_hash = self._hash_dict(primary_key_schema_dict)
if schema_hash not in self.key_postprocessors:
primary_key_schema: dict[str, TableColumnTypeDescriptor] = {
col_name: TableColumnTypeDescriptor.coerce(col_dict)
for col_name, col_dict in primary_key_schema_dict.items()
}
self.key_postprocessors[schema_hash] = create_key_ktpostprocessor(
primary_key_schema=primary_key_schema,
options=self.options,
)
return self.key_postprocessors[schema_hash]
def _get_row_postprocessor(
self,
columns_dict: dict[str, Any],
similarity_pseudocolumn: str | None,
) -> Callable[[dict[str, Any]], dict[str, Any]]:
schema_cache_key = (self._hash_dict(columns_dict), similarity_pseudocolumn)
if schema_cache_key not in self.row_postprocessors:
columns: dict[str, TableColumnTypeDescriptor] = {
col_name: TableColumnTypeDescriptor.coerce(col_dict)
for col_name, col_dict in columns_dict.items()
}
self.row_postprocessors[schema_cache_key] = create_row_tpostprocessor(
columns=columns,
options=self.options,
similarity_pseudocolumn=similarity_pseudocolumn,
)
return self.row_postprocessors[schema_cache_key]
def preprocess_payload(
self, payload: dict[str, Any] | None
) -> dict[str, Any] | None:
return preprocess_table_payload(payload, options=self.options)
def postprocess_key(
self, primary_key_list: list[Any], *, primary_key_schema_dict: dict[str, Any]
) -> tuple[tuple[Any, ...], dict[str, Any]]:
"""
The primary key schema is not coerced here, just parsed from its json
"""
return self._get_key_postprocessor(
primary_key_schema_dict=primary_key_schema_dict
)(primary_key_list)
def postprocess_keys(
self,
primary_key_lists: list[list[Any]],
*,
primary_key_schema_dict: dict[str, Any],
) -> list[tuple[tuple[Any, ...], dict[str, Any]]]:
"""
The primary key schema is not coerced here, just parsed from its json
"""
if primary_key_lists:
_k_postprocessor = self._get_key_postprocessor(
primary_key_schema_dict=primary_key_schema_dict
)
return [
_k_postprocessor(primary_key_list)
for primary_key_list in primary_key_lists
]
else:
return []
def postprocess_row(
self,
raw_dict: dict[str, Any],
*,
columns_dict: dict[str, Any],
similarity_pseudocolumn: str | None,
) -> ROW:
"""
The columns schema is not coerced here, just parsed from its json
"""
return self._get_row_postprocessor(
columns_dict=columns_dict, similarity_pseudocolumn=similarity_pseudocolumn
)(raw_dict) # type: ignore[return-value]
def postprocess_rows(
self,
raw_dicts: list[dict[str, Any]],
*,
columns_dict: dict[str, Any],
similarity_pseudocolumn: str | None,
) -> list[ROW]:
"""
The columns schema is not coerced here, just parsed from its json
"""
if raw_dicts:
_r_postprocessor = self._get_row_postprocessor(
columns_dict=columns_dict,
similarity_pseudocolumn=similarity_pseudocolumn,
)
return [cast(ROW, _r_postprocessor(raw_dict)) for raw_dict in raw_dicts]
else:
return []
Functions
def create_key_ktpostprocessor(primary_key_schema: dict[str, TableColumnTypeDescriptor], options: FullSerdesOptions) ‑> Callable[[list[Any]], tuple[tuple[Any, ...], dict[str, Any]]]
-
Expand source code
def create_key_ktpostprocessor( primary_key_schema: dict[str, TableColumnTypeDescriptor], options: FullSerdesOptions, ) -> Callable[[list[Any]], tuple[tuple[Any, ...], dict[str, Any]]]: ktpostprocessor_list: list[tuple[str, Callable[[Any], Any]]] = [ (col_name, _create_column_tpostprocessor(col_definition, options=options)) for col_name, col_definition in primary_key_schema.items() ] def _ktpostprocessor( primary_key_list: list[Any], ) -> tuple[tuple[Any, ...], dict[str, Any]]: if len(primary_key_list) != len(ktpostprocessor_list): raise ValueError( "Primary key list length / schema mismatch " f"(expected {len(ktpostprocessor_list)}, " f"received {len(primary_key_list)} fields)" ) k_tuple = tuple( [ ktpostprocessor(pk_col_value) for pk_col_value, (_, ktpostprocessor) in zip( primary_key_list, ktpostprocessor_list, ) ] ) k_dict = { pk_col_name: pk_processed_value for pk_processed_value, (pk_col_name, _) in zip( k_tuple, ktpostprocessor_list, ) } return k_tuple, k_dict return _ktpostprocessor
def create_row_tpostprocessor(columns: dict[str, TableColumnTypeDescriptor], options: FullSerdesOptions, similarity_pseudocolumn: str | None) ‑> Callable[[dict[str, Any]], dict[str, Any]]
-
Expand source code
def create_row_tpostprocessor( columns: dict[str, TableColumnTypeDescriptor], options: FullSerdesOptions, similarity_pseudocolumn: str | None, ) -> Callable[[dict[str, Any]], dict[str, Any]]: tpostprocessor_map = { col_name: _create_column_tpostprocessor(col_definition, options=options) for col_name, col_definition in columns.items() } tfiller_map = { col_name: _column_filler_value(col_definition, options=options) for col_name, col_definition in columns.items() } if similarity_pseudocolumn is not None: # whatever in the passed schema, requiring similarity overrides that 'column': tpostprocessor_map[similarity_pseudocolumn] = _create_scalar_tpostprocessor( column_type=ColumnType.FLOAT, options=options ) tfiller_map[similarity_pseudocolumn] = None column_name_set = set(tpostprocessor_map.keys()) def _tpostprocessor(raw_dict: dict[str, Any]) -> dict[str, Any]: extra_fields = set(raw_dict.keys()) - column_name_set if extra_fields: xf_desc = ", ".join(f'"{f}"' for f in sorted(extra_fields)) raise ValueError(f"Returned row has unexpected fields: {xf_desc}") return { col_name: ( # making a copy here, since the user may mutate e.g. a map: copy.copy(tfiller_map[col_name]) if col_name not in raw_dict else tpostprocessor(raw_dict[col_name]) ) for col_name, tpostprocessor in tpostprocessor_map.items() } return _tpostprocessor
def preprocess_table_payload(payload: dict[str, Any] | None, options: FullSerdesOptions) ‑> dict[str, typing.Any] | None
-
Normalize a payload for API calls. This includes e.g. ensuring values for "$vector" key are made into plain lists of floats.
Args
payload
:dict[str, Any]
- A dict expressing a payload for an API call
Returns
dict[str, Any]
- a payload dict, pre-processed, ready for HTTP requests.
Expand source code
def preprocess_table_payload( payload: dict[str, Any] | None, options: FullSerdesOptions ) -> dict[str, Any] | None: """ Normalize a payload for API calls. This includes e.g. ensuring values for "$vector" key are made into plain lists of floats. Args: payload (dict[str, Any]): A dict expressing a payload for an API call Returns: dict[str, Any]: a payload dict, pre-processed, ready for HTTP requests. """ if payload: return cast( dict[str, Any], preprocess_table_payload_value([], payload, options=options), ) else: return payload
def preprocess_table_payload_value(path: list[str], value: Any, options: FullSerdesOptions) ‑> Any
-
Walk a payload for Tables and apply the necessary and required conversions to make it into a ready-to-jsondumps object.
Expand source code
def preprocess_table_payload_value( path: list[str], value: Any, options: FullSerdesOptions ) -> Any: """ Walk a payload for Tables and apply the necessary and required conversions to make it into a ready-to-jsondumps object. """ # is this a nesting structure? if isinstance(value, (dict, DataAPIMap)): return { preprocess_table_payload_value( path, k, options=options ): preprocess_table_payload_value(path + [k], v, options=options) for k, v in value.items() } elif isinstance(value, (list, set, DataAPISet)): return [ preprocess_table_payload_value(path + [""], v, options=options) for v in value ] # it's a scalar of some kind (which includes DataAPIVector) if isinstance(value, float): # Non-numbers must be manually made into a string if math.isnan(value): return NAN_FLOAT_STRING_REPRESENTATION elif math.isinf(value): if value > 0: return PLUS_INFINITY_FLOAT_STRING_REPRESENTATION else: return MINUS_INFINITY_FLOAT_STRING_REPRESENTATION return value elif isinstance(value, bytes): return convert_to_ejson_bytes(value) elif isinstance(value, DataAPIVector): if options.binary_encode_vectors: return convert_to_ejson_bytes(value.to_bytes()) else: # regular list of floats - which can contain non-numbers: return [ preprocess_table_payload_value(path + [""], fval, options=options) for fval in value.data ] elif isinstance(value, DataAPITimestamp): return value.to_string() elif isinstance(value, DataAPIDate): return value.to_string() elif isinstance(value, DataAPITime): return value.to_string() elif isinstance(value, datetime.datetime): # encoding in two steps (that's because the '%:z' strftime directive # is not in all supported Python versions). offset_tuple = _get_datetime_offset(value) if offset_tuple is None: if options.accept_naive_datetimes: return DataAPITimestamp(int(value.timestamp() * 1000)).to_string() raise ValueError(CANNOT_ENCODE_NAIVE_DATETIME_ERROR_MESSAGE) date_part_str = value.strftime(DATETIME_DATETIME_FORMAT) offset_h, offset_m = offset_tuple offset_part_str = f"{offset_h:+03}:{offset_m:02}" return f"{date_part_str}{offset_part_str}" elif isinstance(value, datetime.date): # there's no format to specify - and this is compliant anyway: return value.strftime(DATETIME_DATE_FORMAT) elif isinstance(value, datetime.time): return value.strftime(DATETIME_TIME_FORMAT) elif isinstance(value, decimal.Decimal): # Non-numbers must be manually made into a string, just like floats if math.isnan(value): return NAN_FLOAT_STRING_REPRESENTATION elif math.isinf(value): if value > 0: return PLUS_INFINITY_FLOAT_STRING_REPRESENTATION else: return MINUS_INFINITY_FLOAT_STRING_REPRESENTATION # actually-numeric decimals: leave them as they are for the encoding step, # which will apply the nasty trick to ensure all digits get there. return value elif isinstance(value, DataAPIDuration): # using to_c_string over to_string until the ISO-format parsing can # cope with subsecond fractions: return value.to_c_string() elif isinstance(value, UUID): return str(value) elif isinstance(value, (ipaddress.IPv4Address, ipaddress.IPv6Address)): return str(value) elif isinstance(value, datetime.timedelta): return DataAPIDuration.from_timedelta(value).to_c_string() elif isinstance(value, ObjectId): raise ValueError( "Values of type ObjectId are not supported. Consider switching to " "using UUID-based identifiers instead." ) # Now it is either a generator-like or a "safe" scalar # value can be something that must be unrolled: if options.unroll_iterables_to_lists: _value = ensure_unrolled_if_iterable(value) # process it as if isinstance(_value, list): return [ preprocess_table_payload_value(path + [""], v, options=options) for v in _value ] return _value # all options are exhausted save for str, int, bool, None: return value