o
    U|%il                     @   s   d Z ddlZddlZddlmZ ddlmZmZmZm	Z	 ddl
ZddlZddlmZmZmZmZmZmZmZmZmZmZmZmZmZ ddlmZ G dd deZG d	d
 d
ZG dd deZ dS )zHeatWave ML model utilities for MySQL Connector/Python.

Provides classes to manage training, prediction, scoring, and explanations
via MySQL HeatWave stored procedures.
    N)Enum)AnyDictOptionalUnion)VAR_NAME_SPACEatomic_transactionconvert_to_dfexecute_sqlformat_value_sqlget_random_namesource_schemasql_response_to_dfsql_table_from_dfsql_table_to_dftable_existstemporary_sql_tablesvalidate_name)MySQLConnectionAbstractc                   @   sJ   e Zd ZdZdZdZdZdZdZdZ	dZ
ed	eed f d
efddZdS )ML_TASKz/Enumeration of supported ML tasks for HeatWave.classification
regressionforecastinganomaly_detectionlog_anomaly_detectionrecommendationtopic_modelingtaskreturnc                 C   s   t | tr| S | jS )a'  
        Return the string representation of a machine learning task.

        Args:
            task (Union[str, ML_TASK]): The task to convert.
                Accepts either a task enum member (ML_TASK) or a string.

        Returns:
            str: The string value of the ML task.
        )
isinstancestrvalue)r    r"   N/home/air/sos_test/back/venv/lib/python3.10/site-packages/mysql/ai/ml/model.pyget_task_stringG   s   
zML_TASK.get_task_stringN)__name__
__module____qualname____doc__CLASSIFICATION
REGRESSIONFORECASTINGANOMALY_DETECTIONLOG_ANOMALY_DETECTIONRECOMMENDATIONTOPIC_MODELINGstaticmethodr   r    r$   r"   r"   r"   r#   r   <   s     r   c                
   @   s*  e Zd ZdZejdfdedeeef de	e fddZ
defd	d
Zdede	e fddZde	e fddZdedefddZd$ddZdefddZdefddZdede	e de	e ddfddZdedede	e ddfddZdededede	e def
d d!Zdedede	e dejfd"d#ZdS )%_MyModelCommona8  
    Common utilities and workflow for MySQL HeatWave ML models.

    This class handles model lifecycle steps such as loading, fitting, scoring,
    making predictions, and explaining models or predictions. Not intended for
    direct instantiation, but as a superclass for heatwave model wrappers.

    Attributes:
        db_connection: MySQL connector database connection.
        task: ML task, e.g., "classification" or "regression".
        model_name: Identifier of model in MySQL.
        schema_name: Database schema used for operations and temp tables.
    Ndb_connectionr   
model_namec                 C   s   || _ t|| _t|| _t| j }t|d W d   n1 s#w   Y  |du r1t| j	}t
 d| | _| j d| _|| _t| t| j }t|d| j d|fd W d   dS 1 sew   Y  dS )a  
        Instantiate _MyMLModelCommon.

        References:
            https://dev.mysql.com/doc/heatwave/en/mys-hwaml-ml-train.html
                A full list of supported tasks can be found under "Common ML_TRAIN Options"

        Args:
            db_connection: MySQL database connection.
            task: ML task type (default: "classification").
            model_name: Name to register the model within MySQL (default: None).

        Raises:
            ValueError: If the schema name is not valid
            DatabaseError:
                If a database connection issue occurs.
                If an operational error occurs during execution.

        Returns:
            None
        z(CALL sys.ML_CREATE_OR_UPGRADE_CATALOG();N.z.scorezSET @z = %s;params)r2   r   r$   r   r   schema_namer   r
   r   _is_model_name_availabler   	model_varmodel_var_scorer3   r   )selfr2   r   r3   cursorr"   r"   r#   __init__i   s   

"z_MyModelCommon.__init__r   c                 C   sh   |   }d| d}d| d| j }t| j}t|| |jdkW  d   S 1 s-w   Y  dS )a%  
        Deletes the model from the model catalog if present

        Raises:
            DatabaseError:
                If a database connection issue occurs.
                If an operational error occurs during execution.

        Returns:
            Whether the model was deleted
        
