Module astrapy.data.utils.table_converters

Expand source code
# Copyright DataStax, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import copy
import datetime
import decimal
import hashlib
import ipaddress
import json
import math
from typing import Any, Callable, Generic, cast

from astrapy.constants import ROW
from astrapy.data.info.table_descriptor.table_columns import (
    TableColumnTypeDescriptor,
    TableKeyValuedColumnTypeDescriptor,
    TableScalarColumnTypeDescriptor,
    TableUnsupportedColumnTypeDescriptor,
    TableValuedColumnTypeDescriptor,
    TableVectorColumnTypeDescriptor,
)
from astrapy.data.utils.extended_json_converters import (
    convert_ejson_binary_object_to_bytes,
    convert_to_ejson_bytes,
)
from astrapy.data.utils.table_types import (
    ColumnType,
    TableKeyValuedColumnType,
    TableUnsupportedColumnType,
    TableValuedColumnType,
    TableVectorColumnType,
)
from astrapy.data.utils.vector_coercion import ensure_unrolled_if_iterable
from astrapy.data_types import (
    DataAPIDate,
    DataAPIDuration,
    DataAPIMap,
    DataAPISet,
    DataAPITime,
    DataAPITimestamp,
    DataAPIVector,
)
from astrapy.ids import UUID, ObjectId
from astrapy.settings.error_messages import CANNOT_ENCODE_NAIVE_DATETIME_ERROR_MESSAGE
from astrapy.utils.api_options import FullSerdesOptions
from astrapy.utils.date_utils import _get_datetime_offset

NAN_FLOAT_STRING_REPRESENTATION = "NaN"
PLUS_INFINITY_FLOAT_STRING_REPRESENTATION = "Infinity"
MINUS_INFINITY_FLOAT_STRING_REPRESENTATION = "-Infinity"
DATETIME_TIME_FORMAT = "%H:%M:%S.%f"
DATETIME_DATE_FORMAT = "%Y-%m-%d"
DATETIME_DATETIME_FORMAT = "%Y-%m-%dT%H:%M:%S.%f"


