Module astrapy.data.utils.collection_converters

Expand source code
# Copyright DataStax, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import datetime
from typing import Any, cast

from astrapy.constants import DefaultDocumentType
from astrapy.data.utils.extended_json_converters import (
    convert_ejson_binary_object_to_bytes,
    convert_ejson_date_object_to_apitimestamp,
    convert_ejson_date_object_to_datetime,
    convert_ejson_objectid_object_to_objectid,
    convert_ejson_uuid_object_to_uuid,
    convert_to_ejson_apitimestamp_object,
    convert_to_ejson_bytes,
    convert_to_ejson_date_object,
    convert_to_ejson_objectid_object,
    convert_to_ejson_uuid_object,
)
from astrapy.data.utils.vector_coercion import (
    convert_vector_to_floats,
    ensure_unrolled_if_iterable,
    is_list_of_floats,
)
from astrapy.data_types import DataAPITimestamp, DataAPIVector
from astrapy.ids import UUID, ObjectId
from astrapy.settings.error_messages import CANNOT_ENCODE_NAIVE_DATETIME_ERROR_MESSAGE
from astrapy.utils.api_options import FullSerdesOptions


def preprocess_collection_payload_value(
    path: list[str], value: Any, options: FullSerdesOptions
) -> Any:
    """
    The path helps determining special treatments
    """

    # vector-related pre-processing and coercion
    _value = value
    # is this value in the place for vectors?
    if path[-1:] == ["$vector"] and path[-2:] != ["projection", "$vector"]:
        # must coerce list-likes broadly, and is it the case to do it?
        if options.unroll_iterables_to_lists and not (
            is_list_of_floats(_value) or isinstance(_value, DataAPIVector)
        ):
            _value = convert_vector_to_floats(_value)
        # now _value is either a list or a DataAPIVector.
        # can/should it be binary-encoded?
        can_bin_encode = False
        # TODO: reinstate the following condition once the Data API
        # correctly excludes $binary from indexing for collections:
        # can_bin_encode = path[0] in {"insertOne", "insertMany"}
        # will it be bin-encoded?
        if isinstance(_value, DataAPIVector):
            # if I can, I will
            if can_bin_encode and options.binary_encode_vectors:
                return convert_to_ejson_bytes(_value.to_bytes())
            else:
                # back to a regular list
                return _value.data
        else:
            # this is a list. Encode if serdes options allow it
            if can_bin_encode and options.binary_encode_vectors:
                return convert_to_ejson_bytes(DataAPIVector(_value).to_bytes())
            else:
                return _value

    if options.unroll_iterables_to_lists:
        _value = ensure_unrolled_if_iterable(_value)
    if isinstance(_value, dict):
        return {
            k: preprocess_collection_payload_value(path + [k], v, options=options)
            for k, v in _value.items()
        }
    elif isinstance(_value, list):
        return [
            preprocess_collection_payload_value(path + [""], list_item, options=options)
            for list_item in _value
        ]
    elif isinstance(_value, datetime.datetime):
        if _value.utcoffset() is None and not options.accept_naive_datetimes:
            raise ValueError(CANNOT_ENCODE_NAIVE_DATETIME_ERROR_MESSAGE)
        return convert_to_ejson_date_object(_value)
    elif isinstance(_value, datetime.date):
        # Note: since 'datetime' subclasses 'date', this must come after the previous.
        return convert_to_ejson_date_object(_value)
    elif isinstance(_value, bytes):
        return convert_to_ejson_bytes(_value)
    elif isinstance(_value, UUID):
        return convert_to_ejson_uuid_object(_value)
    elif isinstance(_value, ObjectId):
        return convert_to_ejson_objectid_object(_value)
    elif isinstance(_value, DataAPITimestamp):
        return convert_to_ejson_apitimestamp_object(_value)
    else:
        return _value


def preprocess_collection_payload(
    payload: dict[str, Any] | None, options: FullSerdesOptions
) -> dict[str, Any] | None:
    """
    Normalize a payload for API calls.
    This includes e.g. ensuring values for "$vector" key are properly coerced.

    Args:
        payload (dict[str, Any]): A dict expressing a payload for an API call

    Returns:
        dict[str, Any]: a payload dict, pre-processed, ready for HTTP requests.
    """

    if payload:
        return cast(
            dict[str, Any],
            preprocess_collection_payload_value([], payload, options=options),
        )
    else:
        return payload


