# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation.  All rights reserved.
# Licensed under the MIT License.  See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------

import logging

import numpy as np
import torch
from transformers import WhisperConfig

from onnxruntime import InferenceSession

logger = logging.getLogger(__name__)


# Create audio_features for encoder
# Shape is (batch_size, feature_size, sequence_length) = (batch_size, num_mel_filters, num_frames)
# where num_mel_filters is a model attribute and num_frames = (chunk_length * sample_rate) // hop_length.
#
# Hard-coded audio hyperparameters:
# SAMPLE_RATE = 16000
# N_FFT = 400
# HOP_LENGTH = 160
# CHUNK_LENGTH = 30  (i.e. 30-second chunk of audio)
# N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE = 30 * 16000 = 480000  (i.e. 480,000 samples in a 30-second chunk of audio)
# N_FRAMES = N_SAMPLES // HOP_LENGTH = 480000 // 160 = 3000  (i.e. 3000 frames in a mel spectrogram input)
#
# N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2 = 160 * 2 = 320
# FRAMES_PER_TOKEN = SAMPLE_RATE // HOP_LENGTH = 16000 // 160 = 100  (i.e. 10 ms per audio frame)
# TOKENS_PER_SECOND = SAMPLE_RATE // N_SAMPLES_PER_TOKEN = 16000 // 320 = 50  (i.e. 20 ms per audio token)
def get_sample_audio_features(
    config: WhisperConfig,
    device: torch.device,
    batch_size: int,
    sequence_length: int = 3000,
    use_fp16: bool = False,
):
    torch_dtype = torch.float16 if use_fp16 else torch.float32
    audio_features = torch.randn(batch_size, config.num_mel_bins, sequence_length, device=device, dtype=torch_dtype)
    return audio_features


# Create input_ids for decoder
# Shape is (batch_size, sequence_length) where sequence_length is the initial decoder sequence length
def get_sample_decoder_input_ids(
    config: WhisperConfig,
    device: torch.device,
    batch_size: int,
    sequence_length: int,
    use_int32: bool = True,
):
    torch_dtype = torch.int32 if use_int32 else torch.int64
    decoder_input_ids = torch.randint(
        low=0, high=config.vocab_size, size=(batch_size, sequence_length), device=device, dtype=torch_dtype
    )
    return decoder_input_ids


# Create encoder_hidden_states for decoder-init
# Shape is (batch_size, num_frames // 2, hidden_size)
def get_sample_encoder_hidden_states(
    config: WhisperConfig,
    device: torch.device,
    batch_size: int,
    use_fp16: bool = False,
):
    torch_dtype = torch.float16 if use_fp16 else torch.float32
    encoder_hidden_states = torch.randn(
        batch_size, config.max_source_positions, config.d_model, device=device, dtype=torch_dtype
    )
    return encoder_hidden_states


# Create past_key_values
# Self-attention KV caches are of shape (batch_size, num_heads, past_sequence_length, head_size)
# Cross-attention KV caches are of shape (batch_size, num_heads, num_frames // 2, head_size)
def get_sample_past_key_values(
    config: WhisperConfig,
    device: torch.device,
    batch_size: int,
    past_seq_len: int,
    use_fp16: bool = False,
):
    num_heads = config.decoder_attention_heads
    head_size = config.d_model // num_heads
    max_source_positions = (
        config.max_source_positions
    )  # equal to num_frames // 2 = encoder's sequence_length // 2 = 3000 // 2 = 1500
    torch_dtype = torch.float16 if use_fp16 else torch.float32
    self_attention_kv_caches = [
        (
            torch.rand(batch_size, num_heads, past_seq_len, head_size, device=device, dtype=torch_dtype),
            torch.rand(batch_size, num_heads, past_seq_len, head_size, device=device, dtype=torch_dtype),
        )
        for _ in range(config.num_hidden_layers)
    ]
    cross_attention_kv_caches = [
        (
            torch.rand(batch_size, num_heads, max_source_positions, head_size, device=device, dtype=torch_dtype),
            torch.rand(batch_size, num_heads, max_source_positions, head_size, device=device, dtype=torch_dtype),
        )
        for _ in range(config.num_hidden_layers)
    ]
    return flatten_past_key_values(self_attention_kv_caches, cross_attention_kv_caches)