def _create_scalar_tpostprocessor(
    column_type: ColumnType,
    options: FullSerdesOptions,
) -> Callable[[Any], Any]:
    if column_type in {
        ColumnType.TEXT,
        ColumnType.ASCII,
    }:

        def _tpostprocessor_text(raw_value: Any) -> str | None:
            return raw_value  # type: ignore[no-any-return]

        return _tpostprocessor_text

    elif column_type == ColumnType.BOOLEAN:

        def _tpostprocessor_bool(raw_value: Any) -> bool | None:
            return raw_value  # type: ignore[no-any-return]

        return _tpostprocessor_bool

    elif column_type in {
        ColumnType.INT,
        ColumnType.VARINT,
        ColumnType.BIGINT,
        ColumnType.SMALLINT,
        ColumnType.TINYINT,
    }:

        def _tpostprocessor_int(raw_value: Any) -> int | None:
            if raw_value is None:
                return None
            # the 'int(...)' handles Decimal's
            return int(raw_value)

        return _tpostprocessor_int

    elif column_type in {
        ColumnType.FLOAT,
        ColumnType.DOUBLE,
    }:

        def _tpostprocessor_float(raw_value: Any) -> float | None:
            if raw_value is None:
                return None
            elif isinstance(raw_value, (str, decimal.Decimal)):
                return float(raw_value)
            # just a float already
            return cast(float, raw_value)

        return _tpostprocessor_float

    elif column_type == ColumnType.BLOB:

        def _tpostprocessor_bytes(raw_value: Any) -> bytes | None:
            if raw_value is None:
                return None
            if isinstance(raw_value, dict):
                # {"$binary": ...}
                return convert_ejson_binary_object_to_bytes(raw_value)
            elif isinstance(raw_value, str):
                # within PKSchema, a bare string (e.g. "q83vASNFZ4k=") is encountered
                return convert_ejson_binary_object_to_bytes({"$binary": raw_value})
            else:
                raise ValueError(
                    f"Unexpected value type encountered for a blob column: {column_type}"
                )

        return _tpostprocessor_bytes

    elif column_type == ColumnType.UUID:

        def _tpostprocessor_uuid(raw_value: Any) -> UUID | None:
            if raw_value is None:
                return None
            return UUID(raw_value)

        return _tpostprocessor_uuid

    elif column_type == ColumnType.DATE:
        if options.custom_datatypes_in_reading:

            def _tpostprocessor_date(raw_value: Any) -> DataAPIDate | None:
                if raw_value is None:
                    return None
                return DataAPIDate.from_string(raw_value)

            return _tpostprocessor_date

        else:

            def _tpostprocessor_date_stdlib(raw_value: Any) -> datetime.date | None:
                if raw_value is None:
                    return None
                return DataAPIDate.from_string(raw_value).to_date()

            return _tpostprocessor_date_stdlib

    elif column_type == ColumnType.TIME:
        if options.custom_datatypes_in_reading:

            def _tpostprocessor_time(raw_value: Any) -> DataAPITime | None:
                if raw_value is None:
                    return None
                return DataAPITime.from_string(raw_value)

            return _tpostprocessor_time

        else:

            def _tpostprocessor_time_stdlib(raw_value: Any) -> datetime.time | None:
                if raw_value is None:
                    return None
                return DataAPITime.from_string(raw_value).to_time()

            return _tpostprocessor_time_stdlib

    elif column_type == ColumnType.TIMESTAMP:
        if options.custom_datatypes_in_reading:

            def _tpostprocessor_timestamp(raw_value: Any) -> DataAPITimestamp | None:
                if raw_value is None:
                    return None
                return DataAPITimestamp.from_string(raw_value)

            return _tpostprocessor_timestamp

        else:

            def _tpostprocessor_timestamp_stdlib(
                raw_value: Any,
            ) -> datetime.datetime | None:
                if raw_value is None:
                    return None
                da_timestamp = DataAPITimestamp.from_string(raw_value)
                return da_timestamp.to_datetime(tz=options.datetime_tzinfo)

            return _tpostprocessor_timestamp_stdlib

    elif column_type == ColumnType.DURATION:
        if options.custom_datatypes_in_reading:

            def _tpostprocessor_duration(raw_value: Any) -> DataAPIDuration | None:
                if raw_value is None:
                    return None
                return DataAPIDuration.from_string(raw_value)

            return _tpostprocessor_duration

        else:

            def _tpostprocessor_duration_stdlib(
                raw_value: Any,
            ) -> datetime.timedelta | None:
                if raw_value is None:
                    return None
                return DataAPIDuration.from_string(raw_value).to_timedelta()

            return _tpostprocessor_duration_stdlib

    elif column_type == ColumnType.INET:

        def _tpostprocessor_inet(
            raw_value: Any,
        ) -> ipaddress.IPv4Address | ipaddress.IPv6Address | None:
            if raw_value is None:
                return None
            return ipaddress.ip_address(raw_value)

        return _tpostprocessor_inet

    elif column_type == ColumnType.DECIMAL:

        def _tpostprocessor_decimal(raw_value: Any) -> decimal.Decimal | None:
            if raw_value is None:
                return None
            elif isinstance(raw_value, decimal.Decimal):
                return raw_value
            # else: it is "NaN", "-Infinity" or "Infinity"
            return decimal.Decimal(f"{raw_value}")

        return _tpostprocessor_decimal
    else:
        raise ValueError(f"Unrecognized scalar type for reads: {column_type}")


def _create_unsupported_tpostprocessor(
    cql_definition: str,
    options: FullSerdesOptions,
) -> Callable[[Any], Any]:
    if cql_definition == "counter":
        return _create_scalar_tpostprocessor(
            column_type=ColumnType.INT, options=options
        )
    elif cql_definition == "varchar":
        return _create_scalar_tpostprocessor(
            column_type=ColumnType.TEXT, options=options
        )
    elif cql_definition == "timeuuid":
        return _create_scalar_tpostprocessor(
            column_type=ColumnType.UUID, options=options
        )
    else:
        raise ValueError(
            f"Unrecognized table unsupported-column cqlDefinition for reads: {cql_definition}"
        )


def _column_filler_value(
    col_def: TableColumnTypeDescriptor,
    options: FullSerdesOptions,
) -> Any:
    if isinstance(col_def, TableScalarColumnTypeDescriptor):
        return None
    elif isinstance(col_def, TableVectorColumnTypeDescriptor):
        if col_def.column_type == TableVectorColumnType.VECTOR:
            return None
        else:
            raise ValueError(
                f"Unrecognized table vector-column descriptor for reads: {col_def.as_dict()}"
            )
    elif isinstance(col_def, TableValuedColumnTypeDescriptor):
        if col_def.column_type == TableValuedColumnType.LIST:
            return []
        elif TableValuedColumnType.SET:
            if options.custom_datatypes_in_reading:
                return DataAPISet()
            else:
                return set()
        else:
            raise ValueError(
                f"Unrecognized table valued-column descriptor for reads: {col_def.as_dict()}"
            )
    elif isinstance(col_def, TableKeyValuedColumnTypeDescriptor):
        if col_def.column_type == TableKeyValuedColumnType.MAP:
            if options.custom_datatypes_in_reading:
                return DataAPIMap()
            else:
                return {}
        else:
            raise ValueError(
                f"Unrecognized table key-valued-column descriptor for reads: {col_def.as_dict()}"
            )
    elif isinstance(col_def, TableUnsupportedColumnTypeDescriptor):
        # For lack of better information,
        # the filler for unreported unsupported columns is a None:
        return None
    else:
        raise ValueError(
            f"Unrecognized table column descriptor for reads: {col_def.as_dict()}"
        )