def postprocess_collection_response_value(
    path: list[str], value: Any, options: FullSerdesOptions
) -> Any:
    """
    The path helps determining special treatments
    """

    # for reads, everywhere there's a $vector it can be treated as such and reconverted
    if path[-1:] == ["$vector"]:
        # custom faster handling for the $vector path:
        if isinstance(value, list):
            if options.custom_datatypes_in_reading:
                return DataAPIVector(value)
            else:
                return value
        elif isinstance(value, dict):
            _bytes = convert_ejson_binary_object_to_bytes(value)
            if options.custom_datatypes_in_reading:
                return DataAPIVector.from_bytes(_bytes)
            else:
                return DataAPIVector.from_bytes(_bytes).data
        else:
            raise ValueError(
                f"Response parsing failed: unexpected data type found under $vector: {type(value)}"
            )
    if isinstance(value, dict):
        value_keys = set(value.keys())
        if value_keys == {"$date"}:
            # this is `{"$date": 123456}`.
            # Restore to the appropriate APIOptions-required object
            if options.custom_datatypes_in_reading:
                return convert_ejson_date_object_to_apitimestamp(value)
            else:
                return convert_ejson_date_object_to_datetime(
                    value, tz=options.datetime_tzinfo
                )
        elif value_keys == {"$uuid"}:
            # this is `{"$uuid": "abc123..."}`, restore to UUID
            return convert_ejson_uuid_object_to_uuid(value)
        elif value_keys == {"$objectId"}:
            # this is `{"$objectId": "123abc..."}`, restore to ObjectId
            return convert_ejson_objectid_object_to_objectid(value)
        elif value_keys == {"$binary"}:
            # this is `{"$binary": "xyz=="}`, restore to `bytes`
            return convert_ejson_binary_object_to_bytes(value)
        else:
            return {
                k: postprocess_collection_response_value(path + [k], v, options=options)
                for k, v in value.items()
            }
    elif isinstance(value, list):
        return [
            postprocess_collection_response_value(
                path + [""], list_item, options=options
            )
            for list_item in value
        ]
    else:
        return value


def postprocess_collection_response(
    response: DefaultDocumentType, options: FullSerdesOptions
) -> DefaultDocumentType:
    """
    Process a dictionary just returned from the API.
    This is the place where e.g. `{"$date": 123}` is
    converted back into a datetime object.
    """
    return cast(
        DefaultDocumentType,
        postprocess_collection_response_value([], response, options=options),
    )

Functions

def postprocess_collection_response(response: DefaultDocumentType, options: FullSerdesOptions) ‑> dict[str, typing.Any]

Process a dictionary just returned from the API. This is the place where e.g. {"$date": 123} is converted back into a datetime object.

Expand source code
def postprocess_collection_response(
    response: DefaultDocumentType, options: FullSerdesOptions
) -> DefaultDocumentType:
    """
    Process a dictionary just returned from the API.
    This is the place where e.g. `{"$date": 123}` is
    converted back into a datetime object.
    """
    return cast(
        DefaultDocumentType,
        postprocess_collection_response_value([], response, options=options),
    )
def postprocess_collection_response_value(path: list[str], value: Any, options: FullSerdesOptions) ‑> Any

The path helps determining special treatments

Expand source code
def postprocess_collection_response_value(
    path: list[str], value: Any, options: FullSerdesOptions
) -> Any:
    """
    The path helps determining special treatments
    """

    # for reads, everywhere there's a $vector it can be treated as such and reconverted
    if path[-1:] == ["$vector"]:
        # custom faster handling for the $vector path:
        if isinstance(value, list):
            if options.custom_datatypes_in_reading:
                return DataAPIVector(value)
            else:
                return value
        elif isinstance(value, dict):
            _bytes = convert_ejson_binary_object_to_bytes(value)
            if options.custom_datatypes_in_reading:
                return DataAPIVector.from_bytes(_bytes)
            else:
                return DataAPIVector.from_bytes(_bytes).data
        else:
            raise ValueError(
                f"Response parsing failed: unexpected data type found under $vector: {type(value)}"
            )
    if isinstance(value, dict):
        value_keys = set(value.keys())
        if value_keys == {"$date"}:
            # this is `{"$date": 123456}`.
            # Restore to the appropriate APIOptions-required object
            if options.custom_datatypes_in_reading:
                return convert_ejson_date_object_to_apitimestamp(value)
            else:
                return convert_ejson_date_object_to_datetime(
                    value, tz=options.datetime_tzinfo
                )
        elif value_keys == {"$uuid"}:
            # this is `{"$uuid": "abc123..."}`, restore to UUID
            return convert_ejson_uuid_object_to_uuid(value)
        elif value_keys == {"$objectId"}:
            # this is `{"$objectId": "123abc..."}`, restore to ObjectId
            return convert_ejson_objectid_object_to_objectid(value)
        elif value_keys == {"$binary"}:
            # this is `{"$binary": "xyz=="}`, restore to `bytes`
            return convert_ejson_binary_object_to_bytes(value)
        else:
            return {
                k: postprocess_collection_response_value(path + [k], v, options=options)
                for k, v in value.items()
            }
    elif isinstance(value, list):
        return [
            postprocess_collection_response_value(
                path + [""], list_item, options=options
            )
            for list_item in value
        ]
    else:
        return value
