Module astrapy.data.utils.table_converters

Expand source code
# Copyright DataStax, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import copy
import datetime
import decimal
import hashlib
import ipaddress
import json
import math
from typing import Any, Callable, Dict, Generic, cast

from astrapy.constants import ROW
from astrapy.data.info.table_descriptor.table_columns import (
    TableColumnTypeDescriptor,
    TableKeyValuedColumnTypeDescriptor,
    TableScalarColumnTypeDescriptor,
    TableUnsupportedColumnTypeDescriptor,
    TableValuedColumnTypeDescriptor,
    TableVectorColumnTypeDescriptor,
)
from astrapy.data.utils.extended_json_converters import (
    convert_ejson_binary_object_to_bytes,
    convert_to_ejson_bytes,
)
from astrapy.data.utils.table_types import (
    ColumnType,
    TableKeyValuedColumnType,
    TableUnsupportedColumnType,
    TableValuedColumnType,
    TableVectorColumnType,
)
from astrapy.data.utils.vector_coercion import ensure_unrolled_if_iterable
from astrapy.data_types import (
    DataAPIDate,
    DataAPIDuration,
    DataAPIMap,
    DataAPISet,
    DataAPITime,
    DataAPITimestamp,
    DataAPIVector,
)
from astrapy.ids import UUID, ObjectId
from astrapy.settings.error_messages import CANNOT_ENCODE_NAIVE_DATETIME_ERROR_MESSAGE
from astrapy.utils.api_options import FullSerdesOptions
from astrapy.utils.date_utils import _get_datetime_offset

NAN_FLOAT_STRING_REPRESENTATION = "NaN"
PLUS_INFINITY_FLOAT_STRING_REPRESENTATION = "Infinity"
MINUS_INFINITY_FLOAT_STRING_REPRESENTATION = "-Infinity"
DATETIME_TIME_FORMAT = "%H:%M:%S.%f"
DATETIME_DATE_FORMAT = "%Y-%m-%d"
DATETIME_DATETIME_FORMAT = "%Y-%m-%dT%H:%M:%S.%f"