def _create_column_tpostprocessor(
    col_def: TableColumnTypeDescriptor,
    options: FullSerdesOptions,
) -> Callable[[Any], Any]:
    if isinstance(col_def, TableScalarColumnTypeDescriptor):
        return _create_scalar_tpostprocessor(col_def.column_type, options=options)
    elif isinstance(col_def, TableVectorColumnTypeDescriptor):
        if col_def.column_type == TableVectorColumnType.VECTOR:
            value_tpostprocessor = _create_scalar_tpostprocessor(
                ColumnType.FLOAT,
                options=options,
            )

            if options.custom_datatypes_in_reading:

                def _tpostprocessor_vector(
                    raw_items: list[float] | dict[str, str] | None,
                ) -> DataAPIVector | None:
                    if raw_items is None:
                        return None
                    elif isinstance(raw_items, dict):
                        # {"$binary": ...}
                        return DataAPIVector.from_bytes(
                            convert_ejson_binary_object_to_bytes(raw_items)
                        )
                    return DataAPIVector(
                        [value_tpostprocessor(item) for item in raw_items]
                    )

                return _tpostprocessor_vector

            else:

                def _tpostprocessor_vector_as_list(
                    raw_items: list[float] | dict[str, str] | None,
                ) -> list[float] | None:
                    if raw_items is None:
                        return None
                    elif isinstance(raw_items, dict):
                        # {"$binary": ...}
                        return DataAPIVector.from_bytes(
                            convert_ejson_binary_object_to_bytes(raw_items)
                        ).data
                    return [value_tpostprocessor(item) for item in raw_items]

                return _tpostprocessor_vector_as_list

        else:
            raise ValueError(
                f"Unrecognized table vector-column descriptor for reads: {col_def.as_dict()}"
            )
    elif isinstance(col_def, TableValuedColumnTypeDescriptor):
        if col_def.column_type == TableValuedColumnType.LIST:
            value_tpostprocessor = _create_scalar_tpostprocessor(
                col_def.value_type, options=options
            )

            def _tpostprocessor_list(raw_items: list[Any] | None) -> list[Any] | None:
                if raw_items is None:
                    return None
                return [value_tpostprocessor(item) for item in raw_items]

            return _tpostprocessor_list

        elif TableValuedColumnType.SET:
            value_tpostprocessor = _create_scalar_tpostprocessor(
                col_def.value_type, options=options
            )

            if options.custom_datatypes_in_reading:

                def _tpostprocessor_dataapiset(
                    raw_items: set[Any] | None,
                ) -> DataAPISet[Any] | None:
                    if raw_items is None:
                        return None
                    return DataAPISet(value_tpostprocessor(item) for item in raw_items)

                return _tpostprocessor_dataapiset

            else:

                def _tpostprocessor_dataapiset_as_set(
                    raw_items: set[Any] | None,
                ) -> set[Any] | None:
                    if raw_items is None:
                        return None
                    return {value_tpostprocessor(item) for item in raw_items}

                return _tpostprocessor_dataapiset_as_set

        else:
            raise ValueError(
                f"Unrecognized table valued-column descriptor for reads: {col_def.as_dict()}"
            )
    elif isinstance(col_def, TableKeyValuedColumnTypeDescriptor):
        if col_def.column_type == TableKeyValuedColumnType.MAP:
            key_tpostprocessor = _create_scalar_tpostprocessor(
                col_def.key_type, options=options
            )
            value_tpostprocessor = _create_scalar_tpostprocessor(
                col_def.value_type, options=options
            )

            if options.custom_datatypes_in_reading:

                def _tpostprocessor_dataapimap(
                    raw_items: dict[Any, Any] | None,
                ) -> DataAPIMap[Any, Any] | None:
                    if raw_items is None:
                        return None
                    return DataAPIMap(
                        (key_tpostprocessor(k), value_tpostprocessor(v))
                        for k, v in raw_items.items()
                    )

                return _tpostprocessor_dataapimap

            else:

                def _tpostprocessor_dataapimap_as_dict(
                    raw_items: dict[Any, Any] | None,
                ) -> dict[Any, Any] | None:
                    if raw_items is None:
                        return None
                    return {
                        key_tpostprocessor(k): value_tpostprocessor(v)
                        for k, v in raw_items.items()
                    }

                return _tpostprocessor_dataapimap_as_dict
        else:
            raise ValueError(
                f"Unrecognized table key-valued-column descriptor for reads: {col_def.as_dict()}"
            )
    elif isinstance(col_def, TableUnsupportedColumnTypeDescriptor):
        if col_def.column_type == TableUnsupportedColumnType.UNSUPPORTED:
            # if UNSUPPORTED columns encountered: find the 'type' in the right place:
            return _create_unsupported_tpostprocessor(
                cql_definition=col_def.api_support.cql_definition,
                options=options,
            )
        else:
            raise ValueError(
                f"Unrecognized table unsupported-column descriptor for reads: {col_def.as_dict()}"
            )
    else:
        raise ValueError(
            f"Unrecognized table column descriptor for reads: {col_def.as_dict()}"
        )


