Source code for aitoolbox.experiment.result_package.basic_packages

from aitoolbox.experiment.result_package.abstract_result_packages import AbstractResultPackage
from aitoolbox.experiment.core_metrics.abstract_metric import AbstractBaseMetric
from aitoolbox.experiment.core_metrics.classification import AccuracyMetric, ROCAUCMetric, \
    PrecisionRecallCurveAUCMetric, F1ScoreMetric
from aitoolbox.experiment.core_metrics.regression import MeanSquaredErrorMetric, MeanAbsoluteErrorMetric


[docs]class GeneralResultPackage(AbstractResultPackage): def __init__(self, metrics_list, strict_content_check=False, **kwargs): """Result package executing given list of metrics Args: metrics_list (list): List of objects which are inherited from aitoolbox.experiment.core_metrics.BaseMetric.AbstractBaseMetric strict_content_check (bool): should just print warning or raise the error and crash **kwargs (dict): additional package_metadata for the result package """ AbstractResultPackage.__init__(self, pkg_name='GeneralResultPackage', strict_content_check=strict_content_check, **kwargs) self.metrics_list = metrics_list
[docs] def prepare_results_dict(self): self.qa_check_hyperparameters_dict() results_dict = {} for metric in self.metrics_list: metric_result = metric(self.y_true, self.y_predicted) results_dict = results_dict + metric_result return results_dict
[docs] def qa_check_metrics_list(self): if len(self.metrics_list) == 0: self.warn_about_result_data_problem('Metrics list is empty') for metric in self.metrics_list: if not isinstance(metric, AbstractBaseMetric): self.warn_about_result_data_problem('Metric is not inherited from AbstractBaseMetric class')
[docs]class BinaryClassificationResultPackage(AbstractResultPackage): def __init__(self, positive_class_thresh=0.5, strict_content_check=False, **kwargs): """Binary classification task result package Evaluates the following metrics: accuracy, ROC-AUC, PR-AUC and F1 score Args: positive_class_thresh (float or None): predicted probability positive class threshold strict_content_check (bool): should just print warning or raise the error and crash **kwargs (dict): additional package_metadata for the result package """ AbstractResultPackage.__init__(self, pkg_name='BinaryClassificationResult', strict_content_check=strict_content_check, **kwargs) self.positive_class_thresh = positive_class_thresh
[docs] def prepare_results_dict(self): accuracy_result = AccuracyMetric(self.y_true, self.y_predicted, positive_class_thresh=self.positive_class_thresh) roc_auc_result = ROCAUCMetric(self.y_true, self.y_predicted) pr_auc_result = PrecisionRecallCurveAUCMetric(self.y_true, self.y_predicted) f1_score_result = F1ScoreMetric(self.y_true, self.y_predicted, positive_class_thresh=self.positive_class_thresh) return accuracy_result + roc_auc_result + pr_auc_result + f1_score_result
[docs]class ClassificationResultPackage(AbstractResultPackage): def __init__(self, strict_content_check=False, **kwargs): """Multi-class classification result package Evaluates the accuracy of the predictions. Without Precision-Recall metric which is available only for binary classification problems. Args: strict_content_check (bool): should just print warning or raise the error and crash **kwargs (dict): additional package_metadata for the result package """ AbstractResultPackage.__init__(self, pkg_name='ClassificationResult', strict_content_check=strict_content_check, **kwargs)
[docs] def prepare_results_dict(self): accuracy_result = AccuracyMetric(self.y_true, self.y_predicted, positive_class_thresh=None).get_metric_dict() return accuracy_result
[docs]class RegressionResultPackage(AbstractResultPackage): def __init__(self, strict_content_check=False, **kwargs): """Regression task result package Evaluates MSE and MAE metrics. Args: strict_content_check (bool): should just print warning or raise the error and crash **kwargs (dict): additional package_metadata for the result package """ AbstractResultPackage.__init__(self, pkg_name='RegressionResult', strict_content_check=strict_content_check, **kwargs)
[docs] def prepare_results_dict(self): mse_result = MeanSquaredErrorMetric(self.y_true, self.y_predicted) mae_result = MeanAbsoluteErrorMetric(self.y_true, self.y_predicted) return mse_result + mae_result