def _create_scalar_tpostprocessor(
    column_type: ColumnType,
    options: FullSerdesOptions,
) -> Callable[[Any], Any]:
    if column_type in {
        ColumnType.TEXT,
        ColumnType.ASCII,
    }:

        def _tpostprocessor_text(raw_value: Any) -> str | None:
            return raw_value  # type: ignore[no-any-return]

        return _tpostprocessor_text

    elif column_type == ColumnType.BOOLEAN:

        def _tpostprocessor_bool(raw_value: Any) -> bool | None:
            return raw_value  # type: ignore[no-any-return]

        return _tpostprocessor_bool

    elif column_type in {
        ColumnType.INT,
        ColumnType.VARINT,
        ColumnType.BIGINT,
        ColumnType.SMALLINT,
        ColumnType.TINYINT,
        ColumnType.COUNTER,
    }:

        def _tpostprocessor_int(raw_value: Any) -> int | None:
            if raw_value is None:
                return None
            # the 'int(...)' handles Decimal's
            return int(raw_value)

        return _tpostprocessor_int

    elif column_type in {
        ColumnType.FLOAT,
        ColumnType.DOUBLE,
    }:

        def _tpostprocessor_float(raw_value: Any) -> float | None:
            if raw_value is None:
                return None
            elif isinstance(raw_value, (str, decimal.Decimal)):
                return float(raw_value)
            # just a float already
            return cast(float, raw_value)

        return _tpostprocessor_float

    elif column_type == ColumnType.BLOB:

        def _tpostprocessor_bytes(raw_value: Any) -> bytes | None:
            if raw_value is None:
                return None
            if isinstance(raw_value, dict):
                # {"$binary": ...}
                return convert_ejson_binary_object_to_bytes(raw_value)
            elif isinstance(raw_value, str):
                # within PKSchema, a bare string (e.g. "q83vASNFZ4k=") is encountered
                return convert_ejson_binary_object_to_bytes({"$binary": raw_value})
            else:
                raise ValueError(
                    f"Unexpected value type encountered for a blob column: {column_type}"
                )

        return _tpostprocessor_bytes

    elif column_type in {ColumnType.UUID, ColumnType.TIMEUUID}:

        def _tpostprocessor_uuid(raw_value: Any) -> UUID | None:
            if raw_value is None:
                return None
            return UUID(raw_value)

        return _tpostprocessor_uuid

    elif column_type == ColumnType.DATE:
        if options.custom_datatypes_in_reading:

            def _tpostprocessor_date(raw_value: Any) -> DataAPIDate | None:
                if raw_value is None:
                    return None
                return DataAPIDate.from_string(raw_value)

            return _tpostprocessor_date

        else:

            def _tpostprocessor_date_stdlib(raw_value: Any) -> datetime.date | None:
                if raw_value is None:
                    return None
                return DataAPIDate.from_string(raw_value).to_date()

            return _tpostprocessor_date_stdlib

    elif column_type == ColumnType.TIME:
        if options.custom_datatypes_in_reading:

            def _tpostprocessor_time(raw_value: Any) -> DataAPITime | None:
                if raw_value is None:
                    return None
                return DataAPITime.from_string(raw_value)

            return _tpostprocessor_time

        else:

            def _tpostprocessor_time_stdlib(raw_value: Any) -> datetime.time | None:
                if raw_value is None:
                    return None
                return DataAPITime.from_string(raw_value).to_time()

            return _tpostprocessor_time_stdlib

    elif column_type == ColumnType.TIMESTAMP:
        if options.custom_datatypes_in_reading:

            def _tpostprocessor_timestamp(raw_value: Any) -> DataAPITimestamp | None:
                if raw_value is None:
                    return None
                return DataAPITimestamp.from_string(raw_value)

            return _tpostprocessor_timestamp

        else:

            def _tpostprocessor_timestamp_stdlib(
                raw_value: Any,
            ) -> datetime.datetime | None:
                if raw_value is None:
                    return None
                da_timestamp = DataAPITimestamp.from_string(raw_value)
                return da_timestamp.to_datetime(tz=options.datetime_tzinfo)

            return _tpostprocessor_timestamp_stdlib

    elif column_type == ColumnType.DURATION:
        if options.custom_datatypes_in_reading:

            def _tpostprocessor_duration(raw_value: Any) -> DataAPIDuration | None:
                if raw_value is None:
                    return None
                return DataAPIDuration.from_string(raw_value)

            return _tpostprocessor_duration

        else:

            def _tpostprocessor_duration_stdlib(
                raw_value: Any,
            ) -> datetime.timedelta | None:
                if raw_value is None:
                    return None
                return DataAPIDuration.from_string(raw_value).to_timedelta()

            return _tpostprocessor_duration_stdlib

    elif column_type == ColumnType.INET:

        def _tpostprocessor_inet(
            raw_value: Any,
        ) -> ipaddress.IPv4Address | ipaddress.IPv6Address | None:
            if raw_value is None:
                return None
            return ipaddress.ip_address(raw_value)

        return _tpostprocessor_inet

    elif column_type == ColumnType.DECIMAL:

        def _tpostprocessor_decimal(raw_value: Any) -> decimal.Decimal | None:
            if raw_value is None:
                return None
            elif isinstance(raw_value, decimal.Decimal):
                return raw_value
            # else: it is "NaN", "-Infinity" or "Infinity"
            return decimal.Decimal(f"{raw_value}")

        return _tpostprocessor_decimal
    else:
        raise ValueError(f"Unrecognized scalar type for reads: {column_type}")


def _create_unsupported_tpostprocessor(
    cql_definition: str,
    options: FullSerdesOptions,
) -> Callable[[Any], Any]:
    if cql_definition == "counter":
        return _create_scalar_tpostprocessor(
            column_type=ColumnType.INT, options=options
        )
    elif cql_definition == "varchar":
        return _create_scalar_tpostprocessor(
            column_type=ColumnType.TEXT, options=options
        )
    elif cql_definition == "timeuuid":
        return _create_scalar_tpostprocessor(
            column_type=ColumnType.UUID, options=options
        )
    else:
        raise ValueError(
            f"Unrecognized table unsupported-column cqlDefinition for reads: {cql_definition}"
        )


def _column_filler_value(
    col_def: TableColumnTypeDescriptor,
    options: FullSerdesOptions,
) -> Any:
    if isinstance(col_def, TableScalarColumnTypeDescriptor):
        return None
    elif isinstance(col_def, TableVectorColumnTypeDescriptor):
        if col_def.column_type == TableVectorColumnType.VECTOR:
            return None
        else:
            raise ValueError(
                f"Unrecognized table vector-column descriptor for reads: {col_def.as_dict()}"
            )
    elif isinstance(col_def, TableValuedColumnTypeDescriptor):
        if col_def.column_type == TableValuedColumnType.LIST:
            return []
        elif TableValuedColumnType.SET:
            if options.custom_datatypes_in_reading:
                return DataAPISet()
            else:
                return set()
        else:
            raise ValueError(
                f"Unrecognized table valued-column descriptor for reads: {col_def.as_dict()}"
            )
    elif isinstance(col_def, TableKeyValuedColumnTypeDescriptor):
        if col_def.column_type == TableKeyValuedColumnType.MAP:
            if options.custom_datatypes_in_reading:
                return DataAPIMap()
            else:
                return {}
        else:
            raise ValueError(
                f"Unrecognized table key-valued-column descriptor for reads: {col_def.as_dict()}"
            )
    elif isinstance(col_def, TableUnsupportedColumnTypeDescriptor):
        # For lack of better information,
        # the filler for unreported unsupported columns is a None:
        return None
    else:
        raise ValueError(
            f"Unrecognized table column descriptor for reads: {col_def.as_dict()}"
        )


