import array
from uuid import UUID
from typing import Dict, Optional, Tuple, Union, cast
from chromadb.api.configuration import CollectionConfigurationInternal
from chromadb.api.types import Embedding
import chromadb.proto.chroma_pb2 as proto
from chromadb.types import (
    Collection,
    LogRecord,
    Metadata,
    Operation,
    ScalarEncoding,
    Segment,
    SegmentScope,
    SeqId,
    OperationRecord,
    UpdateMetadata,
    Vector,
    VectorEmbeddingRecord,
    VectorQueryResult,
)


# TODO: Unit tests for this file, handling optional states etc
def to_proto_vector(vector: Vector, encoding: ScalarEncoding) -> proto.Vector:
    if encoding == ScalarEncoding.FLOAT32:
        as_bytes = array.array("f", vector).tobytes()
        proto_encoding = proto.ScalarEncoding.FLOAT32
    elif encoding == ScalarEncoding.INT32:
        as_bytes = array.array("i", vector).tobytes()
        proto_encoding = proto.ScalarEncoding.INT32
    else:
        raise ValueError(
            f"Unknown encoding {encoding}, expected one of {ScalarEncoding.FLOAT32} \
            or {ScalarEncoding.INT32}"
        )

    return proto.Vector(dimension=len(vector), vector=as_bytes, encoding=proto_encoding)


def from_proto_vector(vector: proto.Vector) -> Tuple[Embedding, ScalarEncoding]:
    encoding = vector.encoding
    as_array: Union[array.array[float], array.array[int]]
    if encoding == proto.ScalarEncoding.FLOAT32:
        as_array = array.array("f")
        out_encoding = ScalarEncoding.FLOAT32
    elif encoding == proto.ScalarEncoding.INT32:
        as_array = array.array("i")
        out_encoding = ScalarEncoding.INT32
    else:
        raise ValueError(
            f"Unknown encoding {encoding}, expected one of \
            {proto.ScalarEncoding.FLOAT32} or {proto.ScalarEncoding.INT32}"
        )

    as_array.frombytes(vector.vector)
    return (as_array.tolist(), out_encoding)


def from_proto_operation(operation: proto.Operation) -> Operation:
    if operation == proto.Operation.ADD:
        return Operation.ADD
    elif operation == proto.Operation.UPDATE:
        return Operation.UPDATE
    elif operation == proto.Operation.UPSERT:
        return Operation.UPSERT
    elif operation == proto.Operation.DELETE:
        return Operation.DELETE
    else:
        # TODO: full error
        raise RuntimeError(f"Unknown operation {operation}")


def from_proto_metadata(metadata: proto.UpdateMetadata) -> Optional[Metadata]:
    return cast(Optional[Metadata], _from_proto_metadata_handle_none(metadata, False))


def from_proto_update_metadata(
    metadata: proto.UpdateMetadata,
) -> Optional[UpdateMetadata]:
    return cast(
        Optional[UpdateMetadata], _from_proto_metadata_handle_none(metadata, True)
    )


def _from_proto_metadata_handle_none(
    metadata: proto.UpdateMetadata, is_update: bool
) -> Optional[Union[UpdateMetadata, Metadata]]:
    if not metadata.metadata:
        return None
    out_metadata: Dict[str, Union[str, int, float, bool, None]] = {}
    for key, value in metadata.metadata.items():
        if value.HasField("bool_value"):
            out_metadata[key] = value.bool_value
        elif value.HasField("string_value"):
            out_metadata[key] = value.string_value
        elif value.HasField("int_value"):
            out_metadata[key] = value.int_value
        elif value.HasField("float_value"):
            out_metadata[key] = value.float_value
        elif is_update:
            out_metadata[key] = None
        else:
            raise ValueError(f"Metadata key {key} value cannot be None")
    return out_metadata


def to_proto_update_metadata(metadata: UpdateMetadata) -> proto.UpdateMetadata:
    return proto.UpdateMetadata(
        metadata={k: to_proto_metadata_update_value(v) for k, v in metadata.items()}
    )


def from_proto_submit(
    operation_record: proto.OperationRecord, seq_id: SeqId
) -> LogRecord:
    embedding, encoding = from_proto_vector(operation_record.vector)
    record = LogRecord(
        log_offset=seq_id,
        record=OperationRecord(
            id=operation_record.id,
            embedding=embedding,
            encoding=encoding,
            metadata=from_proto_update_metadata(operation_record.metadata),
            operation=from_proto_operation(operation_record.operation),
        ),
    )
    return record


def from_proto_segment(segment: proto.Segment) -> Segment:
    return Segment(
        id=UUID(hex=segment.id),
        type=segment.type,
        scope=from_proto_segment_scope(segment.scope),
        collection=None
        if not segment.HasField("collection")
        else UUID(hex=segment.collection),
        metadata=from_proto_metadata(segment.metadata)
        if segment.HasField("metadata")
        else None,
    )