ML_SCHEMA_.MODEL_CATALOGzDELETE FROM  WHERE model_handle = @r   N)	_get_userr9   r   r2   r
   rowcount)r;   current_userqualified_model_catalogdelete_modelr<   r"   r"   r#   _delete_model   s   
$z_MyModelCommon._delete_modelc           
         s   dt dt fdd |  }d| d}d| d}t| j4}t|||fd	 t|}|jr1d
}n|jdd}t	|d }	 fdd|	
 D }|W  d
   S 1 sUw   Y  d
S )a  
        Retrieves the model info from the model_catalog

        Args:
            model_var: The model alias to retrieve

        Returns:
            The model info from the model_catalog (None if the model is not present in the catalog)

        Raises:
            DatabaseError:
                If a database connection issue occurs.
                If an operational error occurs during execution.
        elemr   c                 S   s6   t | trzt| } W | S  tjy   Y | S w | S N)r   r    jsonloadsJSONDecodeError)rG   r"   r"   r#   process_col   s   
z3_MyModelCommon._get_model_info.<locals>.process_colr>   r?   SELECT * FROM z WHERE model_handle = %sr5   Nrecords)orientr   c                    s   i | ]	\}}| |qS r"   r"   ).0keyrG   rL   r"   r#   
<dictcomp>   s    z2_MyModelCommon._get_model_info.<locals>.<dictcomp>)r   rA   r   r2   r
   r   emptyto_jsonrI   rJ   items)
r;   r3   rC   rD   model_existsr<   model_info_dfresultunprocessed_resultunprocessed_result_jsonr"   rR   r#   _get_model_info   s"   

$z_MyModelCommon._get_model_infoc                 C   s   |  | jS )a  
        Checks if the model name is available.
        Model info is present in the catalog only if the model was previously fitted.

        Returns:
            True if the model name is not part of the model catalog

        Raises:
            DatabaseError:
                If a database connection issue occurs.
                If an operational error occurs during execution.
        )r\   r3   )r;   r"   r"   r#   get_model_info   s   z_MyModelCommon.get_model_infoc                 C   s   |  |du S )a1  
        Checks if the model name is available

        Returns:
            True if the model name is not part of the model catalog

        Raises:
            DatabaseError:
                If a database connection issue occurs.
                If an operational error occurs during execution.
        N)r\   )r;   r3   r"   r"   r#   r8      s   z'_MyModelCommon._is_model_name_availablec                 C   sH   t | j}d| j d}t|| W d   dS 1 sw   Y  dS )a  
        Loads the model specified by `self.model_name` into MySQL.
        After loading, the model is ready to handle ML operations.

        References:
            https://dev.mysql.com/doc/heatwave/en/mys-hwaml-ml-model-load.html

        Raises:
            DatabaseError:
                If the model is not initialized, i.e., fit or import has not been called
                If a database connection issue occurs.
                If an operational error occurs during execution.

        Returns:
            None
        zCALL sys.ML_MODEL_LOAD(@z, NULL);N)r   r2   r9   r
   )r;   r<   load_model_queryr"   r"   r#   _load_model   s   "z_MyModelCommon._load_modelc                 C   sV   t | j}|d | d dd }t|W  d   S 1 s$w   Y  dS )a  
        Fetch the current database user (without host).

        Returns:
            The username string associated with the connection.

        Raises:
            DatabaseError:
                If a database connection issue occurs.
                If an operational error occurs during execution.
            ValueError: If the user name includes unsupported characters
        zSELECT CURRENT_USER()r   @N)r   r2   executefetchonesplitr   )r;   r<   rC   r"   r"   r#   rA     s
   