def _create_column_tpostprocessor(
    col_def: TableColumnTypeDescriptor,
    options: FullSerdesOptions,
) -> Callable[[Any], Any]:
    if isinstance(col_def, TableScalarColumnTypeDescriptor):
        return _create_scalar_tpostprocessor(col_def.column_type, options=options)
    elif isinstance(col_def, TableVectorColumnTypeDescriptor):
        if col_def.column_type == TableVectorColumnType.VECTOR:
            value_tpostprocessor = _create_scalar_tpostprocessor(
                ColumnType.FLOAT,
                options=options,
            )

            if options.custom_datatypes_in_reading:

                def _tpostprocessor_vector(
                    raw_items: list[float] | dict[str, str] | None,
                ) -> DataAPIVector | None:
                    if raw_items is None:
                        return None
                    elif isinstance(raw_items, dict):
                        # {"$binary": ...}
                        return DataAPIVector.from_bytes(
                            convert_ejson_binary_object_to_bytes(raw_items)
                        )
                    return DataAPIVector(
                        [value_tpostprocessor(item) for item in raw_items]
                    )

                return _tpostprocessor_vector

            else:

                def _tpostprocessor_vector_as_list(
                    raw_items: list[float] | dict[str, str] | None,
                ) -> list[float] | None:
                    if raw_items is None:
                        return None
                    elif isinstance(raw_items, dict):
                        # {"$binary": ...}
                        return DataAPIVector.from_bytes(
                            convert_ejson_binary_object_to_bytes(raw_items)
                        ).data
                    return [value_tpostprocessor(item) for item in raw_items]

                return _tpostprocessor_vector_as_list

        else:
            raise ValueError(
                f"Unrecognized table vector-column descriptor for reads: {col_def.as_dict()}"
            )
    elif isinstance(col_def, TableValuedColumnTypeDescriptor):
        if col_def.column_type == TableValuedColumnType.LIST:
            value_tpostprocessor = _create_scalar_tpostprocessor(
                col_def.value_type, options=options
            )

            def _tpostprocessor_list(raw_items: list[Any] | None) -> list[Any] | None:
                if raw_items is None:
                    return None
                return [value_tpostprocessor(item) for item in raw_items]

            return _tpostprocessor_list

        elif TableValuedColumnType.SET:
            value_tpostprocessor = _create_scalar_tpostprocessor(
                col_def.value_type, options=options
            )

            if options.custom_datatypes_in_reading:

                def _tpostprocessor_dataapiset(
                    raw_items: set[Any] | None,
                ) -> DataAPISet[Any] | None:
                    if raw_items is None:
                        return None
                    return DataAPISet(value_tpostprocessor(item) for item in raw_items)

                return _tpostprocessor_dataapiset

            else:

                def _tpostprocessor_dataapiset_as_set(
                    raw_items: set[Any] | None,
                ) -> set[Any] | None:
                    if raw_items is None:
                        return None
                    return {value_tpostprocessor(item) for item in raw_items}

                return _tpostprocessor_dataapiset_as_set

        else:
            raise ValueError(
                f"Unrecognized table valued-column descriptor for reads: {col_def.as_dict()}"
            )
    elif isinstance(col_def, TableKeyValuedColumnTypeDescriptor):
        if col_def.column_type == TableKeyValuedColumnType.MAP:
            key_tpostprocessor = _create_scalar_tpostprocessor(
                col_def.key_type, options=options
            )
            value_tpostprocessor = _create_scalar_tpostprocessor(
                col_def.value_type, options=options
            )

            if options.custom_datatypes_in_reading:

                def _tpostprocessor_dataapimap(
                    raw_items: dict[Any, Any] | None,
                ) -> DataAPIMap[Any, Any] | None:
                    if raw_items is None:
                        return None
                    return DataAPIMap(
                        (key_tpostprocessor(k), value_tpostprocessor(v))
                        for k, v in raw_items.items()
                    )

                return _tpostprocessor_dataapimap

            else:

                def _tpostprocessor_dataapimap_as_dict(
                    raw_items: dict[Any, Any] | None,
                ) -> dict[Any, Any] | None:
                    if raw_items is None:
                        return None
                    return {
                        key_tpostprocessor(k): value_tpostprocessor(v)
                        for k, v in raw_items.items()
                    }

                return _tpostprocessor_dataapimap_as_dict
        else:
            raise ValueError(
                f"Unrecognized table key-valued-column descriptor for reads: {col_def.as_dict()}"
            )
    elif isinstance(col_def, TableUnsupportedColumnTypeDescriptor):
        if col_def.column_type == TableUnsupportedColumnType.UNSUPPORTED:
            # if UNSUPPORTED columns encountered: find the 'type' in the right place:
            return _create_unsupported_tpostprocessor(
                cql_definition=col_def.api_support.cql_definition,
                options=options,
            )
        else:
            raise ValueError(
                f"Unrecognized table unsupported-column descriptor for reads: {col_def.as_dict()}"
            )
    else:
        raise ValueError(
            f"Unrecognized table column descriptor for reads: {col_def.as_dict()}"
        )