def to_proto_segment(segment: Segment) -> proto.Segment:
    return proto.Segment(
        id=segment["id"].hex,
        type=segment["type"],
        scope=to_proto_segment_scope(segment["scope"]),
        collection=None if segment["collection"] is None else segment["collection"].hex,
        metadata=None
        if segment["metadata"] is None
        else to_proto_update_metadata(segment["metadata"]),
    )


def from_proto_segment_scope(segment_scope: proto.SegmentScope) -> SegmentScope:
    if segment_scope == proto.SegmentScope.VECTOR:
        return SegmentScope.VECTOR
    elif segment_scope == proto.SegmentScope.METADATA:
        return SegmentScope.METADATA
    elif segment_scope == proto.SegmentScope.RECORD:
        return SegmentScope.RECORD
    else:
        raise RuntimeError(f"Unknown segment scope {segment_scope}")


def to_proto_segment_scope(segment_scope: SegmentScope) -> proto.SegmentScope:
    if segment_scope == SegmentScope.VECTOR:
        return proto.SegmentScope.VECTOR
    elif segment_scope == SegmentScope.METADATA:
        return proto.SegmentScope.METADATA
    elif segment_scope == SegmentScope.RECORD:
        return proto.SegmentScope.RECORD
    else:
        raise RuntimeError(f"Unknown segment scope {segment_scope}")


def to_proto_metadata_update_value(
    value: Union[str, int, float, bool, None]
) -> proto.UpdateMetadataValue:
    # Be careful with the order here. Since bools are a subtype of int in python,
    # isinstance(value, bool) and isinstance(value, int) both return true
    # for a value of bool type.
    if isinstance(value, bool):
        return proto.UpdateMetadataValue(bool_value=value)
    elif isinstance(value, str):
        return proto.UpdateMetadataValue(string_value=value)
    elif isinstance(value, int):
        return proto.UpdateMetadataValue(int_value=value)
    elif isinstance(value, float):
        return proto.UpdateMetadataValue(float_value=value)
    # None is used to delete the metadata key.
    elif value is None:
        return proto.UpdateMetadataValue()
    else:
        raise ValueError(
            f"Unknown metadata value type {type(value)}, expected one of str, int, \
            float, or None"
        )


def from_proto_collection(collection: proto.Collection) -> Collection:
    return Collection(
        id=UUID(hex=collection.id),
        name=collection.name,
        configuration=CollectionConfigurationInternal.from_json_str(
            collection.configuration_json_str
        ),
        metadata=from_proto_metadata(collection.metadata)
        if collection.HasField("metadata")
        else None,
        dimension=collection.dimension
        if collection.HasField("dimension") and collection.dimension
        else None,
        database=collection.database,
        tenant=collection.tenant,
        version=collection.version,
    )


def to_proto_collection(collection: Collection) -> proto.Collection:
    return proto.Collection(
        id=collection["id"].hex,
        name=collection["name"],
        configuration_json_str=collection.get_configuration().to_json_str(),
        metadata=None
        if collection["metadata"] is None
        else to_proto_update_metadata(collection["metadata"]),
        dimension=collection["dimension"],
        tenant=collection["tenant"],
        database=collection["database"],
        version=collection["version"],
    )


def to_proto_operation(operation: Operation) -> proto.Operation:
    if operation == Operation.ADD:
        return proto.Operation.ADD
    elif operation == Operation.UPDATE:
        return proto.Operation.UPDATE
    elif operation == Operation.UPSERT:
        return proto.Operation.UPSERT
    elif operation == Operation.DELETE:
        return proto.Operation.DELETE
    else:
        raise ValueError(
            f"Unknown operation {operation}, expected one of {Operation.ADD}, \
            {Operation.UPDATE}, {Operation.UPDATE}, or {Operation.DELETE}"
        )


def to_proto_submit(
    submit_record: OperationRecord,
) -> proto.OperationRecord:
    vector = None
    if submit_record["embedding"] is not None and submit_record["encoding"] is not None:
        vector = to_proto_vector(submit_record["embedding"], submit_record["encoding"])

    metadata = None
    if submit_record["metadata"] is not None:
        metadata = to_proto_update_metadata(submit_record["metadata"])

    return proto.OperationRecord(
        id=submit_record["id"],
        vector=vector,
        metadata=metadata,
        operation=to_proto_operation(submit_record["operation"]),
    )


def from_proto_vector_embedding_record(
    embedding_record: proto.VectorEmbeddingRecord,
) -> VectorEmbeddingRecord:
    return VectorEmbeddingRecord(
        id=embedding_record.id,
        embedding=from_proto_vector(embedding_record.vector)[0],
    )


def to_proto_vector_embedding_record(
    embedding_record: VectorEmbeddingRecord,
    encoding: ScalarEncoding,
) -> proto.VectorEmbeddingRecord:
    return proto.VectorEmbeddingRecord(
        id=embedding_record["id"],
        vector=to_proto_vector(embedding_record["embedding"], encoding),
    )


def from_proto_vector_query_result(
    vector_query_result: proto.VectorQueryResult,
) -> VectorQueryResult:
    return VectorQueryResult(
        id=vector_query_result.id,
        distance=vector_query_result.distance,
        embedding=from_proto_vector(vector_query_result.vector)[0],
    )