def create_row_tpostprocessor(
    columns: dict[str, TableColumnTypeDescriptor],
    options: FullSerdesOptions,
    similarity_pseudocolumn: str | None,
) -> Callable[[dict[str, Any]], dict[str, Any]]:
    tpostprocessor_map = {
        col_name: _create_column_tpostprocessor(col_definition, options=options)
        for col_name, col_definition in columns.items()
    }
    tfiller_map = {
        col_name: _column_filler_value(col_definition, options=options)
        for col_name, col_definition in columns.items()
    }
    if similarity_pseudocolumn is not None:
        # whatever in the passed schema, requiring similarity overrides that 'column':
        tpostprocessor_map[similarity_pseudocolumn] = _create_scalar_tpostprocessor(
            column_type=ColumnType.FLOAT, options=options
        )
        tfiller_map[similarity_pseudocolumn] = None
    column_name_set = set(tpostprocessor_map.keys())

    def _tpostprocessor(raw_dict: dict[str, Any]) -> dict[str, Any]:
        extra_fields = set(raw_dict.keys()) - column_name_set
        if extra_fields:
            xf_desc = ", ".join(f'"{f}"' for f in sorted(extra_fields))
            raise ValueError(f"Returned row has unexpected fields: {xf_desc}")
        return {
            col_name: (
                # making a copy here, since the user may mutate e.g. a map:
                copy.copy(tfiller_map[col_name])
                if col_name not in raw_dict
                else tpostprocessor(raw_dict[col_name])
            )
            for col_name, tpostprocessor in tpostprocessor_map.items()
        }

    return _tpostprocessor


def create_key_ktpostprocessor(
    primary_key_schema: dict[str, TableColumnTypeDescriptor],
    options: FullSerdesOptions,
) -> Callable[[list[Any]], tuple[tuple[Any, ...], dict[str, Any]]]:
    ktpostprocessor_list: list[tuple[str, Callable[[Any], Any]]] = [
        (col_name, _create_column_tpostprocessor(col_definition, options=options))
        for col_name, col_definition in primary_key_schema.items()
    ]

    def _ktpostprocessor(
        primary_key_list: list[Any],
    ) -> tuple[tuple[Any, ...], dict[str, Any]]:
        if len(primary_key_list) != len(ktpostprocessor_list):
            raise ValueError(
                "Primary key list length / schema mismatch "
                f"(expected {len(ktpostprocessor_list)}, "
                f"received {len(primary_key_list)} fields)"
            )
        k_tuple = tuple(
            [
                ktpostprocessor(pk_col_value)
                for pk_col_value, (_, ktpostprocessor) in zip(
                    primary_key_list,
                    ktpostprocessor_list,
                )
            ]
        )
        k_dict = {
            pk_col_name: pk_processed_value
            for pk_processed_value, (pk_col_name, _) in zip(
                k_tuple,
                ktpostprocessor_list,
            )
        }
        return k_tuple, k_dict

    return _ktpostprocessor