def create_row_tpostprocessor(
    columns: dict[str, TableColumnTypeDescriptor],
    options: FullSerdesOptions,
    similarity_pseudocolumn: str | None,
) -> Callable[[dict[str, Any]], dict[str, Any]]:
    tpostprocessor_map = {
        col_name: _create_column_tpostprocessor(col_definition, options=options)
        for col_name, col_definition in columns.items()
    }
    tfiller_map = {
        col_name: _column_filler_value(col_definition, options=options)
        for col_name, col_definition in columns.items()
    }
    if similarity_pseudocolumn is not None:
        # whatever in the passed schema, requiring similarity overrides that 'column':
        tpostprocessor_map[similarity_pseudocolumn] = _create_scalar_tpostprocessor(
            column_type=ColumnType.FLOAT, options=options
        )
        tfiller_map[similarity_pseudocolumn] = None
    column_name_set = set(tpostprocessor_map.keys())

    def _tpostprocessor(raw_dict: dict[str, Any]) -> dict[str, Any]:
        extra_fields = set(raw_dict.keys()) - column_name_set
        if extra_fields:
            xf_desc = ", ".join(f'"{f}"' for f in sorted(extra_fields))
            raise ValueError(f"Returned row has unexpected fields: {xf_desc}")
        return {
            col_name: (
                # making a copy here, since the user may mutate e.g. a map:
                copy.copy(tfiller_map[col_name])
                if col_name not in raw_dict
                else tpostprocessor(raw_dict[col_name])
            )
            for col_name, tpostprocessor in tpostprocessor_map.items()
        }

    return _tpostprocessor


def create_key_ktpostprocessor(
    primary_key_schema: dict[str, TableColumnTypeDescriptor],
    options: FullSerdesOptions,
) -> Callable[[list[Any]], tuple[tuple[Any, ...], dict[str, Any]]]:
    ktpostprocessor_list: list[tuple[str, Callable[[Any], Any]]] = [
        (col_name, _create_column_tpostprocessor(col_definition, options=options))
        for col_name, col_definition in primary_key_schema.items()
    ]

    def _ktpostprocessor(
        primary_key_list: list[Any],
    ) -> tuple[tuple[Any, ...], dict[str, Any]]:
        if len(primary_key_list) != len(ktpostprocessor_list):
            raise ValueError(
                "Primary key list length / schema mismatch "
                f"(expected {len(ktpostprocessor_list)}, "
                f"received {len(primary_key_list)} fields)"
            )
        k_tuple = tuple(
            [
                ktpostprocessor(pk_col_value)
                for pk_col_value, (_, ktpostprocessor) in zip(
                    primary_key_list,
                    ktpostprocessor_list,
                )
            ]
        )
        k_dict = {
            pk_col_name: pk_processed_value
            for pk_processed_value, (pk_col_name, _) in zip(
                k_tuple,
                ktpostprocessor_list,
            )
        }
        return k_tuple, k_dict

    return _ktpostprocessor