$z_MyModelCommon._get_userc                 C   sx   |    t| j)}|  }d| d}d| d| j }t|| t|}|jd W  d   S 1 s5w   Y  dS )a  
        Get model explanations, such as detailed feature importances.

        Returns:
            dict: Feature importances and model explainability data.

        References:
            https://dev.mysql.com/doc/heatwave/en/mys-hwaml-model-explanations.html

        Raises:
            DatabaseError:
                If the model is not initialized, i.e., fit or import has not been called
                If a database connection issue occurs.
                If an operational error occurs during execution.
            ValueError:
                If the model does not exist in the model catalog.
                Should only occur if model was not fitted or was deleted.
        r>   r?   zSELECT model_explanation FROM r@   r   r   N)r_   r   r2   rA   r9   r
   r   iloc)r;   r<   rC   rD   explain_querydfr"   r"   r#   explain_model#  s   
$z_MyModelCommon.explain_model
table_nametarget_column_nameoptionsc                 C   s   t | |durt | d| d}nd}|du ri }t|}| j|d< |   t| j(}t|\}}t|d| j	 d| d| d| d	| j
 d
|d W d   dS 1 sYw   Y  dS )a2  
        Fit an ML model using a referenced SQL table and target column.

        References:
            https://dev.mysql.com/doc/heatwave/en/mys-hwaml-ml-train.html
                A full list of supported options can be found under "Common ML_TRAIN Options"

        Args:
            table_name: Name of the training data table.
            target_column_name: Name of the target/label column.
            options: Additional fit/config options (may override defaults).

        Raises:
            DatabaseError:
                If provided options are invalid or unsupported.
                If a database connection issue occurs.
                If an operational error occurs during execution.
            ValueError: If the table or target_column name is not valid

        Returns:
            None
        N'NULLr   zCALL sys.ML_TRAIN('r4   ', , z, @)r5   )r   copydeepcopyr   rF   r   r2   r   r
   r7   r9   )r;   ri   rj   rk   target_col_stringr<   placeholders
parametersr"   r"   r#   _fitE  s8   

"z_MyModelCommon._fitoutput_table_namec                 C   s   t | t | |   t| j,}t|\}}t|d| j d| d| j d| j d| d| d|d W d   dS 1 s@w   Y  dS )	a  
        Predict on a given data table and write results to an output table.

        References:
            https://dev.mysql.com/doc/heatwave/en/mys-hwaml-ml-predict-table.html
                A full list of supported options can be found under "ML_PREDICT_TABLE Options"

        Args:
            table_name: Name of the SQL table with input data.
            output_table_name: Name for the SQL output table to contain predictions.
            options: Optional prediction options.

        Returns:
            None

        Raises:
            DatabaseError:
                If provided options are invalid or unsupported,
                or if the model is not initialized, i.e., fit or import has not
                been called
                If a database connection issue occurs.
                If an operational error occurs during execution.
            ValueError: If the table or output_table name is not valid
        zCALL sys.ML_PREDICT_TABLE('r4   ', @, 'rn   rp   r5   N)r   r_   r   r2   r   r
   r7   r9   )r;   ri   rw   rk   r<   rt   ru   r"   r"   r#   _predict~  s.   "z_MyModelCommon._predictmetricc           	      C   s   t | t | t | |   t| j@}t|\}}t|d| j d| d| d| j d| j d| d|g|d t|d	| j  t	|}|j