# Flatten KV caches into pairs-of-4 where each pair is defined as:
# (self_attn_key_cache, self_attn_value_cache, cross_attn_key_cache, cross_attn_value_cache)
def flatten_past_key_values(
    self_attn_kv_caches: list[tuple[torch.Tensor, torch.Tensor]],
    cross_attn_kv_caches: list[tuple[torch.Tensor, torch.Tensor]],
):
    past_key_values = []
    for (self_k_cache, self_v_cache), (cross_k_cache, cross_v_cache) in zip(
        self_attn_kv_caches, cross_attn_kv_caches, strict=False
    ):
        layer_kv_caches = (self_k_cache, self_v_cache, cross_k_cache, cross_v_cache)
        past_key_values.append(layer_kv_caches)
    return past_key_values


# Group KV caches into two 1D lists where one list contains the self attention KV caches and
# one list contains the cross attention KV caches
def group_past_key_values(
    kv_caches: list[tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]],
):
    self_attn_kv_caches, cross_attn_kv_caches = [], []
    for self_k_cache, self_v_cache, cross_k_cache, cross_v_cache in kv_caches:
        self_attn_kv_caches.append(self_k_cache)
        self_attn_kv_caches.append(self_v_cache)
        cross_attn_kv_caches.append(cross_k_cache)
        cross_attn_kv_caches.append(cross_v_cache)
    return self_attn_kv_caches, cross_attn_kv_caches


# Create alignment heads for timestamps
# Shape is (num_alignment_heads, 2)
def get_sample_alignment_heads(
    config: WhisperConfig,
    device: torch.device,
    num_alignment_heads: int = 6,
    use_int32: bool = True,
):
    torch_dtype = torch.int32 if use_int32 else torch.int64
    alignment_heads = torch.ones((num_alignment_heads, 2), device=device, dtype=torch_dtype)
    return alignment_heads


# Create length of start-of-transcription sequence for timestamps
# Shape is (1)
def get_sample_sot_sequence_length(
    device: torch.device,
    sot_sequence_length: int,
    use_int32: bool = False,
):
    torch_dtype = torch.int32 if use_int32 else torch.int64
    sot_length = torch.tensor([sot_sequence_length], device=device, dtype=torch_dtype)
    return sot_length


# Create segment length for timestamps
# Shape is (1)
def get_sample_segment_length(
    device: torch.device,
    segment_length: int,
    use_int32: bool = False,
):
    torch_dtype = torch.int32 if use_int32 else torch.int64
    segment_size = torch.tensor([segment_length], device=device, dtype=torch_dtype)
    return segment_size


# Create QKs for timestamps
# Shape is (batch_size, num_heads, sequence_length, num_frames // 2)
def get_sample_QKs(  # noqa: N802
    config: WhisperConfig,
    device: torch.device,
    batch_size: int,
    sequence_length: int,
    use_fp16: bool = False,
):
    num_heads = config.decoder_attention_heads
    torch_dtype = torch.float16 if use_fp16 else torch.float32
    QKs = [  # noqa: N806
        torch.rand(
            batch_size, num_heads, sequence_length, config.max_source_positions, device=device, dtype=torch_dtype
        )
        for _ in range(config.num_hidden_layers)
    ]
    return QKs


# Create inputs for encoder component of Whisper
def get_sample_encoder_inputs(
    config: WhisperConfig,
    device: torch.device,
    batch_size: int,
    sequence_length: int = 3000,
    use_fp16: bool = False,
):
    audio_features = get_sample_audio_features(config, device, batch_size, sequence_length, use_fp16)
    return {"audio_features": audio_features}