def preprocess_table_payload_value(
    path: list[str], value: Any, options: FullSerdesOptions
) -> Any:
    """
    Walk a payload for Tables and apply the necessary and required conversions
    to make it into a ready-to-jsondumps object.
    """

    # is this a nesting structure?
    if isinstance(value, (dict, DataAPIMap)):
        return {
            preprocess_table_payload_value(
                path, k, options=options
            ): preprocess_table_payload_value(path + [k], v, options=options)
            for k, v in value.items()
        }
    elif isinstance(value, (list, set, DataAPISet)):
        return [
            preprocess_table_payload_value(path + [""], v, options=options)
            for v in value
        ]

    # it's a scalar of some kind (which includes DataAPIVector)
    if isinstance(value, float):
        # Non-numbers must be manually made into a string
        if math.isnan(value):
            return NAN_FLOAT_STRING_REPRESENTATION
        elif math.isinf(value):
            if value > 0:
                return PLUS_INFINITY_FLOAT_STRING_REPRESENTATION
            else:
                return MINUS_INFINITY_FLOAT_STRING_REPRESENTATION
        return value
    elif isinstance(value, bytes):
        return convert_to_ejson_bytes(value)
    elif isinstance(value, DataAPIVector):
        if options.binary_encode_vectors:
            return convert_to_ejson_bytes(value.to_bytes())
        else:
            # regular list of floats - which can contain non-numbers:
            return [
                preprocess_table_payload_value(path + [""], fval, options=options)
                for fval in value.data
            ]
    elif isinstance(value, DataAPITimestamp):
        return value.to_string()
    elif isinstance(value, DataAPIDate):
        return value.to_string()
    elif isinstance(value, DataAPITime):
        return value.to_string()
    elif isinstance(value, datetime.datetime):
        # encoding in two steps (that's because the '%:z' strftime directive
        # is not in all supported Python versions).
        offset_tuple = _get_datetime_offset(value)
        if offset_tuple is None:
            if options.accept_naive_datetimes:
                return DataAPITimestamp(int(value.timestamp() * 1000)).to_string()
            raise ValueError(CANNOT_ENCODE_NAIVE_DATETIME_ERROR_MESSAGE)
        date_part_str = value.strftime(DATETIME_DATETIME_FORMAT)
        offset_h, offset_m = offset_tuple
        offset_part_str = f"{offset_h:+03}:{offset_m:02}"
        return f"{date_part_str}{offset_part_str}"
    elif isinstance(value, datetime.date):
        # there's no format to specify - and this is compliant anyway:
        return value.strftime(DATETIME_DATE_FORMAT)
    elif isinstance(value, datetime.time):
        return value.strftime(DATETIME_TIME_FORMAT)
    elif isinstance(value, decimal.Decimal):
        # Non-numbers must be manually made into a string, just like floats
        if math.isnan(value):
            return NAN_FLOAT_STRING_REPRESENTATION
        elif math.isinf(value):
            if value > 0:
                return PLUS_INFINITY_FLOAT_STRING_REPRESENTATION
            else:
                return MINUS_INFINITY_FLOAT_STRING_REPRESENTATION
        # actually-numeric decimals: leave them as they are for the encoding step,
        # which will apply the nasty trick to ensure all digits get there.
        return value
    elif isinstance(value, DataAPIDuration):
        # using to_c_string over to_string until the ISO-format parsing can
        # cope with subsecond fractions:
        return value.to_c_string()
    elif isinstance(value, UUID):
        return str(value)
    elif isinstance(value, (ipaddress.IPv4Address, ipaddress.IPv6Address)):
        return str(value)
    elif isinstance(value, datetime.timedelta):
        return DataAPIDuration.from_timedelta(value).to_c_string()
    elif isinstance(value, ObjectId):
        raise ValueError(
            "Values of type ObjectId are not supported. Consider switching to "
            "using UUID-based identifiers instead."
        )

    # Now it is either a generator-like or a "safe" scalar
    # value can be something that must be unrolled:
    if options.unroll_iterables_to_lists:
        _value = ensure_unrolled_if_iterable(value)
        # process it as
        if isinstance(_value, list):
            return [
                preprocess_table_payload_value(path + [""], v, options=options)
                for v in _value
            ]
        return _value

    # all options are exhausted save for str, int, bool, None:
    return value


def preprocess_table_payload(
    payload: dict[str, Any] | None, options: FullSerdesOptions
) -> dict[str, Any] | None:
    """
    Normalize a payload for API calls.
    This includes e.g. ensuring values for "$vector" key
    are made into plain lists of floats.

    Args:
        payload (dict[str, Any]): A dict expressing a payload for an API call

    Returns:
        dict[str, Any]: a payload dict, pre-processed, ready for HTTP requests.
    """

    if payload:
        return cast(
            Dict[str, Any],
            preprocess_table_payload_value([], payload, options=options),
        )
    else:
        return payload


