Source code for aitoolbox.experiment.result_package.torch_metrics_packages

from aitoolbox.experiment.result_package.abstract_result_packages import AbstractResultPackage


[docs]class TorchMetricsPackage(AbstractResultPackage): def __init__(self, torch_metrics): """Torch Metrics result package wrapper https://github.com/Lightning-AI/metrics Args: torch_metrics (torchmetrics.Metric or torchmetrics.MetricCollection): single torchmetrics metric object or a collection of such metrics wrapped inside the MetricCollection """ AbstractResultPackage.__init__(self, pkg_name='Torch Metrics', np_array=False) self.metric = torch_metrics
[docs] def prepare_results_dict(self): metric_result = self.metric(self.y_predicted, self.y_true) if not isinstance(metric_result, dict): metric_result = {self.metric.__class__.__name__: metric_result} # Add suffix PTLMetrics to indicate that we are using PyTorch Lightning metrics instead of aitoolbox metric metric_result = {f'{k}_PTLMetrics': v for k, v in metric_result.items()} return metric_result
[docs] def metric_compute(self): return self.metric.compute()
[docs] def metric_reset(self): self.metric.reset()