d
 W  d   S 1 sXw   Y  dS )aI  
        Evaluate model performance with a scoring metric.

        References:
            https://dev.mysql.com/doc/heatwave/en/mys-hwaml-ml-score.html
                A full list of supported options can be found under
                "Options for Recommendation Models" and
                "Options for Anomaly Detection Models"

        Args:
            table_name: Table with features and ground truth.
            target_column_name: Column of true target labels.
            metric: String name of the metric to compute.
            options: Optional dictionary of further scoring options.

        Returns:
            float: Computed score from the ML system.

        Raises:
            DatabaseError:
                If provided options are invalid or unsupported,
                or if the model is not initialized, i.e., fit or import has not
                been called
                If a database connection issue occurs.
                If an operational error occurs during execution.
            ValueError: If the table or target_column name or metric is not valid
        zCALL sys.ML_SCORE('r4   z', 'rx   z, %s, @ro   rp   r5   zSELECT @rd   N)r   r_   r   r2   r   r
   r7   r9   r:   r   re   )	r;   ri   rj   r{   rk   r<   rt   ru   rg   r"   r"   r#   _score  s6   "	$z_MyModelCommon._scorec                 C   s   t | t | |du rddi}|   t| j=}t|\}}t|d| j d| d| j d| j d| d| d	|d
 t|d| j d|  t|}|W  d   S 1 sYw   Y  dS )a  
        Produce explanations for model predictions on provided data.

        References:
            https://dev.mysql.com/doc/heatwave/en/mys-hwaml-ml-explain-table.html
                A full list of supported options can be found under "ML_EXPLAIN_TABLE Options"

        Args:
            table_name: Name of the SQL table with input data.
            output_table_name: Name for the SQL table to store explanations.
            options: Optional dictionary (default:
                {"prediction_explainer": "permutation_importance"}).

        Returns:
            DataFrame: Prediction explanations from the output SQL table.

        Raises:
            DatabaseError:
                If provided options are invalid or unsupported,
                or if the model is not initialized, i.e., fit or import has not
                been called
                If a database connection issue occurs.
                If an operational error occurs during execution.
            ValueError: If the table or output_table name is not valid
        Nprediction_explainerpermutation_importancezCALL sys.ML_EXPLAIN_TABLE('r4   rx   ry   rn   rp   r5   rM   )	r   r_   r   r2   r   r
   r7   r9   r   )r;   ri   rw   rk   r<   rt   ru   rg   r"   r"   r#   _explain_predictions  s8   $z#_MyModelCommon._explain_predictions)r   N)r%   r&   r'   r(   r   r)   r   r   r    r   r=   boolrF   dictr\   r]   r8   r_   rA   rh   rv   rz   floatr|   pd	DataFramer   r"   r"   r"   r#   r1   Z   sp    

./
"
9
.
<r1   c                   @   s   e Zd ZdZ	ddeejejf de	eejejf  de	e
 ddfddZ	ddeejejf de	e
 dejfd	d
Z	ddeejejf deejejf dede	e
 def
ddZ	ddeejejf dedejfddZdS )MyModelz
    Convenience class for managing the ML workflow using pandas DataFrames.

    Methods convert in-memory DataFrames into temp SQL tables before delegating to the
    _MyMLModelCommon routines, and automatically clean up temp resources.
    NXyrk   r   c              	      s  t  t | }t| jr}t| jT}|durEt|tjr%|jd }nt fdd}| jv r:t	d| d 
 }|||< |}nd} }t|| j|\}	}
|| j|
f | |
|| W d   n1 skw   Y  W d   dS W d   dS 1 sw   Y  dS )a  
        Fit a model using DataFrame inputs.

        If an 'id' column is defined in either dataframe, it will be used as the primary key.

        References:
            https://dev.mysql.com/doc/heatwave/en/mys-hwaml-ml-train.html
                A full list of supported options can be found under "Common ML_TRAIN Options"

        Args:
            X: Features DataFrame.
            y: (Optional) Target labels DataFrame or Series. If None, only X is used.
            options: Additional options to pass to training.

        Returns:
            None

        Raises:
            DatabaseError:
                If provided options are invalid or unsupported.
                If a database connection issue occurs.
                If an operational error occurs during execution.

        Notes:
            Combines X and y as necessary. Creates a temporary table in the schema for training,
            and deletes it afterward.
        Nr   c                    
   |  j vS rH   columnsnamer   r"   r#   <lambda>U     
 zMyModel.fit.<locals>.<lambda>zTarget column y with name z' already present in feature dataframe X)r	   r   r2   r   r   r   r   r   r   
ValueErrorrq   r   r7   appendrv   )r;   r   r   rk   r<   temporary_tablesrj   df_combinedfinal_df_ri   r"   r   r#   fit(  s4   !