class _DecimalCleaner(json.JSONEncoder):
    """
    This class cleans decimal (coming from decimal-oriented parsing of responses)
    so that the schema can be made into a string, hashed, and used as key to the
    converters cache safely.
    """

    def default(self, obj: object) -> Any:
        if isinstance(obj, decimal.Decimal):
            return float(obj)
        return super().default(obj)


class _TableConverterAgent(Generic[ROW]):
    options: FullSerdesOptions
    row_postprocessors: dict[
        tuple[str, str | None], Callable[[dict[str, Any]], dict[str, Any]]
    ]
    key_postprocessors: dict[
        str, Callable[[list[Any]], tuple[tuple[Any, ...], dict[str, Any]]]
    ]

    def __init__(self, *, options: FullSerdesOptions) -> None:
        self.options = options
        self.row_postprocessors = {}
        self.key_postprocessors = {}

    @staticmethod
    def _hash_dict(input_dict: dict[str, Any]) -> str:
        return hashlib.md5(
            json.dumps(
                input_dict,
                sort_keys=True,
                separators=(",", ":"),
                cls=_DecimalCleaner,
            ).encode()
        ).hexdigest()

    def _get_key_postprocessor(
        self, primary_key_schema_dict: dict[str, Any]
    ) -> Callable[[list[Any]], tuple[tuple[Any, ...], dict[str, Any]]]:
        schema_hash = self._hash_dict(primary_key_schema_dict)
        if schema_hash not in self.key_postprocessors:
            primary_key_schema: dict[str, TableColumnTypeDescriptor] = {
                col_name: TableColumnTypeDescriptor.coerce(col_dict)
                for col_name, col_dict in primary_key_schema_dict.items()
            }
            self.key_postprocessors[schema_hash] = create_key_ktpostprocessor(
                primary_key_schema=primary_key_schema,
                options=self.options,
            )
        return self.key_postprocessors[schema_hash]

    def _get_row_postprocessor(
        self,
        columns_dict: dict[str, Any],
        similarity_pseudocolumn: str | None,
    ) -> Callable[[dict[str, Any]], dict[str, Any]]:
        schema_cache_key = (self._hash_dict(columns_dict), similarity_pseudocolumn)
        if schema_cache_key not in self.row_postprocessors:
            columns: dict[str, TableColumnTypeDescriptor] = {
                col_name: TableColumnTypeDescriptor.coerce(col_dict)
                for col_name, col_dict in columns_dict.items()
            }
            self.row_postprocessors[schema_cache_key] = create_row_tpostprocessor(
                columns=columns,
                options=self.options,
                similarity_pseudocolumn=similarity_pseudocolumn,
            )
        return self.row_postprocessors[schema_cache_key]

    def preprocess_payload(
        self, payload: dict[str, Any] | None
    ) -> dict[str, Any] | None:
        return preprocess_table_payload(payload, options=self.options)

    def postprocess_key(
        self, primary_key_list: list[Any], *, primary_key_schema_dict: dict[str, Any]
    ) -> tuple[tuple[Any, ...], dict[str, Any]]:
        """
        The primary key schema is not coerced here, just parsed from its json
        """
        return self._get_key_postprocessor(
            primary_key_schema_dict=primary_key_schema_dict
        )(primary_key_list)

    def postprocess_keys(
        self,
        primary_key_lists: list[list[Any]],
        *,
        primary_key_schema_dict: dict[str, Any],
    ) -> list[tuple[tuple[Any, ...], dict[str, Any]]]:
        """
        The primary key schema is not coerced here, just parsed from its json
        """
        if primary_key_lists:
            _k_postprocessor = self._get_key_postprocessor(
                primary_key_schema_dict=primary_key_schema_dict
            )
            return [
                _k_postprocessor(primary_key_list)
                for primary_key_list in primary_key_lists
            ]
        else:
            return []

    def postprocess_row(
        self,
        raw_dict: dict[str, Any],
        *,
        columns_dict: dict[str, Any],
        similarity_pseudocolumn: str | None,
    ) -> ROW:
        """
        The columns schema is not coerced here, just parsed from its json
        """
        return self._get_row_postprocessor(
            columns_dict=columns_dict, similarity_pseudocolumn=similarity_pseudocolumn
        )(raw_dict)  # type: ignore[return-value]

    def postprocess_rows(
        self,
        raw_dicts: list[dict[str, Any]],
        *,
        columns_dict: dict[str, Any],
        similarity_pseudocolumn: str | None,
    ) -> list[ROW]:
        """
        The columns schema is not coerced here, just parsed from its json
        """
        if raw_dicts:
            _r_postprocessor = self._get_row_postprocessor(
                columns_dict=columns_dict,
                similarity_pseudocolumn=similarity_pseudocolumn,
            )
            return [cast(ROW, _r_postprocessor(raw_dict)) for raw_dict in raw_dicts]
        else:
            return []