def preprocess_table_payload_value(
    path: list[str], value: Any, options: FullSerdesOptions
) -> Any:
    """
    Walk a payload for Tables and apply the necessary and required conversions
    to make it into a ready-to-jsondumps object.
    """

    # is this a nesting structure?
    if isinstance(value, (dict, DataAPIMap)):
        return {
            preprocess_table_payload_value(
                path, k, options=options
            ): preprocess_table_payload_value(path + [k], v, options=options)
            for k, v in value.items()
        }
    elif isinstance(value, (list, set, DataAPISet)):
        return [
            preprocess_table_payload_value(path + [""], v, options=options)
            for v in value
        ]

    # it's a scalar of some kind (which includes DataAPIVector)
    if isinstance(value, float):
        # Non-numbers must be manually made into a string
        if math.isnan(value):
            return NAN_FLOAT_STRING_REPRESENTATION
        elif math.isinf(value):
            if value > 0:
                return PLUS_INFINITY_FLOAT_STRING_REPRESENTATION
            else:
                return MINUS_INFINITY_FLOAT_STRING_REPRESENTATION
        return value
    elif isinstance(value, bytes):
        return convert_to_ejson_bytes(value)
    elif isinstance(value, DataAPIVector):
        if options.binary_encode_vectors:
            return convert_to_ejson_bytes(value.to_bytes())
        else:
            # regular list of floats - which can contain non-numbers:
            return [
                preprocess_table_payload_value(path + [""], fval, options=options)
                for fval in value.data
            ]
    elif isinstance(value, DataAPITimestamp):
        return value.to_string()
    elif isinstance(value, DataAPIDate):
        return value.to_string()
    elif isinstance(value, DataAPITime):
        return value.to_string()
    elif isinstance(value, datetime.datetime):
        # encoding in two steps (that's because the '%:z' strftime directive
        # is not in all supported Python versions).
        offset_tuple = _get_datetime_offset(value)
        if offset_tuple is None:
            if options.accept_naive_datetimes:
                return DataAPITimestamp(int(value.timestamp() * 1000)).to_string()
            raise ValueError(CANNOT_ENCODE_NAIVE_DATETIME_ERROR_MESSAGE)
        date_part_str = value.strftime(DATETIME_DATETIME_FORMAT)
        offset_h, offset_m = offset_tuple
        offset_part_str = f"{offset_h:+03}:{offset_m:02}"
        return f"{date_part_str}{offset_part_str}"
    elif isinstance(value, datetime.date):
        # there's no format to specify - and this is compliant anyway:
        return value.strftime(DATETIME_DATE_FORMAT)
    elif isinstance(value, datetime.time):
        return value.strftime(DATETIME_TIME_FORMAT)
    elif isinstance(value, decimal.Decimal):
        # Non-numbers must be manually made into a string, just like floats
        if math.isnan(value):
            return NAN_FLOAT_STRING_REPRESENTATION
        elif math.isinf(value):
            if value > 0:
                return PLUS_INFINITY_FLOAT_STRING_REPRESENTATION
            else:
                return MINUS_INFINITY_FLOAT_STRING_REPRESENTATION
        # actually-numeric decimals: leave them as they are for the encoding step,
        # which will apply the nasty trick to ensure all digits get there.
        return value
    elif isinstance(value, DataAPIDuration):
        # using to_c_string over to_string until the ISO-format parsing can
        # cope with subsecond fractions:
        return value.to_c_string()
    elif isinstance(value, UUID):
        return str(value)
    elif isinstance(value, (ipaddress.IPv4Address, ipaddress.IPv6Address)):
        return str(value)
    elif isinstance(value, datetime.timedelta):
        return DataAPIDuration.from_timedelta(value).to_c_string()
    elif isinstance(value, ObjectId):
        raise ValueError(
            "Values of type ObjectId are not supported. Consider switching to "
            "using UUID-based identifiers instead."
        )

    # Now it is either a generator-like or a "safe" scalar
    # value can be something that must be unrolled:
    if options.unroll_iterables_to_lists:
        _value = ensure_unrolled_if_iterable(value)
        # process it as
        if isinstance(_value, list):
            return [
                preprocess_table_payload_value(path + [""], v, options=options)
                for v in _value
            ]
        return _value

    # all options are exhausted save for str, int, bool, None:
    return value


def preprocess_table_payload(
    payload: dict[str, Any] | None, options: FullSerdesOptions
) -> dict[str, Any] | None:
    """
    Normalize a payload for API calls.
    This includes e.g. ensuring values for "$vector" key
    are made into plain lists of floats.

    Args:
        payload (dict[str, Any]): A dict expressing a payload for an API call

    Returns:
        dict[str, Any]: a payload dict, pre-processed, ready for HTTP requests.
    """

    if payload:
        return cast(
            dict[str, Any],
            preprocess_table_payload_value([], payload, options=options),
        )
    else:
        return payload


