# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation.  All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
from logging import getLogger

import numpy as np
from fusion_base import Fusion
from fusion_options import AttentionMaskFormat
from fusion_utils import FusionUtils, NumpyHelper
from onnx import NodeProto, TensorProto, helper, numpy_helper
from onnx_model import OnnxModel

logger = getLogger(__name__)


class AttentionMask:
    """
    Fuse Attention subgraph into one Attention node.
    """

    def __init__(self, model: OnnxModel):
        self.model = model
        # A lookup table with mask input as key, and mask index output as value
        self.mask_indice = {}
        # A lookup table with mask input as key, and cast (to int32) output as value
        self.mask_casted = {}
        self.utils = FusionUtils(model)
        self.mask_format = AttentionMaskFormat.MaskIndexEnd
        self.opset_version = model.get_opset_version()

    def set_mask_format(self, mask_format: AttentionMaskFormat):
        self.mask_format = mask_format

    def set_mask_indice(self, mask, mask_index):
        if mask in self.mask_indice:
            assert mask_index == self.mask_indice[mask]
        self.mask_indice[mask] = mask_index

    def get_first_mask(self):
        assert len(self.mask_indice) > 0
        return next(iter(self.mask_indice))

    def process_mask(self, mask_2d: str) -> str | None:
        if self.mask_format == AttentionMaskFormat.NoMask:
            return None

        if mask_2d in self.mask_indice:
            return self.mask_indice[mask_2d]

        # Add cast to convert int64 to int32
        if self.model.find_graph_input(mask_2d):
            casted, input_name = self.utils.cast_graph_input_to_int32(mask_2d)
        else:
            input_name, _cast_node = self.utils.cast_input_to_int32(mask_2d)
            casted = True

        if casted:
            self.mask_casted[mask_2d] = input_name

        # Attention supports int32 attention mask (2D) since 1.4.0
        if self.mask_format == AttentionMaskFormat.AttentionMask:
            self.mask_indice[mask_2d] = input_name
            return input_name

        # Add a mask processing node to convert attention mask to mask index (1D)
        output_name = self.model.create_node_name("mask_index")
        if self.opset_version < 13:
            mask_index_node = helper.make_node(
                "ReduceSum",
                inputs=[input_name],
                outputs=[output_name],
                name=self.model.create_node_name("ReduceSum", "MaskReduceSum"),
            )
            mask_index_node.attribute.extend([helper.make_attribute("axes", [1]), helper.make_attribute("keepdims", 0)])
        else:
            # ReduceSum-13: axes is moved from attribute to input
            axes_name = "ort_const_1_reduce_sum_axes"
            if self.model.get_initializer(axes_name) is None:
                self.model.add_initializer(
                    helper.make_tensor(
                        name=axes_name,
                        data_type=TensorProto.INT64,
                        dims=[1],
                        vals=[1],
                        raw=False,
                    )
                )
            mask_index_node = helper.make_node(
                "ReduceSum",
                inputs=[input_name, axes_name],
                outputs=[output_name],
                name=self.model.create_node_name("ReduceSum", "MaskReduceSum"),
            )
            mask_index_node.attribute.extend([helper.make_attribute("keepdims", 0)])

        self.model.add_node(mask_index_node)

        self.mask_indice[mask_2d] = output_name
        return output_name