Functions

def create_key_ktpostprocessor(primary_key_schema: dict[str, TableColumnTypeDescriptor], options: FullSerdesOptions) ‑> Callable[[list[Any]], tuple[tuple[Any, ...], dict[str, Any]]]
Expand source code
def create_key_ktpostprocessor(
    primary_key_schema: dict[str, TableColumnTypeDescriptor],
    options: FullSerdesOptions,
) -> Callable[[list[Any]], tuple[tuple[Any, ...], dict[str, Any]]]:
    ktpostprocessor_list: list[tuple[str, Callable[[Any], Any]]] = [
        (col_name, _create_column_tpostprocessor(col_definition, options=options))
        for col_name, col_definition in primary_key_schema.items()
    ]

    def _ktpostprocessor(
        primary_key_list: list[Any],
    ) -> tuple[tuple[Any, ...], dict[str, Any]]:
        if len(primary_key_list) != len(ktpostprocessor_list):
            raise ValueError(
                "Primary key list length / schema mismatch "
                f"(expected {len(ktpostprocessor_list)}, "
                f"received {len(primary_key_list)} fields)"
            )
        k_tuple = tuple(
            [
                ktpostprocessor(pk_col_value)
                for pk_col_value, (_, ktpostprocessor) in zip(
                    primary_key_list,
                    ktpostprocessor_list,
                )
            ]
        )
        k_dict = {
            pk_col_name: pk_processed_value
            for pk_processed_value, (pk_col_name, _) in zip(
                k_tuple,
                ktpostprocessor_list,
            )
        }
        return k_tuple, k_dict

    return _ktpostprocessor
def create_row_tpostprocessor(columns: dict[str, TableColumnTypeDescriptor], options: FullSerdesOptions, similarity_pseudocolumn: str | None) ‑> Callable[[dict[str, Any]], dict[str, Any]]
Expand source code
def create_row_tpostprocessor(
    columns: dict[str, TableColumnTypeDescriptor],
    options: FullSerdesOptions,
    similarity_pseudocolumn: str | None,
) -> Callable[[dict[str, Any]], dict[str, Any]]:
    tpostprocessor_map = {
        col_name: _create_column_tpostprocessor(col_definition, options=options)
        for col_name, col_definition in columns.items()
    }
    tfiller_map = {
        col_name: _column_filler_value(col_definition, options=options)
        for col_name, col_definition in columns.items()
    }
    if similarity_pseudocolumn is not None:
        # whatever in the passed schema, requiring similarity overrides that 'column':
        tpostprocessor_map[similarity_pseudocolumn] = _create_scalar_tpostprocessor(
            column_type=ColumnType.FLOAT, options=options
        )
        tfiller_map[similarity_pseudocolumn] = None
    column_name_set = set(tpostprocessor_map.keys())

    def _tpostprocessor(raw_dict: dict[str, Any]) -> dict[str, Any]:
        extra_fields = set(raw_dict.keys()) - column_name_set
        if extra_fields:
            xf_desc = ", ".join(f'"{f}"' for f in sorted(extra_fields))
            raise ValueError(f"Returned row has unexpected fields: {xf_desc}")
        return {
            col_name: (
                # making a copy here, since the user may mutate e.g. a map:
                copy.copy(tfiller_map[col_name])
                if col_name not in raw_dict
                else tpostprocessor(raw_dict[col_name])
            )
            for col_name, tpostprocessor in tpostprocessor_map.items()
        }

    return _tpostprocessor
def preprocess_table_payload(payload: dict[str, Any] | None, options: FullSerdesOptions) ‑> dict[str, typing.Any] | None

Normalize a payload for API calls. This includes e.g. ensuring values for "$vector" key are made into plain lists of floats.

Args

payload : dict[str, Any]
A dict expressing a payload for an API call

Returns

dict[str, Any]
a payload dict, pre-processed, ready for HTTP requests.
Expand source code
def preprocess_table_payload(
    payload: dict[str, Any] | None, options: FullSerdesOptions
) -> dict[str, Any] | None:
    """
    Normalize a payload for API calls.
    This includes e.g. ensuring values for "$vector" key
    are made into plain lists of floats.

    Args:
        payload (dict[str, Any]): A dict expressing a payload for an API call

    Returns:
        dict[str, Any]: a payload dict, pre-processed, ready for HTTP requests.
    """

    if payload:
        return cast(
            Dict[str, Any],
            preprocess_table_payload_value([], payload, options=options),
        )
    else:
        return payload