PzMyModel.fitc              	      s   t |}tja tjK}t j|\}}|j|f t fdd}|j|f ||| t	 j|}|d 
tj|d< |W  d   W  d   S 1 s]w   Y  W d   dS 1 smw   Y  dS )a  
        Generate model predictions using DataFrame input.

        If an 'id' column is defined in either dataframe, it will be used as the primary key.

        References:
            https://dev.mysql.com/doc/heatwave/en/mys-hwaml-ml-predict-table.html
                A full list of supported options can be found under "ML_PREDICT_TABLE Options"

        Args:
            X: DataFrame containing prediction features (no labels).
            options: Additional prediction settings.

        Returns:
            DataFrame with prediction results as returned by HeatWave.

        Raises:
            DatabaseError:
                If provided options are invalid or unsupported,
                or if the model is not initialized, i.e., fit or import has not
                been called
                If a database connection issue occurs.
                If an operational error occurs during execution.

        Notes:
            Temporary SQL tables are created and deleted for input/output.
        c                       t  j|  S rH   r   r7   ri   r<   r;   r"   r#   r         z!MyModel.predict.<locals>.<lambda>
ml_resultsN)r	   r   r2   r   r   r7   r   r   rz   r   maprI   rJ   )r;   r   rk   r   r   ri   rw   predictionsr"   r   r#   predictj  s$    RzMyModel.predictr{   c              	      s   t  t | }t| jR}t| j<}t fdd}  }|||< |}	t|| j|	\}
}|| j|f | 	||||}|W  d   W  d   S 1 sSw   Y  W d   dS 1 scw   Y  dS )a  
        Score the model using X/y data and a selected metric.

        If an 'id' column is defined in either dataframe, it will be used as the primary key.

        References:
            https://dev.mysql.com/doc/heatwave/en/mys-hwaml-ml-score.html
                A full list of supported options can be found under
                "Options for Recommendation Models" and
                "Options for Anomaly Detection Models"

        Args:
            X: DataFrame of features.
            y: DataFrame or Series of labels.
            metric: Metric name (e.g., "balanced_accuracy").
            options: Optional ml scoring options.

        Raises:
            DatabaseError:
                If provided options are invalid or unsupported,
                or if the model is not initialized, i.e., fit or import has not
                been called
                If a database connection issue occurs.
                If an operational error occurs during execution.

        Returns:
            float: Computed score.
        c                    r   rH   r   r   r   r"   r#   r     r   zMyModel.score.<locals>.<lambda>N)
r	   r   r2   r   r   rq   r   r7   r   r|   )r;   r   r   r{   rk   r<   r   rj   r   r   r   ri   scorer"   r   r#   r     s    #RzMyModel.scorec              	      s   t |}tjP tj:}t j|\}}|j|f t fdd}|j|f |||}|W  d   W  d   S 1 sLw   Y  W d   dS 1 s\w   Y  dS )a  
        Explain model predictions using provided data.

        If an 'id' column is defined in either dataframe, it will be used as the primary key.

        References:
            https://dev.mysql.com/doc/heatwave/en/mys-hwaml-ml-explain-table.html
                A full list of supported options can be found under
                "ML_EXPLAIN_TABLE Options"

        Args:
            X: DataFrame for which predictions should be explained.
            options: Optional dictionary of explainability options.

        Returns:
            DataFrame containing explanation details (feature attributions, etc.)

        Raises:
            DatabaseError:
                If provided options are invalid or unsupported, or if the model is not initialized,
                i.e., fit or import has not been called
                If a database connection issue occurs.
                If an operational error occurs during execution.

        Notes:
            Temporary input/output tables are cleaned up after explanation.
        c                    r   rH   r   r   r   r"   r#   r     r   z-MyModel.explain_predictions.<locals>.<lambda>N)	r	   r   r2   r   r   r7   r   r   r   )r;   r   rk   r   r   ri   rw   explanationsr"   r   r#   explain_predictions  s$    RzMyModel.explain_predictionsrH   )r%   r&   r'   r(   r   r   r   npndarrayr   r   r   r   r    r   r   r   r   r"   r"   r"   r#   r      sP    
E
=
8r   )!r(   rq   rI   enumr   typingr   r   r   r   numpyr   pandasr   mysql.ai.utilsr   r   r	   r
   r   r   r   r   r   r   r   r   r   mysql.connector.abstractsr   r   r1   r   r"   r"   r"   r#   <module>   s   <   I