def preprocess_collection_payload(payload: dict[str, Any] | None, options: FullSerdesOptions) ‑> dict[str, typing.Any] | None

Normalize a payload for API calls. This includes e.g. ensuring values for "$vector" key are properly coerced.

Args

payload : dict[str, Any]
A dict expressing a payload for an API call

Returns

dict[str, Any]
a payload dict, pre-processed, ready for HTTP requests.
Expand source code
def preprocess_collection_payload(
    payload: dict[str, Any] | None, options: FullSerdesOptions
) -> dict[str, Any] | None:
    """
    Normalize a payload for API calls.
    This includes e.g. ensuring values for "$vector" key are properly coerced.

    Args:
        payload (dict[str, Any]): A dict expressing a payload for an API call

    Returns:
        dict[str, Any]: a payload dict, pre-processed, ready for HTTP requests.
    """

    if payload:
        return cast(
            dict[str, Any],
            preprocess_collection_payload_value([], payload, options=options),
        )
    else:
        return payload
def preprocess_collection_payload_value(path: list[str], value: Any, options: FullSerdesOptions) ‑> Any

The path helps determining special treatments

Expand source code
def preprocess_collection_payload_value(
    path: list[str], value: Any, options: FullSerdesOptions
) -> Any:
    """
    The path helps determining special treatments
    """

    # vector-related pre-processing and coercion
    _value = value
    # is this value in the place for vectors?
    if path[-1:] == ["$vector"] and path[-2:] != ["projection", "$vector"]:
        # must coerce list-likes broadly, and is it the case to do it?
        if options.unroll_iterables_to_lists and not (
            is_list_of_floats(_value) or isinstance(_value, DataAPIVector)
        ):
            _value = convert_vector_to_floats(_value)
        # now _value is either a list or a DataAPIVector.
        # can/should it be binary-encoded?
        can_bin_encode = False
        # TODO: reinstate the following condition once the Data API
        # correctly excludes $binary from indexing for collections:
        # can_bin_encode = path[0] in {"insertOne", "insertMany"}
        # will it be bin-encoded?
        if isinstance(_value, DataAPIVector):
            # if I can, I will
            if can_bin_encode and options.binary_encode_vectors:
                return convert_to_ejson_bytes(_value.to_bytes())
            else:
                # back to a regular list
                return _value.data
        else:
            # this is a list. Encode if serdes options allow it
            if can_bin_encode and options.binary_encode_vectors:
                return convert_to_ejson_bytes(DataAPIVector(_value).to_bytes())
            else:
                return _value

    if options.unroll_iterables_to_lists:
        _value = ensure_unrolled_if_iterable(_value)
    if isinstance(_value, dict):
        return {
            k: preprocess_collection_payload_value(path + [k], v, options=options)
            for k, v in _value.items()
        }
    elif isinstance(_value, list):
        return [
            preprocess_collection_payload_value(path + [""], list_item, options=options)
            for list_item in _value
        ]
    elif isinstance(_value, datetime.datetime):
        if _value.utcoffset() is None and not options.accept_naive_datetimes:
            raise ValueError(CANNOT_ENCODE_NAIVE_DATETIME_ERROR_MESSAGE)
        return convert_to_ejson_date_object(_value)
    elif isinstance(_value, datetime.date):
        # Note: since 'datetime' subclasses 'date', this must come after the previous.
        return convert_to_ejson_date_object(_value)
    elif isinstance(_value, bytes):
        return convert_to_ejson_bytes(_value)
    elif isinstance(_value, UUID):
        return convert_to_ejson_uuid_object(_value)
    elif isinstance(_value, ObjectId):
        return convert_to_ejson_objectid_object(_value)
    elif isinstance(_value, DataAPITimestamp):
        return convert_to_ejson_apitimestamp_object(_value)
    else:
        return _value