class FusionAttention(Fusion):
    """
    Fuse Attention subgraph into one Attention node.
    """

    def __init__(
        self,
        model: OnnxModel,
        hidden_size: int,
        num_heads: int,
        attention_mask: AttentionMask | None = None,
        use_multi_head_attention: bool = False,
        disable_multi_head_attention_bias: bool = False,
        search_op_types: list[str] = ["SkipLayerNormalization", "LayerNormalization"],  # noqa: B006
    ):
        attention_op_name = "MultiHeadAttention" if use_multi_head_attention else "Attention"
        super().__init__(model, attention_op_name, search_op_types)
        self.hidden_size = hidden_size
        self.num_heads = num_heads
        self.attention_mask = attention_mask if attention_mask else AttentionMask(model)
        self.use_multi_head_attention = use_multi_head_attention
        self.disable_multi_head_attention_bias = disable_multi_head_attention_bias
        self.mask_filter_value = None

        # Flags to show warning only once
        self.num_heads_warning = True
        self.hidden_size_warning = True

        self.shape_infer = None
        self.shape_infer_done = True

    def get_num_heads_and_hidden_size_from_concat(self, concat: NodeProto) -> tuple[int, int]:
        """
        Detect num_heads and hidden_size from Concat node in the following subgraph:

        SkipLayerNormalization or EmbedLayerNormalization
                        /        |
                     MatMul    Shape
                        |        |
                       Add     Gather(indices=0)
                        |        |
                        |      Unsqueeze
                        |        |
                        |     Concat (*, -1, 12, 64)
                        |     /
                       Reshape
                          |
                       Transpose
        """
        if len(concat.input) == 4:
            num_heads = self.model.get_constant_value(concat.input[2])
            head_size = self.model.get_constant_value(concat.input[3])
            if (
                isinstance(num_heads, np.ndarray)
                and num_heads.size == 1
                and isinstance(head_size, np.ndarray)
                and head_size.size == 1
            ):
                return num_heads[0], num_heads[0] * head_size[0]

        return self.num_heads, self.hidden_size

    def get_num_heads_and_hidden_size(self, reshape_q: NodeProto) -> tuple[int, int]:
        """Detect num_heads and hidden_size from a reshape node.

        Args:
            reshape_q (NodeProto): reshape node for Q

        Returns:
            Tuple[int, int]: num_heads and hidden_size
        """
        # we assume that reshape fusion has done, so the shape is a tensor like [0, 0, num_heads, head_size]
        q_shape_value = self.model.get_constant_value(reshape_q.input[1])
        if q_shape_value is None:
            concat = self.model.get_parent(reshape_q, 1)
            if concat is not None and concat.op_type == "Concat":
                return self.get_num_heads_and_hidden_size_from_concat(concat)
            logger.debug("%s is not initializer.", reshape_q.input[1])
            return self.num_heads, self.hidden_size  # Fall back to user specified value

        if (
            (not isinstance(q_shape_value, np.ndarray))
            or len(q_shape_value) != 4
            or (q_shape_value[2] <= 0 or q_shape_value[3] <= 0)
        ):
            logger.debug("q_shape_value=%s. Expected value are like [0, 0, num_heads, head_size].", q_shape_value)
            return self.num_heads, self.hidden_size  # Fall back to user specified value

        num_heads = q_shape_value[2]
        head_size = q_shape_value[3]
        hidden_size = num_heads * head_size

        if self.num_heads > 0 and num_heads != self.num_heads:
            if self.num_heads_warning:
                logger.warning(
                    "--num_heads is %d. Detected value is %d. Using detected value.", self.num_heads, num_heads
                )
                self.num_heads_warning = False  # Do not show the warning more than once

        if self.hidden_size > 0 and hidden_size != self.hidden_size:
            if self.hidden_size_warning:
                logger.warning(
                    "--hidden_size is %d. Detected value is %d. Using detected value.", self.hidden_size, hidden_size
                )
                self.hidden_size_warning = False  # Do not show the warning more than once

        return num_heads, hidden_size

    def get_add_qk_str(self, add_qk: NodeProto):
        if not self.shape_infer_done:
            self.shape_infer = self.model.infer_runtime_shape(update=True)
            self.shape_infer_done = True

        if self.shape_infer is None:
            return None

        input_0_shape = self.shape_infer.get_edge_shape(add_qk.input[0])
        input_1_shape = self.shape_infer.get_edge_shape(add_qk.input[1])

        if input_0_shape is None or input_1_shape is None:
            logger.debug("one of the inputs of %s is None", add_qk)
            return None

        if input_0_shape != input_1_shape:
            logger.debug("the shape of two inputs of %s is not same", add_qk)
            return None

        return add_qk.input[1]

    def reshape_add_qk(self, add_qk: str):
        # Convert 4D mask from (B,1,S,T) to (B,N,S,T)
        # B = batch size, N = num heads, S = source sequence length, T = target sequence length
        mask_output_name = add_qk + "_mask"

        # Check if concat node for (B,1,S,T) --> (B,N,S,T) already exists
        concat_node = list(filter(lambda node: node.output[0] == mask_output_name, self.nodes_to_add))
        if len(concat_node) == 1:
            return mask_output_name

        assert len(concat_node) == 0
        concat_node_name = self.model.create_node_name("Concat")
        concat_add_qk_fp32 = helper.make_node(
            "Concat",
            inputs=[add_qk for _ in range(self.num_heads)],
            outputs=[mask_output_name],
            name=concat_node_name,
            axis=1,
        )
        # Add new node to graph
        self.nodes_to_add.append(concat_add_qk_fp32)
        self.node_name_to_graph_name[concat_node_name] = self.this_graph_name

        return mask_output_name

    def concat_kv(self, past_k: str, past_v: str) -> str:
        """Concatenate past_k and past_v inputs to create past_kv input.

        Args:
            past_k (str): name of past K value
            past_v (str): name of past V value

        Returns:
            kv_output_name (str): name of past KV value
        """
        # Unsqueeze K and V nodes from (B,N,P,H) to (1,B,N,P,H)
        # B = batch size, N = num heads, P = past sequence length, H = head size
        unsqueeze_k_name = self.model.create_node_name("Unsqueeze")
        unsqueeze_v_name = self.model.create_node_name("Unsqueeze")
        k_5d_name = (past_k + "_5d").replace(".", "_")
        v_5d_name = (past_v + "_5d").replace(".", "_")

        k_5d = helper.make_node(
            "Unsqueeze",
            inputs=[past_k],
            outputs=[k_5d_name],
            name=unsqueeze_k_name,
            axes=[0],
        )
        v_5d = helper.make_node(
            "Unsqueeze",
            inputs=[past_v],
            outputs=[v_5d_name],
            name=unsqueeze_v_name,
            axes=[0],
        )

        # Add unsqueeze nodes to graph
        self.nodes_to_add.append(k_5d)
        self.nodes_to_add.append(v_5d)
        self.node_name_to_graph_name[unsqueeze_k_name] = self.this_graph_name
        self.node_name_to_graph_name[unsqueeze_v_name] = self.this_graph_name

        # Concat K and V to get one node of size (2,B,N,P,H)
        concat_node_name = self.model.create_node_name("Concat")
        kv_output_name = past_v.replace(".value", ".kv").replace(".", "_").replace("_value", "_kv")
        concat_kv = helper.make_node(
            "Concat",
            inputs=[k_5d_name, v_5d_name],
            outputs=[kv_output_name],
            name=concat_node_name,
            axis=0,
        )

        # Add concat node to graph
        self.nodes_to_add.append(concat_kv)
        self.node_name_to_graph_name[concat_node_name] = self.this_graph_name

        return kv_output_name

    def split_kv(self, present_k_name: str, present_v_name: str, kv_node: str):
        """Split kv_node containing present KV values into separate present K and present V values.

        Args:
            present_k_name (str): name of output to store present K value in
            present_v_name (str): name of output to store present V value in
            kv_node (str): name of present KV values
        """
        # Split kv_node into present_k and present_v nodes

        # Create initializers for indexing kv_node, whose shape is (2,B,N,P,H)
        k_index, v_index = "index_0", "index_1"
        k_dim = self.model.get_initializer(k_index)
        v_dim = self.model.get_initializer(v_index)
        if k_dim is None:
            k_dim = numpy_helper.from_array(np.array(0, dtype="int64"), name=k_index)
            self.model.add_initializer(k_dim, self.this_graph_name)
        if v_dim is None:
            v_dim = numpy_helper.from_array(np.array(1, dtype="int64"), name=v_index)
            self.model.add_initializer(v_dim, self.this_graph_name)

        # Create nodes to index kv_node
        gather_k_name = self.model.create_node_name("Gather")
        gather_v_name = self.model.create_node_name("Gather")
        present_k = helper.make_node(
            "Gather",
            inputs=[kv_node, k_index],
            outputs=[present_k_name],
            name=gather_k_name,
            axis=0,
        )
        present_v = helper.make_node(
            "Gather",
            inputs=[kv_node, v_index],
            outputs=[present_v_name],
            name=gather_v_name,
            axis=0,
        )

        # Add gather nodes to graph
        self.nodes_to_add.append(present_k)
        self.nodes_to_add.append(present_v)
        self.node_name_to_graph_name[gather_k_name] = self.this_graph_name
        self.node_name_to_graph_name[gather_v_name] = self.this_graph_name

    def create_combined_qkv_bias(
        self,
        q_add: NodeProto,
        k_add: NodeProto | None,
        v_add: NodeProto | None,
        name_prefix: str,
    ) -> NodeProto | None:
        q_bias = self.model.get_initializer(q_add.input[1]) or self.model.get_initializer(q_add.input[0])
        qb = NumpyHelper.to_array(q_bias)
        kb = np.zeros_like(qb)
        vb = np.zeros_like(qb)
        if k_add is not None:
            k_bias = self.model.get_initializer(k_add.input[1]) or self.model.get_initializer(k_add.input[0])
            kb = NumpyHelper.to_array(k_bias)
        if v_add is not None:
            v_bias = self.model.get_initializer(v_add.input[1]) or self.model.get_initializer(v_add.input[0])
            vb = NumpyHelper.to_array(v_bias)

        qkv_bias = np.stack((qb, kb, vb), axis=0)
        qkv_bias_dim = 3 * np.prod(qb.shape)

        bias_name = name_prefix + "_qkv_bias"
        self.add_initializer(
            name=bias_name,
            data_type=q_bias.data_type,
            dims=[qkv_bias_dim],
            vals=qkv_bias,
        )
        return bias_name

    def create_packed_qkv_matmul_node(
        self,
        q_matmul: NodeProto,
        k_matmul: NodeProto,
        v_matmul: NodeProto,
        q_add: NodeProto,
        k_add: NodeProto | None,
        v_add: NodeProto | None,
    ) -> tuple[NodeProto, NodeProto, NodeProto]:
        """Create packed QKV MatMul node before MultiHeadAttention node.
           This is for the scenario where an Attention node should be created but cannot be created
           because past_key and past_value are separate inputs and not one concatenated input.

        Args:
            q_matmul (NodeProto): name of MatMul from Q path - (batch_size, sequence_length, hidden_size)
            k_matmul (NodeProto): name of MatMul from K path - (batch_size, sequence_length, hidden_size)
            v_matmul (NodeProto): name of MatMul from V path - (batch_size, sequence_length, hidden_size)
            q_add (NodeProto): name of Add from Q path
            k_add (NodeProto): name of Add from K path
            v_add (NodeProto): name of Add from V path

        Returns:
             q_output (NodeProto): Slice node for Q
             k_output (NodeProto): Slice node for K
             v_output (NodeProto): Slice node for V
        """
        matmul_node_name = self.model.create_node_name("MatMul")

        # Check that input for Q, K, V is the same
        assert q_matmul.input[0] == k_matmul.input[0] and k_matmul.input[0] == v_matmul.input[0]

        # Created packed QKV weight
        q_weight = self.model.get_initializer(q_matmul.input[1])
        k_weight = self.model.get_initializer(k_matmul.input[1])
        v_weight = self.model.get_initializer(v_matmul.input[1])

        qw = NumpyHelper.to_array(q_weight)
        kw = NumpyHelper.to_array(k_weight)
        vw = NumpyHelper.to_array(v_weight)

        assert qw.shape == kw.shape and kw.shape == vw.shape
        d = qw.shape[0]

        qkv_weight = np.stack((qw, kw, vw), axis=1).reshape((d, 3 * d))
        qkv_weight_name = matmul_node_name + "_qkv_weight"

        self.add_initializer(
            name=qkv_weight_name,
            data_type=q_weight.data_type,
            dims=[qkv_weight.shape[0], qkv_weight.shape[1]],
            vals=qkv_weight,
        )

        # Created packed QKV MatMul with output (B, S, 3*D)
        # Output is of the form:
        #
        # [[[Q Q ... Q Q K K ... K K V V ... V V]]]
        #   [Q Q ... Q Q K K ... K K V V ... V V]
        #                     .
        #                     .
        #                     .
        #  [[Q Q ... Q Q K K ... K K V V ... V V]
        #   [Q Q ... Q Q K K ... K K V V ... V V]]]
        qkv_matmul_output = matmul_node_name + "_qkv_out"
        qkv_matmul = helper.make_node(
            "MatMul",
            inputs=[q_matmul.input[0], qkv_weight_name],
            outputs=[qkv_matmul_output],
            name=matmul_node_name,
        )
        self.node_name_to_graph_name[matmul_node_name] = self.this_graph_name

        qkv_nodes = [qkv_matmul]

        # Create Slice nodes to access Q, K, V
        q_slice_name = matmul_node_name + "_q_start_index"
        self.add_initializer(name=q_slice_name, data_type=TensorProto.INT64, dims=[1], vals=[0], raw=False)
        k_slice_name = matmul_node_name + "_k_start_index"
        self.add_initializer(name=k_slice_name, data_type=TensorProto.INT64, dims=[1], vals=[d], raw=False)
        v_slice_name = matmul_node_name + "_v_start_index"
        self.add_initializer(name=v_slice_name, data_type=TensorProto.INT64, dims=[1], vals=[2 * d], raw=False)
        end_of_qkv_name = matmul_node_name + "_end_of_qkv_index"
        self.add_initializer(name=end_of_qkv_name, data_type=TensorProto.INT64, dims=[1], vals=[3 * d], raw=False)
        qkv_last_axis_name = matmul_node_name + "_qkv_last_axis"
        self.add_initializer(name=qkv_last_axis_name, data_type=TensorProto.INT64, dims=[1], vals=[-1], raw=False)

        q_slice_output = matmul_node_name + "_q_out"
        q_slice = helper.make_node(
            "Slice",
            inputs=[qkv_matmul_output, q_slice_name, k_slice_name, qkv_last_axis_name],
            outputs=[q_slice_output],
            name=self.model.create_node_name("Slice"),
        )
        self.node_name_to_graph_name[q_slice.name] = self.this_graph_name
        k_slice_output = matmul_node_name + "_k_out"
        k_slice = helper.make_node(
            "Slice",
            inputs=[qkv_matmul_output, k_slice_name, v_slice_name, qkv_last_axis_name],
            outputs=[k_slice_output],
            name=self.model.create_node_name("Slice"),
        )
        self.node_name_to_graph_name[k_slice.name] = self.this_graph_name
        v_slice_output = matmul_node_name + "_v_out"
        v_slice = helper.make_node(
            "Slice",
            inputs=[qkv_matmul_output, v_slice_name, end_of_qkv_name, qkv_last_axis_name],
            outputs=[v_slice_output],
            name=self.model.create_node_name("Slice"),
        )
        self.node_name_to_graph_name[v_slice.name] = self.this_graph_name

        q_output = q_slice
        k_output = k_slice
        v_output = v_slice
        qkv_nodes.extend([q_slice, k_slice, v_slice])

        if self.disable_multi_head_attention_bias:
            if q_add is not None:
                initializer_input = 1 if self.model.get_initializer(q_add.input[1]) else 0
                if np.any(NumpyHelper.to_array(self.model.get_initializer(q_add.input[initializer_input]))):
                    q_add.input[1 - initializer_input] = q_slice_output
                    q_output = q_add
                    qkv_nodes.append(q_add)
                    self.node_name_to_graph_name[q_add.name] = self.this_graph_name
            if k_add is not None:
                initializer_input = 1 if self.model.get_initializer(k_add.input[1]) else 0
                if np.any(NumpyHelper.to_array(self.model.get_initializer(k_add.input[initializer_input]))):
                    k_add.input[1 - initializer_input] = k_slice_output
                    k_output = k_add
                    qkv_nodes.append(k_add)
                    self.node_name_to_graph_name[k_add.name] = self.this_graph_name
            if v_add is not None:
                initializer_input = 1 if self.model.get_initializer(v_add.input[1]) else 0
                if np.any(NumpyHelper.to_array(self.model.get_initializer(v_add.input[initializer_input]))):
                    v_add.input[1 - initializer_input] = v_slice_output
                    v_output = v_add
                    qkv_nodes.append(v_add)
                    self.node_name_to_graph_name[v_add.name] = self.this_graph_name

        # Add nodes to graph
        self.nodes_to_add.extend(qkv_nodes)
        return q_output, k_output, v_output

    # This function is used in child classes for bart or conformer model.
    def create_multihead_attention_node(
        self,
        q_matmul: NodeProto,
        k_matmul: NodeProto | str | None,
        v_matmul: NodeProto | str | None,
        q_add: NodeProto,
        k_add: NodeProto | None,
        v_add: NodeProto | None,
        num_heads: int,
        hidden_size: int,
        output: str,
        key_padding_mask: str = "",
        add_qk: str = "",
        unidirectional: bool = False,
        past_k: str = "",
        past_v: str = "",
        present_k: str = "",
        present_v: str = "",
        packed_qkv: bool = False,
    ) -> NodeProto | None:
        """Create a MultiHeadAttention node.

        Args:
            q_matmul (NodeProto): name of MatMul from Q path - (batch_size, sequence_length, hidden_size)
            k_matmul (NodeProto): name of MatMul from K path - (batch_size, sequence_length, hidden_size) or (batch_size, num_heads, past_sequence_length, head_size)
            v_matmul (NodeProto): name of MatMul from V path - (batch_size, sequence_length, hidden_size) or (batch_size, num_heads, past_sequence_length, head_size)
            q_add (NodeProto): name of Add from Q path
            k_add (NodeProto): name of Add from K path
            v_add (NodeProto): name of Add from V path
            num_heads (int): number of attention heads. If a model is pruned, it is the number of heads after pruning.
            hidden_size (int): hidden dimension. If a model is pruned, it is the hidden dimension after pruning.
            output (str): output name of MHA
            key_padding_mask (str): name of key padding mask
            add_qk (str): name of add after Q x K'
            unidirectional (bool): whether to apply causal attention mask automatically or not
            past_k (str): name of past K value - (batch_size, num_heads, past_sequence_length, head_size)
            past_v (str): name of past V value - (batch_size, num_heads, past_sequence_length, head_size)
            present_k (str): name of present K value - (batch_size, num_heads, sequence_length, head_size)
            present_v (str): name of present V value - (batch_size, num_heads, sequence_length, head_size)
            packed_qkv (bool): whether to combine MatMuls from Q, K, V paths
                               Note: This is for the scenario where an Attention node should be created but cannot be created
                               because past_key and past_value are separate inputs and not one concatenated input.

        Returns:
            Union[NodeProto, None]: the node created or None if failed.
        """
        # B = batch size, N = num heads, P = past seq len, H = head size
        assert num_heads > 0

        if hidden_size > 0 and (hidden_size % num_heads) != 0:
            logger.debug("input hidden size %d is not a multiple of num of heads %d", hidden_size, num_heads)
            return None

        graph_input_names = {node.name for node in self.model.graph().input}
        mha_node_name = self.model.create_node_name("Attention")

        # Add initial Q/K/V inputs for MHA
        mha_inputs = []
        if packed_qkv:
            q_slice, k_slice, v_slice = self.create_packed_qkv_matmul_node(
                q_matmul,
                k_matmul,
                v_matmul,
                q_add,
                k_add,
                v_add,
            )
            mha_inputs.extend([q_slice.output[0], k_slice.output[0], v_slice.output[0]])
        elif isinstance(k_matmul, NodeProto) and isinstance(v_matmul, NodeProto):
            if self.disable_multi_head_attention_bias:
                mha_inputs.extend([q_add.output[0], k_matmul.output[0], v_add.output[0]])
            else:
                mha_inputs.extend([q_matmul.output[0], k_matmul.output[0], v_matmul.output[0]])
        elif (
            isinstance(k_matmul, str)
            and isinstance(v_matmul, str)
            and k_matmul in graph_input_names
            and v_matmul in graph_input_names
        ):
            if self.disable_multi_head_attention_bias:
                mha_inputs.extend([q_add.output[0], k_matmul, v_matmul])
            else:
                mha_inputs.extend([q_matmul.output[0], k_matmul, v_matmul])
        else:
            return None

        # Add bias to inputs for MHA
        # Bias for cross attention is not fully supported in DMMHA and cpu MHA kernels since they assume
        # bias has been added to key and value when they are in BNSH format, so only bias for query is used.
        # Need add checks if we found such assumption is not true.
        if not self.disable_multi_head_attention_bias:
            bias_name = self.create_combined_qkv_bias(q_add, k_add, v_add, mha_node_name)
            mha_inputs.append(bias_name)
        else:
            mha_inputs.append("")

        # Add optional inputs for MHA
        if past_k and past_v:
            mha_inputs.extend([key_padding_mask, add_qk, past_k, past_v])
        elif key_padding_mask or add_qk:
            mha_inputs.extend([key_padding_mask, add_qk])

        # Add outputs for MHA
        mha_outputs = [output]
        if present_k and present_v:
            mha_outputs.extend([present_k, present_v])

        mha_node = helper.make_node(
            "MultiHeadAttention",
            inputs=mha_inputs,
            outputs=mha_outputs,
            name=mha_node_name,
        )
        mha_node.domain = "com.microsoft"
        mha_node.attribute.append(helper.make_attribute("num_heads", num_heads))
        if unidirectional:
            mha_node.attribute.append(helper.make_attribute("unidirectional", int(unidirectional)))

        self.increase_counter("MultiHeadAttention")
        return mha_node

    def create_attention_node(
        self,
        mask_index: str | None,
        q_matmul: NodeProto,
        k_matmul: NodeProto,
        v_matmul: NodeProto,
        q_add: NodeProto,
        k_add: NodeProto,
        v_add: NodeProto,
        num_heads: int,
        hidden_size: int,
        first_input: str,
        output: str,
        add_qk_str: str = "",
        past_k: str = "",
        past_v: str = "",
        present_k: str = "",
        present_v: str = "",
        scale: float | None = None,
        causal: bool = False,
    ) -> NodeProto | None:
        """Create an Attention node.

        Args:
            mask_index (str | None): mask input
            q_matmul (NodeProto): MatMul node in fully connection for Q
            k_matmul (NodeProto): MatMul node in fully connection for K
            v_matmul (NodeProto): MatMul node in fully connection for V
            q_add (NodeProto): Add bias node in fully connection for Q
            k_add (NodeProto): Add bias node in fully connection for K
            v_add (NodeProto): Add bias node in fully connection for V
            num_heads (int): number of attention heads. If a model is pruned, it is the number of heads after pruning.
            hidden_size (int): hidden dimension. If a model is pruned, it is the hidden dimension after pruning.
            first_input (str): first input name
            output (str): output name
            add_qk_str (str): name of Add node after Q x K'
            past_k (str): name of input for past K value
            past_v (str): name of input for past V value
            present_k (str): name of output to store present K value
            present_v (str): name of output to store present V value
            scale: scale before softmax
            causal: whether it is uni-directional mask.

        Returns:
            Union[NodeProto, None]: the node created or None if failed.
        """
        assert num_heads > 0

        if hidden_size > 0 and (hidden_size % num_heads) != 0:
            logger.debug("input hidden size %d is not a multiple of num of heads %d", hidden_size, num_heads)
            return None

        has_bias = True
        if q_add is None and k_add is None and v_add is None:
            has_bias = False

        q_weight = self.model.get_initializer(q_matmul.input[1])
        k_weight = self.model.get_initializer(k_matmul.input[1])
        v_weight = self.model.get_initializer(v_matmul.input[1])

        q_bias, k_bias, v_bias = None, None, None
        if has_bias:
            q_bias = self.model.get_initializer(q_add.input[1]) or self.model.get_initializer(q_add.input[0])
            k_bias = self.model.get_initializer(k_add.input[1]) or self.model.get_initializer(k_add.input[0])
            v_bias = self.model.get_initializer(v_add.input[1]) or self.model.get_initializer(v_add.input[0])

            if not (k_weight and v_weight and q_bias and k_bias):
                return None

        if q_weight is None:
            print(
                f"{q_matmul.input[1]} is not an initializer. "
                "Please set do_constant_folding=True in torch.onnx.export to unblock attention fusion"
            )
            return None

        qw = NumpyHelper.to_array(q_weight)
        kw = NumpyHelper.to_array(k_weight)
        vw = NumpyHelper.to_array(v_weight)

        # assert q and k have same shape as expected
        assert qw.shape == kw.shape

        qw_in_size = qw.shape[0]
        kw_in_size = kw.shape[0]
        vw_in_size = vw.shape[0]

        assert qw_in_size == kw_in_size == vw_in_size

        if hidden_size > 0 and hidden_size != qw_in_size:
            logger.warning(
                "Input hidden size (%d) is not same as weight matrix dimension of q,k,v (%d). "
                "Please provide a correct input hidden size or pass in 0",
                hidden_size,
                qw_in_size,
            )

        is_qkv_diff_dims = False
        if qw.shape != vw.shape:
            is_qkv_diff_dims = True

        # All the matrices can have the same shape or q, k matrices can have the same shape with v being different
        # For 2d weights, the shapes would be [in_size, out_size].
        # For 3d weights, shape would be [in_size, a, b] where a*b = out_size
        qw_out_size = np.prod(qw.shape[1:])
        kw_out_size = np.prod(kw.shape[1:])
        vw_out_size = np.prod(vw.shape[1:])

        qkv_weight_dim = 0
        if is_qkv_diff_dims:
            qkv_weight = np.concatenate((qw, kw, vw), axis=1)
            qkv_weight_dim = qw_out_size + kw_out_size + vw_out_size
        else:
            qkv_weight = np.stack((qw, kw, vw), axis=1)
            qkv_weight_dim = 3 * qw_out_size

        qkv_bias_dim = 0
        qkv_bias: np.ndarray | None = None
        if has_bias:
            qb = NumpyHelper.to_array(q_bias)
            kb = NumpyHelper.to_array(k_bias)
            vb = NumpyHelper.to_array(v_bias)

            q_bias_shape = np.prod(qb.shape)
            k_bias_shape = np.prod(kb.shape)
            v_bias_shape = np.prod(vb.shape)

            assert q_bias_shape == k_bias_shape == qw_out_size
            assert v_bias_shape == vw_out_size

            if is_qkv_diff_dims:
                qkv_bias = np.concatenate((qb, kb, vb), axis=0)
                qkv_bias_dim = q_bias_shape + k_bias_shape + v_bias_shape
            else:
                qkv_bias = np.stack((qb, kb, vb), axis=0)
                qkv_bias_dim = 3 * q_bias_shape

        attention_node_name = self.model.create_node_name("Attention")

        if not self.use_multi_head_attention:
            self.add_initializer(
                name=attention_node_name + "_qkv_weight",
                data_type=q_weight.data_type,
                dims=[qw_in_size, int(qkv_weight_dim)],
                vals=qkv_weight,
            )

        if has_bias:
            self.add_initializer(
                name=attention_node_name + "_qkv_bias",
                data_type=q_bias.data_type,
                dims=[int(qkv_bias_dim)],
                vals=qkv_bias,
            )

        # For MultiHeadAttention operator, use separated inputs for query, key and value, and no weights.
        if self.use_multi_head_attention:
            if add_qk_str:
                logger.debug("MultiHeadAttention does not support relative_position_bias: cannot fuse the attention.")
                return None

            attention_inputs = [
                q_matmul.output[0],
                k_matmul.output[0],
                v_matmul.output[0],
                attention_node_name + "_qkv_bias",
            ]

            if mask_index is not None:
                attention_inputs.append(mask_index)

            attention_node = helper.make_node(
                "MultiHeadAttention",
                inputs=attention_inputs,
                outputs=[output],
                name=attention_node_name,
            )
            self.increase_counter("MultiHeadAttention")

        else:
            attention_inputs = [
                first_input,
                attention_node_name + "_qkv_weight",
                attention_node_name + "_qkv_bias" if has_bias else "",
            ]
            if mask_index is not None:
                attention_inputs.append(mask_index)
            else:
                attention_inputs.append("")

            past_exists = past_k and past_v
            if past_exists:
                past_kv = self.concat_kv(past_k, past_v)
                attention_inputs.append(past_kv)

            if add_qk_str:
                # Add additional add to attention node (input name = attention_bias)
                if not past_exists:
                    attention_inputs.append("")
                attention_inputs.append(add_qk_str)

            attention_outputs = [output]
            if present_k and present_v:
                present_kv = present_k.replace(".key", "").replace("_key", "").replace(".", "_")
                attention_outputs.append(present_kv)
                self.split_kv(present_k, present_v, present_kv)

            attention_node = helper.make_node(
                "Attention",
                inputs=attention_inputs,
                outputs=attention_outputs,
                name=attention_node_name,
            )
            self.increase_counter("Attention")

        attention_node.domain = "com.microsoft"
        attention_node.attribute.extend([helper.make_attribute("num_heads", num_heads)])

        if causal:
            attention_node.attribute.extend([helper.make_attribute("unidirectional", 1)])

        if scale is not None:
            attention_node.attribute.extend([helper.make_attribute("scale", scale)])

        if is_qkv_diff_dims:
            attention_node.attribute.extend(
                [helper.make_attribute("qkv_hidden_sizes", [qw_out_size, kw_out_size, vw_out_size])]
            )

        if self.mask_filter_value is not None:
            attention_node.attribute.extend([helper.make_attribute("mask_filter_value", float(self.mask_filter_value))])

        return attention_node

    def fuse(self, node, input_name_to_nodes, output_name_to_node):
        # Sometimes we can not fuse skiplayernormalization since the add before layernorm has an output that used by nodes outside skiplayernorm
        # Conceptually we treat add before layernorm as skiplayernorm node since they share the same pattern
        normalize_node = node
        start_node = normalize_node
        if normalize_node.op_type == "LayerNormalization":
            add_before_layernorm = self.model.match_parent(normalize_node, "Add", 0)
            if add_before_layernorm is not None:
                start_node = add_before_layernorm
            else:
                return

        # SkipLayerNormalization has two inputs, and one of them is the root input for attention.
        qkv_nodes = self.model.match_parent_path(
            start_node,
            ["Add", "MatMul", "Reshape", "Transpose", "MatMul"],
            [None, None, 0, 0, 0],
        )
        einsum_node = None
        if qkv_nodes is not None:
            (_, _, reshape_qkv, transpose_qkv, matmul_qkv) = qkv_nodes
        else:
            # Match Albert
            qkv_nodes = self.model.match_parent_path(
                start_node, ["Add", "Einsum", "Transpose", "MatMul"], [1, None, 0, 0]
            )
            if qkv_nodes is not None:
                (_, einsum_node, transpose_qkv, matmul_qkv) = qkv_nodes
            else:
                return

        other_inputs = []
        for _i, node_input in enumerate(start_node.input):
            if node_input not in output_name_to_node:
                continue

            if node_input == qkv_nodes[0].output[0]:
                continue
            other_inputs.append(node_input)
        if len(other_inputs) != 1:
            return

        root_input = other_inputs[0]

        # Match flaubert                     Mask
        #                                     |
        # Mul --> LayerNormalization -->  Attention --> MatMul --> Add
        #  |                                                        |
        #  |                                                        |
        #  +---------------------------------------------------------
        mul_before_layernorm = self.model.match_parent(start_node, "Mul", 0)
        if mul_before_layernorm is not None:
            mul_children = input_name_to_nodes[mul_before_layernorm.output[0]]
            if mul_children is not None and len(mul_children) == 2:
                layernorm_node = mul_children[1]
                if layernorm_node.op_type == "LayerNormalization":
                    root_input = layernorm_node.output[0]
                else:
                    return
            elif mul_children is not None and len(mul_children) == 5:
                root_input = mul_before_layernorm.output[0]
            else:
                return
        elif normalize_node.op_type == "LayerNormalization":
            children = input_name_to_nodes[root_input]
            for child in children:
                if child.op_type == "LayerNormalization":
                    root_input = child.output[0]

        # When Add before the LayerNormalization produces an output
        # that is consumed by some other nodes other than the LayerNormalization itself,
        # fused SkipLayerNormalization will have several outputs.
        # In this case we need to pick the one used in Attention
        # For example, this is the case for ViT
        # SkipLayerNormalization --> Attention --> MatMul --> Add --> SkipLayerNormalization
        #  |                                                                     |
        #  |                                                                     |
        #  +---------------------------------------------------------------------+
        parent_node = output_name_to_node[root_input]
        if parent_node.op_type == "SkipLayerNormalization" and len(parent_node.output) == 4:
            root_input = parent_node.output[0]

        children = input_name_to_nodes[root_input]
        children_types = [child.op_type for child in children]
        if children_types.count("MatMul") != 3:
            return

        v_nodes = self.model.match_parent_path(matmul_qkv, ["Transpose", "Reshape", "Add", "MatMul"], [1, 0, 0, None])
        if v_nodes is None:
            logger.debug("fuse_attention: failed to match v path")
            return
        (_, _, add_v, matmul_v) = v_nodes

        is_distill = False
        is_distill_add = False
        is_no_mask_attention = False
        is_sdpa = False
        qk_paths = {
            "path1": (["Softmax", "Add", "Div", "MatMul"], [0, 0, None, 0]),
            "path2": (["Softmax", "Add", "Mul", "MatMul"], [0, 0, None, 0]),
            "path3": (["Softmax", "Where", "MatMul", "Div"], [0, 0, 2, 0]),
            "path4": (["Softmax", "Add", "Where", "MatMul"], [0, 0, 0, 2]),
            "path5": (["Softmax", "Div", "MatMul"], [0, 0, 0]),
            "sdpa": (["Softmax", "Add", "MatMul", "Mul", "Sqrt"], [0, 0, None, 0, 1]),
        }

        qk_nodes = None
        for k, v in qk_paths.items():
            qk_nodes = self.model.match_parent_path(matmul_qkv, v[0], v[1])
            if qk_nodes is None:
                continue
            if k == "path3":
                is_distill = True
            elif k == "path4":
                is_distill_add = True
            elif k == "path5":
                is_no_mask_attention = True
            elif k == "sdpa":
                is_sdpa = True
            break

        if qk_nodes is None:
            logger.debug("fuse_attention: failed to match qk path")
            return

        add_qk = None
        matmul_qk = None
        where_qk = None
        after_q = None
        if is_distill:
            (_, where_qk, matmul_qk, _) = qk_nodes
        elif is_distill_add:
            (_, add_qk, where_qk, matmul_qk) = qk_nodes
        elif is_no_mask_attention:
            (_, _, matmul_qk) = qk_nodes
        elif is_sdpa:
            (_, add_qk, matmul_qk, after_q, _) = qk_nodes
        else:
            (_, add_qk, _, matmul_qk) = qk_nodes

        after_q = after_q or matmul_qk
        q_nodes = self.model.match_parent_path(after_q, ["Transpose", "Reshape", "Add", "MatMul"], [0, 0, 0, None])
        if q_nodes is None:
            q_nodes = self.model.match_parent_path(
                after_q,
                ["Div", "Transpose", "Reshape", "Add", "MatMul"],
                [0, 0, 0, 0, None],
            )
            if q_nodes is None:
                logger.debug("fuse_attention: failed to match q path")
                return
        reshape_q = q_nodes[-3]
        add_q = q_nodes[-2]
        matmul_q = q_nodes[-1]

        after_k = matmul_qk
        if is_sdpa:
            mul_k_nodes = self.model.match_parent_path(matmul_qk, ["Mul", "Sqrt"], [1, None])
            if mul_k_nodes is None:
                logger.debug("fuse_attention: failed to match mul sqrt q path")
                return
            (after_k, _) = mul_k_nodes

        k_nodes = self.model.match_parent_path(
            after_k, ["Transpose", "Reshape", "Add", "MatMul"], [0 if is_sdpa else 1, 0, 0, None]
        )
        if k_nodes is None:
            k_nodes = self.model.match_parent_path(
                matmul_qk,
                ["Transpose", "Transpose", "Reshape", "Add", "MatMul"],
                [1, 0, 0, 0, None],
            )
            if k_nodes is None:
                logger.debug("fuse_attention: failed to match k path")
                return
        add_k = k_nodes[-2]
        matmul_k = k_nodes[-1]

        # Note that Cast might be removed by OnnxRuntime so we match two patterns here.
        mask_nodes = None
        add_qk_str = ""
        if is_distill:
            _, mask_nodes, _ = self.model.match_parent_paths(
                where_qk,
                [
                    (["Expand", "Reshape", "Equal"], [0, 0, 0]),
                    (["Equal", "Unsqueeze", "Unsqueeze"], [0, 0, 0]),
                    (["Cast", "Expand", "Reshape", "Equal"], [0, 0, 0, 0]),
                ],
                output_name_to_node,
            )
        elif is_distill_add:
            _, mask_nodes, _ = self.model.match_parent_paths(
                where_qk,
                [
                    (["Cast", "Equal", "Unsqueeze", "Unsqueeze"], [0, 0, 0, 0]),
                    (["Equal", "Unsqueeze", "Unsqueeze"], [0, 0, 0]),
                ],
                output_name_to_node,
            )
            if add_qk is not None:
                add_qk_str = self.get_add_qk_str(add_qk)
                if add_qk_str is None:
                    logger.debug("fuse_attention: failed to verify shape inference of %s", add_qk)
                    return
        elif is_no_mask_attention:
            pass
        else:
            _, mask_nodes, _ = self.model.match_parent_paths(
                add_qk,
                [
                    (["Mul", "Sub", "Cast", "Unsqueeze", "Unsqueeze"], [None, 0, 1, 0, 0]),
                    (["Mul", "Sub", "Unsqueeze", "Unsqueeze"], [None, 0, 1, 0]),
                    # The following two patterns are for SDPA.
                    (["Where", "Cast", "Sub", "Expand", "Unsqueeze", "Unsqueeze"], [None, 0, 0, 1, 0, 0]),
                    (["Where", "Cast", "Sub", "Cast", "Expand", "Unsqueeze", "Unsqueeze"], [None, 0, 0, 1, 0, 0, 0]),
                ],
                output_name_to_node,
            )
        if not is_no_mask_attention and mask_nodes is None:
            logger.debug("fuse_attention: failed to match mask path")
            return

        if not is_no_mask_attention and len(mask_nodes) > 1:
            _, mul_val = self.model.get_constant_input(mask_nodes[0])
            # The mask value shall be a float scalar (usually is the lowest float value).
            if (
                (mul_val is None)
                or not (isinstance(mul_val, np.ndarray) and mul_val.size == 1)
                or (float(mul_val) >= 0)
            ):
                return
            if float(mul_val) != -10000:
                self.mask_filter_value = float(mul_val)

        if matmul_v.input[0] == root_input and matmul_q.input[0] == root_input and matmul_k.input[0] == root_input:
            mask_index = self.attention_mask.process_mask(mask_nodes[-1].input[0]) if not is_no_mask_attention else None

            attention_last_node = reshape_qkv if einsum_node is None else transpose_qkv

            q_num_heads, q_hidden_size = self.get_num_heads_and_hidden_size(reshape_q)
            if q_num_heads <= 0 or q_hidden_size <= 0:
                logger.warning(
                    "Failed to detect num_heads and hidden_size for Attention fusion. "
                    "Please specify those parameters in argument."
                )
                return

            # number of heads are same for all the paths, hence to create attention node, we pass the q_num_heads
            # the input_hidden_size represents the input hidden size, this is used as needed but hidden sizes for Q, K are extracted appropriately
            new_node = self.create_attention_node(
                mask_index=mask_index,
                q_matmul=matmul_q,
                k_matmul=matmul_k,
                v_matmul=matmul_v,
                q_add=add_q,
                k_add=add_k,
                v_add=add_v,
                num_heads=q_num_heads,
                hidden_size=q_hidden_size,
                first_input=root_input,
                output=attention_last_node.output[0],
                add_qk_str=add_qk_str,
            )

            if new_node is None:
                return

            self.nodes_to_add.append(new_node)
            self.node_name_to_graph_name[new_node.name] = self.this_graph_name

            if einsum_node is not None:
                unique_index = einsum_node.input[0]
                new_edge = "edge_modified_" + unique_index

                shape_tensor = self.add_initializer(
                    name="shape_modified_tensor" + unique_index,
                    data_type=TensorProto.INT64,
                    dims=[4],
                    vals=[0, 0, q_num_heads, int(q_hidden_size / q_num_heads)],
                    raw=False,
                )

                self.model.add_node(
                    helper.make_node(
                        "Reshape",
                        [attention_last_node.output[0], shape_tensor.name],
                        [new_edge],
                        "reshape_modified_" + unique_index,
                    ),
                    self.this_graph_name,
                )
                einsum_node.input[0] = new_edge

            self.nodes_to_remove.extend([attention_last_node, transpose_qkv, matmul_qkv])
            self.nodes_to_remove.extend(qk_nodes)

            # For MultiHeadAttention operator, MatMul nodes for Q/K/V projection shall not be fused.
            self.nodes_to_remove.extend(q_nodes if not self.use_multi_head_attention else q_nodes[:-1])
            self.nodes_to_remove.extend(k_nodes if not self.use_multi_head_attention else k_nodes[:-1])
            self.nodes_to_remove.extend(v_nodes if not self.use_multi_head_attention else v_nodes[:-1])

            # Use prune graph to remove mask nodes since they are shared by all attention nodes.
            self.prune_graph = True