# Create inputs for encoder component + first pass through decoder component of Whisper
def get_sample_encoder_decoder_init_inputs(
    config: WhisperConfig,
    device: torch.device,
    batch_size: int,
    decoder_sequence_length: int,
    encoder_sequence_length: int = 3000,
    use_fp16: bool = False,
    use_int32: bool = True,
):
    audio_features = get_sample_audio_features(config, device, batch_size, encoder_sequence_length, use_fp16)
    decoder_input_ids = get_sample_decoder_input_ids(config, device, batch_size, decoder_sequence_length, use_int32)
    return {"audio_features": audio_features, "decoder_input_ids": decoder_input_ids}


# Create inputs for decoder component of Whisper
# Inputs for first pass through the decoder (i.e. decoder-init): decoder_input_ids, encoder_hidden_states
# Inputs for subsequent passes through the decoder (i.e. decoder-with-past): decoder_input_ids, past_key_values
def get_sample_decoder_inputs(
    config: WhisperConfig,
    device: torch.device,
    batch_size: int,
    past_sequence_length: int,
    sequence_length: int,
    use_fp16: bool = False,
    use_int32: bool = True,
):
    decoder_input_ids = get_sample_decoder_input_ids(config, device, batch_size, sequence_length, use_int32)
    encoder_hidden_states = get_sample_encoder_hidden_states(config, device, batch_size, use_fp16)
    past_key_values = get_sample_past_key_values(config, device, batch_size, past_sequence_length, use_fp16)
    return {
        "decoder_input_ids": decoder_input_ids,
        "encoder_hidden_states": encoder_hidden_states,
        "past_key_values": past_key_values,
    }


# Create inputs for timestamps component of Whisper
def get_sample_jump_times_inputs(
    config: WhisperConfig,
    device: torch.device,
    batch_size: int,
    sequence_length: int,
    num_alignment_heads: int,
    sot_sequence_length: int,
    segment_length: int,
    use_fp16: bool = False,
    use_int32: bool = True,
):
    alignment_heads = get_sample_alignment_heads(config, device, num_alignment_heads, use_int32)
    # lengths need to be int64 because subsequent 'Slice' ops only take int64 inputs
    sot_sequence_length = get_sample_sot_sequence_length(device, sot_sequence_length)
    segment_length = get_sample_segment_length(device, segment_length)
    QKs = get_sample_QKs(config, device, batch_size, sequence_length, use_fp16)  # noqa: N806
    return {
        "alignment_heads": alignment_heads,
        "sot_sequence_length": sot_sequence_length,
        "segment_length": segment_length,
        "QKs": QKs,
    }