def preprocess_table_payload_value(path: list[str], value: Any, options: FullSerdesOptions) ‑> Any

Walk a payload for Tables and apply the necessary and required conversions to make it into a ready-to-jsondumps object.

Expand source code
def preprocess_table_payload_value(
    path: list[str], value: Any, options: FullSerdesOptions
) -> Any:
    """
    Walk a payload for Tables and apply the necessary and required conversions
    to make it into a ready-to-jsondumps object.
    """

    # is this a nesting structure?
    if isinstance(value, (dict, DataAPIMap)):
        return {
            preprocess_table_payload_value(
                path, k, options=options
            ): preprocess_table_payload_value(path + [k], v, options=options)
            for k, v in value.items()
        }
    elif isinstance(value, (list, set, DataAPISet)):
        return [
            preprocess_table_payload_value(path + [""], v, options=options)
            for v in value
        ]

    # it's a scalar of some kind (which includes DataAPIVector)
    if isinstance(value, float):
        # Non-numbers must be manually made into a string
        if math.isnan(value):
            return NAN_FLOAT_STRING_REPRESENTATION
        elif math.isinf(value):
            if value > 0:
                return PLUS_INFINITY_FLOAT_STRING_REPRESENTATION
            else:
                return MINUS_INFINITY_FLOAT_STRING_REPRESENTATION
        return value
    elif isinstance(value, bytes):
        return convert_to_ejson_bytes(value)
    elif isinstance(value, DataAPIVector):
        if options.binary_encode_vectors:
            return convert_to_ejson_bytes(value.to_bytes())
        else:
            # regular list of floats - which can contain non-numbers:
            return [
                preprocess_table_payload_value(path + [""], fval, options=options)
                for fval in value.data
            ]
    elif isinstance(value, DataAPITimestamp):
        return value.to_string()
    elif isinstance(value, DataAPIDate):
        return value.to_string()
    elif isinstance(value, DataAPITime):
        return value.to_string()
    elif isinstance(value, datetime.datetime):
        # encoding in two steps (that's because the '%:z' strftime directive
        # is not in all supported Python versions).
        offset_tuple = _get_datetime_offset(value)
        if offset_tuple is None:
            if options.accept_naive_datetimes:
                return DataAPITimestamp(int(value.timestamp() * 1000)).to_string()
            raise ValueError(CANNOT_ENCODE_NAIVE_DATETIME_ERROR_MESSAGE)
        date_part_str = value.strftime(DATETIME_DATETIME_FORMAT)
        offset_h, offset_m = offset_tuple
        offset_part_str = f"{offset_h:+03}:{offset_m:02}"
        return f"{date_part_str}{offset_part_str}"
    elif isinstance(value, datetime.date):
        # there's no format to specify - and this is compliant anyway:
        return value.strftime(DATETIME_DATE_FORMAT)
    elif isinstance(value, datetime.time):
        return value.strftime(DATETIME_TIME_FORMAT)
    elif isinstance(value, decimal.Decimal):
        # Non-numbers must be manually made into a string, just like floats
        if math.isnan(value):
            return NAN_FLOAT_STRING_REPRESENTATION
        elif math.isinf(value):
            if value > 0:
                return PLUS_INFINITY_FLOAT_STRING_REPRESENTATION
            else:
                return MINUS_INFINITY_FLOAT_STRING_REPRESENTATION
        # actually-numeric decimals: leave them as they are for the encoding step,
        # which will apply the nasty trick to ensure all digits get there.
        return value
    elif isinstance(value, DataAPIDuration):
        # using to_c_string over to_string until the ISO-format parsing can
        # cope with subsecond fractions:
        return value.to_c_string()
    elif isinstance(value, UUID):
        return str(value)
    elif isinstance(value, (ipaddress.IPv4Address, ipaddress.IPv6Address)):
        return str(value)
    elif isinstance(value, datetime.timedelta):
        return DataAPIDuration.from_timedelta(value).to_c_string()
    elif isinstance(value, ObjectId):
        raise ValueError(
            "Values of type ObjectId are not supported. Consider switching to "
            "using UUID-based identifiers instead."
        )

    # Now it is either a generator-like or a "safe" scalar
    # value can be something that must be unrolled:
    if options.unroll_iterables_to_lists:
        _value = ensure_unrolled_if_iterable(value)
        # process it as
        if isinstance(_value, list):
            return [
                preprocess_table_payload_value(path + [""], v, options=options)
                for v in _value
            ]
        return _value

    # all options are exhausted save for str, int, bool, None:
    return value