class _DecimalCleaner(json.JSONEncoder):
    """
    This class cleans decimal (coming from decimal-oriented parsing of responses)
    so that the schema can be made into a string, hashed, and used as key to the
    converters cache safely.
    """

    def default(self, obj: object) -> Any:
        if isinstance(obj, decimal.Decimal):
            return float(obj)
        return super().default(obj)


class _TableConverterAgent(Generic[ROW]):
    options: FullSerdesOptions
    row_postprocessors: dict[
        tuple[str, str | None], Callable[[dict[str, Any]], dict[str, Any]]
    ]
    key_postprocessors: dict[
        str, Callable[[list[Any]], tuple[tuple[Any, ...], dict[str, Any]]]
    ]

    def __init__(self, *, options: FullSerdesOptions) -> None:
        self.options = options
        self.row_postprocessors = {}
        self.key_postprocessors = {}

    @staticmethod
    def _hash_dict(input_dict: dict[str, Any]) -> str:
        return hashlib.md5(
            json.dumps(
                input_dict,
                sort_keys=True,
                separators=(",", ":"),
                cls=_DecimalCleaner,
            ).encode()
        ).hexdigest()

    def _get_key_postprocessor(
        self, primary_key_schema_dict: dict[str, Any]
    ) -> Callable[[list[Any]], tuple[tuple[Any, ...], dict[str, Any]]]:
        schema_hash = self._hash_dict(primary_key_schema_dict)
        if schema_hash not in self.key_postprocessors:
            primary_key_schema: dict[str, TableColumnTypeDescriptor] = {
                col_name: TableColumnTypeDescriptor.coerce(col_dict)
                for col_name, col_dict in primary_key_schema_dict.items()
            }
            self.key_postprocessors[schema_hash] = create_key_ktpostprocessor(
                primary_key_schema=primary_key_schema,
                options=self.options,
            )
        return self.key_postprocessors[schema_hash]

    def _get_row_postprocessor(
        self,
        columns_dict: dict[str, Any],
        similarity_pseudocolumn: str | None,
    ) -> Callable[[dict[str, Any]], dict[str, Any]]:
        schema_cache_key = (self._hash_dict(columns_dict), similarity_pseudocolumn)
        if schema_cache_key not in self.row_postprocessors:
            columns: dict[str, TableColumnTypeDescriptor] = {
                col_name: TableColumnTypeDescriptor.coerce(col_dict)
                for col_name, col_dict in columns_dict.items()
            }
            self.row_postprocessors[schema_cache_key] = create_row_tpostprocessor(
                columns=columns,
                options=self.options,
                similarity_pseudocolumn=similarity_pseudocolumn,
            )
        return self.row_postprocessors[schema_cache_key]

    def preprocess_payload(
        self, payload: dict[str, Any] | None
    ) -> dict[str, Any] | None:
        return preprocess_table_payload(payload, options=self.options)

    def postprocess_key(
        self, primary_key_list: list[Any], *, primary_key_schema_dict: dict[str, Any]
    ) -> tuple[tuple[Any, ...], dict[str, Any]]:
        """
        The primary key schema is not coerced here, just parsed from its json
        """
        return self._get_key_postprocessor(
            primary_key_schema_dict=primary_key_schema_dict
        )(primary_key_list)

    def postprocess_keys(
        self,
        primary_key_lists: list[list[Any]],
        *,
        primary_key_schema_dict: dict[str, Any],
    ) -> list[tuple[tuple[Any, ...], dict[str, Any]]]:
        """
        The primary key schema is not coerced here, just parsed from its json
        """
        if primary_key_lists:
            _k_postprocessor = self._get_key_postprocessor(
                primary_key_schema_dict=primary_key_schema_dict
            )
            return [
                _k_postprocessor(primary_key_list)
                for primary_key_list in primary_key_lists
            ]
        else:
            return []

    def postprocess_row(
        self,
        raw_dict: dict[str, Any],
        *,
        columns_dict: dict[str, Any],
        similarity_pseudocolumn: str | None,
    ) -> ROW:
        """
        The columns schema is not coerced here, just parsed from its json
        """
        return self._get_row_postprocessor(
            columns_dict=columns_dict, similarity_pseudocolumn=similarity_pseudocolumn
        )(raw_dict)  # type: ignore[return-value]

    def postprocess_rows(
        self,
        raw_dicts: list[dict[str, Any]],
        *,
        columns_dict: dict[str, Any],
        similarity_pseudocolumn: str | None,
    ) -> list[ROW]:
        """
        The columns schema is not coerced here, just parsed from its json
        """
        if raw_dicts:
            _r_postprocessor = self._get_row_postprocessor(
                columns_dict=columns_dict,
                similarity_pseudocolumn=similarity_pseudocolumn,
            )
            return [cast(ROW, _r_postprocessor(raw_dict)) for raw_dict in raw_dicts]
        else:
            return []

