# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
from collections.abc import Sequence
from logging import getLogger
from typing import Any

import numpy as np
import onnx
from onnx import helper
from onnx_model import OnnxModel

logger = getLogger(__name__)


class DynamoOnnxHelper:
    """
    Helper class for processing ONNX models exported by Torch Dynamo.
    """

    def __init__(self, model: onnx.ModelProto):
        self.model = OnnxModel(model)

    def update_edges(self, edge_mapping: dict) -> None:
        """
        Updates the edges in the model according to the given mapping.
        """
        for node in self.model.model.graph.node:
            for i in range(len(node.input)):
                if node.input[i] in edge_mapping:
                    node.input[i] = edge_mapping[node.input[i]]
            for i in range(len(node.output)):
                if node.output[i] in edge_mapping:
                    node.output[i] = edge_mapping[node.output[i]]

        for graph_input in self.model.model.graph.input:
            if graph_input.name in edge_mapping:
                graph_input.name = edge_mapping[graph_input.name]
        for graph_output in self.model.model.graph.output:
            if graph_output.name in edge_mapping:
                graph_output.name = edge_mapping[graph_output.name]

    def unroll_function(self, func_name: str) -> None:
        """
        Unrolls the function with the given name in the model.
        """
        logger.debug(f"Unrolling function {func_name}...")
        nodes_to_remove = []
        nodes_to_add = []
        edges_to_remove = []
        edges_to_add = []
        for node in self.model.model.graph.node:
            if node.op_type == func_name:
                nodes_to_remove.append(node)
                edges_to_remove.extend(list(node.input) + list(node.output))

        func_to_remove = None
        for f in self.model.model.functions:
            if f.name == func_name:
                nodes_to_add.extend(list(f.node))
                edges_to_add.extend(list(f.input) + list(f.output))
                func_to_remove = f

        assert len(edges_to_remove) == len(edges_to_add)

        for node in nodes_to_remove:
            self.model.model.graph.node.remove(node)
        for node in nodes_to_add:
            self.model.model.graph.node.append(node)
        if func_to_remove is not None:
            self.model.model.functions.remove(func_to_remove)

        edge_mapping = {}
        for i in range(len(edges_to_remove)):
            k = edges_to_remove[i]
            v = edges_to_add[i]
            if k != v:
                edge_mapping[k] = v

        return self.update_edges(edge_mapping)

    def remove_function(self, func_name: str, input_id: int, output_id: int) -> None:
        """
        Removes the function in the model.
        """
        edge_mapping = {}
        nodes_to_remove = []
        for node in self.model.model.graph.node:
            if node.op_type.find(func_name) != -1:
                edge_mapping[node.input[input_id]] = node.output[output_id]
                nodes_to_remove.append(node)
        for node in nodes_to_remove:
            self.model.model.graph.node.remove(node)

        self.update_edges(edge_mapping)

    def remove_dropout_layer(self) -> None:
        """
        Removes the dropout layer in the model.
        """
        logger.debug("Removing dropout layer...")
        self.remove_function("Dropout", 0, 0)

    def remove_lm_head_layer(self) -> None:
        """
        Removes the LM head layer in the model.
        """
        logger.debug("Removing LM head layer...")
        # bugbug: need to copy the right vi over
        self.remove_function("Linear_lm_head", 2, 0)

    def add_initializer(self, name: str, data_type: int, dims: Sequence[int], vals: Any, raw: bool = True):
        if raw:
            np_type = helper.tensor_dtype_to_np_dtype(data_type)
            if not isinstance(vals, np.ndarray):
                bytes = np.array(vals, dtype=np_type).tobytes()
            else:
                bytes = vals.astype(np_type).tobytes()
            tensor = helper.make_tensor(
                name=name,
                data_type=data_type,
                dims=dims,
                vals=bytes,
                raw=True,
            )
        else:
            tensor = helper.make_tensor(
                name=name,
                data_type=data_type,
                dims=dims,
                vals=vals,
                raw=False,
            )

        self.model.add_initializer(tensor)
        return tensor

    def convert_constants_to_initializers(self, min_size: int = 1) -> None:
        """
        Converts Constant ops of size [min_size] or higher to initializers
        """
        logger.debug(f"Converting constants greater than size {min_size} to initializers")

        constant_nodes = self.model.get_nodes_by_op_type("Constant")
        nodes_to_remove = []

        for node in constant_nodes:
            # Get info from Constant op
            np_data = self.model.get_constant_value(node.output[0])

            # Skip if there are less than [min_size] elements
            if np_data is None or np_data.size < min_size:
                continue

            # Add new initializer with same name as Constant op's output
            for att in node.attribute:
                if att.name == "value":
                    self.add_initializer(
                        name=node.output[0],
                        data_type=att.t.data_type,
                        dims=list(np_data.shape),
                        vals=np_data,
                    )
                    break

            nodes_to_remove.append(node)

        # Remove Constant ops from graph
        self.model.remove_nodes(nodes_to_remove)

    def clear_metadata(self) -> None:
        """
        Clear metadata fields in all nodes
        """
        for graph in self.model.graphs():
            graph.ClearField("metadata_props")
        for node in self.model.nodes():
            node.ClearField("metadata_props")

    @staticmethod
    def fold_transpose_initializers(model) -> None:
        """
        Constant fold Transpose initializers without changing the initializer names
        """
        from onnxscript import ir

        for name, initializer in model.graph.initializers.items():
            user_nodes = initializer.consumers()
            if len(user_nodes) == 1 and user_nodes[0].op_type == "Transpose":
                transpose_node = user_nodes[0]
                perm = transpose_node.attributes.get("perm")
                if perm is None:
                    transposed_tensor = ir.tensor(initializer.const_value.numpy().transpose())
                else:
                    transposed_tensor = ir.tensor(initializer.const_value.numpy().transpose(perm.as_ints()))
                new_initializer = ir.Value(
                    name=initializer.name,
                    shape=transposed_tensor.shape,
                    type=ir.TensorType(transposed_tensor.dtype),
                    const_value=transposed_tensor,
                )
                ir.convenience.replace_all_uses_with(transpose_node.outputs[0], new_initializer)
                model.graph.initializers[name] = new_initializer
                transpose_node.graph.remove(transpose_node, safe=True)