# Convert PyTorch inputs to ONNX Runtime inputs
def convert_inputs_for_ort(
    inputs: dict,
    model: InferenceSession,
):
    self_attn_kv_caches, cross_attn_kv_caches = None, None
    batch_size, num_heads, past_seq_len, head_size = 0, 0, 0, 0
    num_beams, max_seq_len = 1, 448
    if "past_key_values" in inputs:
        (self_attn_kv_caches, cross_attn_kv_caches) = group_past_key_values(inputs["past_key_values"])
        batch_size, num_heads, past_seq_len, head_size = self_attn_kv_caches[0].shape

    ort_inputs = {}
    model_inputs = list(map(lambda i: i.name, model.get_inputs()))  # noqa: C417
    use_buffer_sharing = "cache_indirection" in model_inputs
    for name in model_inputs:
        if name in {"audio_features", "encoder_input_ids"}:
            # Encoder input
            ort_inputs[name] = inputs["audio_features"].detach().cpu().numpy()
        elif name == "encoder_hidden_states":
            # Encoder output
            ort_inputs[name] = inputs["encoder_hidden_states"].detach().cpu().numpy()
        elif name in {"decoder_input_ids", "input_ids"}:
            # Decoder input
            ort_inputs[name] = inputs["decoder_input_ids"].detach().cpu().numpy()
        elif "past_key_self" in name or "past_value_self" in name:
            # Decoder input
            orig_kv_cache = self_attn_kv_caches.pop(0).detach().cpu().numpy()
            if use_buffer_sharing:
                new_kv_cache = np.zeros((batch_size, num_heads, max_seq_len, head_size), dtype=orig_kv_cache.dtype)
                new_kv_cache[:batch_size, :num_heads, :past_seq_len, :head_size] = orig_kv_cache
                ort_inputs[name] = new_kv_cache
            else:
                ort_inputs[name] = orig_kv_cache
        elif "past_key_cross" in name or "past_value_cross" in name:
            # Decoder input
            orig_kv_cache = cross_attn_kv_caches.pop(0).detach().cpu().numpy()
            ort_inputs[name] = orig_kv_cache
        elif name == "past_sequence_length":
            # Decoder input
            ort_inputs[name] = np.array([past_seq_len], dtype=np.int32)
        elif name == "cache_indirection":
            # Decoder input
            ort_inputs[name] = np.zeros((batch_size, num_beams, max_seq_len), dtype=np.int32)
        elif name == "alignment_heads":
            # Jump times input
            ort_inputs[name] = inputs["alignment_heads"].detach().cpu().numpy()
        elif name == "sot_sequence_length":
            # Jump times input
            ort_inputs[name] = inputs["sot_sequence_length"].detach().cpu().numpy()
        elif name == "segment_length":
            # Jump times input
            ort_inputs[name] = inputs["segment_length"].detach().cpu().numpy()
        elif "cross_qk" in name:
            # Jump times input
            ort_inputs[name] = inputs["QKs"].pop(0).detach().cpu().numpy()
        else:
            raise ValueError(f"Unknown name not recognized: {name}")

    return ort_inputs


# Get dynamic axes for all inputs and outputs to the model
def get_model_dynamic_axes(
    config: WhisperConfig,
    input_names: list[str],
    output_names: list[str],
):
    dynamic_axes = {}
    for name in input_names + output_names:
        if name in {"audio_features", "encoder_input_ids"}:
            # shape is (batch_size, num_mels, num_frames)
            dynamic_axes[name] = {0: "batch_size"}
        elif name in {"input_ids", "decoder_input_ids"}:
            # shape is (batch_size, sequence_length)
            dynamic_axes[name] = {0: "batch_size", 1: "sequence_length"}
        elif name == "alignment_heads":
            # shape is (num_alignment_heads, 2)
            dynamic_axes[name] = {0: "num_alignment_heads"}
        elif name in {"sot_sequence_length", "segment_length"}:
            # shape is (1)
            pass
        elif name == "logits":
            # shape is (batch_size, sequence_length, vocab_size)
            dynamic_axes[name] = {0: "batch_size", 1: "sequence_length"}
        elif name == "encoder_hidden_states":
            # shape is (batch_size, num_frames // 2, hidden_size)
            dynamic_axes[name] = {0: "batch_size"}
        elif "past_key_self" in name or "past_value_self" in name:
            # shape is (batch_size, num_heads, past_sequence_length, head_size)
            dynamic_axes[name] = {0: "batch_size", 2: "past_sequence_length"}
        elif "present_key_self" in name or "present_value_self" in name:
            # shape is (batch_size, num_heads, past_sequence_length + sequence_length, head_size),
            # which is equal to (batch_size, num_heads, total_sequence_length, head_size)
            dynamic_axes[name] = {0: "batch_size", 2: "total_sequence_length"}
        elif (
            "past_key_cross" in name
            or "past_value_cross" in name
            or "present_key_cross" in name
            or "present_value_cross" in name
        ):
            # shape is (batch_size, num_heads, num_frames // 2, head_size)
            dynamic_axes[name] = {0: "batch_size"}
        elif "cross_qk" in name:
            # shape is (batch_size, num_heads, source_sequence_length, target_sequence_length)
            dynamic_axes[name] = {0: "batch_size", 2: "sequence_length"}
        elif "jump_times" in name:
            # shape is (batch_size, max_length)
            dynamic_axes[name] = {0: "batch_size", 1: "max_length"}
        else:
            raise Exception(f"Unknown input or output name found: {name}")
    return dynamic_axes