Functions

def create_key_ktpostprocessor(primary_key_schema: dict[str, TableColumnTypeDescriptor], options: FullSerdesOptions) ‑> Callable[[list[Any]], tuple[tuple[Any, ...], dict[str, Any]]]
Expand source code
def create_key_ktpostprocessor(
    primary_key_schema: dict[str, TableColumnTypeDescriptor],
    options: FullSerdesOptions,
) -> Callable[[list[Any]], tuple[tuple[Any, ...], dict[str, Any]]]:
    ktpostprocessor_list: list[tuple[str, Callable[[Any], Any]]] = [
        (col_name, _create_column_tpostprocessor(col_definition, options=options))
        for col_name, col_definition in primary_key_schema.items()
    ]

    def _ktpostprocessor(
        primary_key_list: list[Any],
    ) -> tuple[tuple[Any, ...], dict[str, Any]]:
        if len(primary_key_list) != len(ktpostprocessor_list):
            raise ValueError(
                "Primary key list length / schema mismatch "
                f"(expected {len(ktpostprocessor_list)}, "
                f"received {len(primary_key_list)} fields)"
            )
        k_tuple = tuple(
            [
                ktpostprocessor(pk_col_value)
                for pk_col_value, (_, ktpostprocessor) in zip(
                    primary_key_list,
                    ktpostprocessor_list,
                )
            ]
        )
        k_dict = {
            pk_col_name: pk_processed_value
            for pk_processed_value, (pk_col_name, _) in zip(
                k_tuple,
                ktpostprocessor_list,
            )
        }
        return k_tuple, k_dict

    return _ktpostprocessor
def create_row_tpostprocessor(columns: dict[str, TableColumnTypeDescriptor], options: FullSerdesOptions, similarity_pseudocolumn: str | None) ‑> Callable[[dict[str, Any]], dict[str, Any]]
Expand source code
def create_row_tpostprocessor(
    columns: dict[str, TableColumnTypeDescriptor],
    options: FullSerdesOptions,
    similarity_pseudocolumn: str | None,
) -> Callable[[dict[str, Any]], dict[str, Any]]:
    tpostprocessor_map = {
        col_name: _create_column_tpostprocessor(col_definition, options=options)
        for col_name, col_definition in columns.items()
    }
    tfiller_map = {
        col_name: _column_filler_value(col_definition, options=options)
        for col_name, col_definition in columns.items()
    }
    if similarity_pseudocolumn is not None:
        # whatever in the passed schema, requiring similarity overrides that 'column':
        tpostprocessor_map[similarity_pseudocolumn] = _create_scalar_tpostprocessor(
            column_type=ColumnType.FLOAT, options=options
        )
        tfiller_map[similarity_pseudocolumn] = None
    column_name_set = set(tpostprocessor_map.keys())

    def _tpostprocessor(raw_dict: dict[str, Any]) -> dict[str, Any]:
        extra_fields = set(raw_dict.keys()) - column_name_set
        if extra_fields:
            xf_desc = ", ".join(f'"{f}"' for f in sorted(extra_fields))
            raise ValueError(f"Returned row has unexpected fields: {xf_desc}")
        return {
            col_name: (
                # making a copy here, since the user may mutate e.g. a map:
                copy.copy(tfiller_map[col_name])
                if col_name not in raw_dict
                else tpostprocessor(raw_dict[col_name])
            )
            for col_name, tpostprocessor in tpostprocessor_map.items()
        }

    return _tpostprocessor
def preprocess_table_payload(payload: dict[str, Any] | None, options: FullSerdesOptions) ‑> dict[str, typing.Any] | None

Normalize a payload for API calls. This includes e.g. ensuring values for "$vector" key are made into plain lists of floats.

Args

payload : dict[str, Any]
A dict expressing a payload for an API call

Returns

dict[str, Any]
a payload dict, pre-processed, ready for HTTP requests.
Expand source code
def preprocess_table_payload(
    payload: dict[str, Any] | None, options: FullSerdesOptions
) -> dict[str, Any] | None:
    """
    Normalize a payload for API calls.
    This includes e.g. ensuring values for "$vector" key
    are made into plain lists of floats.

    Args:
        payload (dict[str, Any]): A dict expressing a payload for an API call

    Returns:
        dict[str, Any]: a payload dict, pre-processed, ready for HTTP requests.
    """

    if payload:
        return cast(
            dict[str, Any],
            preprocess_table_payload_value([], payload, options=options),
        )
    else:
        return payload
def preprocess_table_payload_value(path: list[str], value: Any, options: FullSerdesOptions) ‑> Any

Walk a payload for Tables and apply the necessary and required conversions to make it into a ready-to-jsondumps object.

Expand source code
def preprocess_table_payload_value(
    path: list[str], value: Any, options: FullSerdesOptions
) -> Any:
    """
    Walk a payload for Tables and apply the necessary and required conversions
    to make it into a ready-to-jsondumps object.
    """

    # is this a nesting structure?
    if isinstance(value, (dict, DataAPIMap)):
        return {
            preprocess_table_payload_value(
                path, k, options=options
            ): preprocess_table_payload_value(path + [k], v, options=options)
            for k, v in value.items()
        }
    elif isinstance(value, (list, set, DataAPISet)):
        return [
            preprocess_table_payload_value(path + [""], v, options=options)
            for v in value
        ]

    # it's a scalar of some kind (which includes DataAPIVector)
    if isinstance(value, float):
        # Non-numbers must be manually made into a string
        if math.isnan(value):
            return NAN_FLOAT_STRING_REPRESENTATION
        elif math.isinf(value):
            if value > 0:
                return PLUS_INFINITY_FLOAT_STRING_REPRESENTATION
            else:
                return MINUS_INFINITY_FLOAT_STRING_REPRESENTATION
        return value
    elif isinstance(value, bytes):
        return convert_to_ejson_bytes(value)
    elif isinstance(value, DataAPIVector):
        if options.binary_encode_vectors:
            return convert_to_ejson_bytes(value.to_bytes())
        else:
            # regular list of floats - which can contain non-numbers:
            return [
                preprocess_table_payload_value(path + [""], fval, options=options)
                for fval in value.data
            ]
    elif isinstance(value, DataAPITimestamp):
        return value.to_string()
    elif isinstance(value, DataAPIDate):
        return value.to_string()
    elif isinstance(value, DataAPITime):
        return value.to_string()
    elif isinstance(value, datetime.datetime):
        # encoding in two steps (that's because the '%:z' strftime directive
        # is not in all supported Python versions).
        offset_tuple = _get_datetime_offset(value)
        if offset_tuple is None:
            if options.accept_naive_datetimes:
                return DataAPITimestamp(int(value.timestamp() * 1000)).to_string()
            raise ValueError(CANNOT_ENCODE_NAIVE_DATETIME_ERROR_MESSAGE)
        date_part_str = value.strftime(DATETIME_DATETIME_FORMAT)
        offset_h, offset_m = offset_tuple
        offset_part_str = f"{offset_h:+03}:{offset_m:02}"
        return f"{date_part_str}{offset_part_str}"
    elif isinstance(value, datetime.date):
        # there's no format to specify - and this is compliant anyway:
        return value.strftime(DATETIME_DATE_FORMAT)
    elif isinstance(value, datetime.time):
        return value.strftime(DATETIME_TIME_FORMAT)
    elif isinstance(value, decimal.Decimal):
        # Non-numbers must be manually made into a string, just like floats
        if math.isnan(value):
            return NAN_FLOAT_STRING_REPRESENTATION
        elif math.isinf(value):
            if value > 0:
                return PLUS_INFINITY_FLOAT_STRING_REPRESENTATION
            else:
                return MINUS_INFINITY_FLOAT_STRING_REPRESENTATION
        # actually-numeric decimals: leave them as they are for the encoding step,
        # which will apply the nasty trick to ensure all digits get there.
        return value
    elif isinstance(value, DataAPIDuration):
        # using to_c_string over to_string until the ISO-format parsing can
        # cope with subsecond fractions:
        return value.to_c_string()
    elif isinstance(value, UUID):
        return str(value)
    elif isinstance(value, (ipaddress.IPv4Address, ipaddress.IPv6Address)):
        return str(value)
    elif isinstance(value, datetime.timedelta):
        return DataAPIDuration.from_timedelta(value).to_c_string()
    elif isinstance(value, ObjectId):
        raise ValueError(
            "Values of type ObjectId are not supported. Consider switching to "
            "using UUID-based identifiers instead."
        )

    # Now it is either a generator-like or a "safe" scalar
    # value can be something that must be unrolled:
    if options.unroll_iterables_to_lists:
        _value = ensure_unrolled_if_iterable(value)
        # process it as
        if isinstance(_value, list):
            return [
                preprocess_table_payload_value(path + [""], v, options=options)
                for v in _value
            ]
        return _value

    # all options are exhausted save for str, int, bool, None:
    return value