# -*- coding: utf-8 -*-
"""Reporter for benchmark results using graphs."""
from __future__ import annotations
from io import BytesIO
from typing import TYPE_CHECKING, Any, ClassVar, TypeAlias
import matplotlib as mpl
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from simplebench.enums import Section
from simplebench.exceptions import SimpleBenchTypeError
from simplebench.reporters.reporter import ReporterOptions
from simplebench.results import Results
from simplebench.si_units import si_scale_for_largest
from simplebench.type_proxies import is_case
from simplebench.validators import validate_type
from ...matplotlib import MatPlotLibReporter
from .config import ScatterPlotConfig
from .exceptions import _ScatterPlotReporterErrorTag
from .options import ScatterPlotOptions
Options: TypeAlias = ScatterPlotOptions
if TYPE_CHECKING:
from simplebench.case import Case
[docs]
class ScatterPlotReporter(MatPlotLibReporter):
"""Class for outputting benchmark results as scatter plot graphs.
This reporter generates scatter plot visualizations for various result sections,
saving them to the filesystem or passing them to a callback function. It provides
a visual way to compare the performance of different benchmark variations.
**Defined command-line flags:**
* ``--scatter-plot: {filesystem, callback}`` (default=filesystem)
* ``--scatter-plot.ops: ...``
* ``--scatter-plot.timings: ...``
* ``--scatter-plot.memory: ...``
**Example usage:**
.. code-block:: none
program.py --scatter-plot # Outputs graphs to the filesystem.
program.py --scatter-plot.ops filesystem # Outputs only ops graphs to the filesystem.
"""
_OPTIONS_TYPE: ClassVar[type[ScatterPlotOptions]] = ScatterPlotOptions # pylint: disable=line-too-long # type: ignore[reportIncompatibleVariableOveride] # noqa: E501
""":ivar: The specific :class:`~.ReporterOptions` subclass associated with this reporter.
:vartype: ~typing.ClassVar[type[~.ScatterPlotOptions]]
"""
_OPTIONS_KWARGS: ClassVar[dict[str, Any]] = ScatterPlotOptions.DEFAULT_KWARGS
""":ivar: The default keyword arguments for the :class:`~.ScatterPlotOptions` subclass.
:vartype: ~typing.ClassVar[dict[str, ~typing.Any]]
"""
def __init__(self, config: ScatterPlotConfig | None = None) -> None:
"""Initialize the ScatterPlotReporter.
.. note::
The exception documentation below refers to validation of subclass configuration
class variables ``_OPTIONS_TYPE`` and ``_OPTIONS_KWARGS``. These must be correctly
defined in any subclass of :class:`ScatterPlotReporter` to ensure proper
functionality.
:param config: An optional configuration object to override default reporter settings.
If not provided, default settings will be used.
:type config: ScatterPlotConfig | None
:raises ~simplebench.exceptions.SimpleBenchTypeError: If the subclass configuration
types are invalid.
:raises ~simplebench.exceptions.SimpleBenchValueError: If the subclass configuration
values are invalid.
"""
if config is None:
config = ScatterPlotConfig()
super().__init__(config)
[docs]
def render(self, *, case: Case, section: Section, options: ReporterOptions) -> bytes:
"""Render the scatter plot graph and return it as bytes.
:param case: The :class:`~simplebench.case.Case` instance representing the
benchmarked code.
:param section: The section of the results to plot.
:param options: The options for rendering the scatter plot.
:return: The rendered graph as bytes. The format is determined by the options.
The defaults are defined in :class:`~.ScatterPlotOptions`.
:raises ~simplebench.exceptions.SimpleBenchTypeError: If the provided arguments are not
of the expected types or values.
:raises ~simplebench.exceptions.SimpleBenchValueError: If the provided values are not
valid.
"""
# is_* checks provide deferred import validation to avoid circular imports
if not is_case(case):
raise SimpleBenchTypeError(
f"'case' argument must be a Case instance, got {type(case)}",
tag=_ScatterPlotReporterErrorTag.RENDER_INVALID_CASE)
section = validate_type(section, Section, 'section',
_ScatterPlotReporterErrorTag.RENDER_INVALID_SECTION)
options = validate_type(
options, Options, 'options',
_ScatterPlotReporterErrorTag.RENDER_INVALID_OPTIONS)
base_unit = self.get_base_unit_for_section(section=section)
results: list[Results] = case.results
all_numbers = self.get_all_stats_values(results=results, section=section)
common_unit, common_scale = si_scale_for_largest(numbers=all_numbers, base_unit=base_unit)
target_name = f'{section.value} ({common_unit})'
with BytesIO() as graphfile:
with mpl.rc_context():
plot_data = []
x_axis_legend = 'N'
for result in results:
x = result.n
target_stats = result.results_section(section)
value = target_stats.mean * common_scale
plot_data.append((x, value))
df = pd.DataFrame(plot_data, columns=[x_axis_legend, target_name])
# See https://matplotlib.org/stable/users/explain/customizing.html#the-matplotlibrc-file
benchmarking_theme = options.theme
mpl.rcParams.update(benchmarking_theme)
figure_height = options.height / options.dpi # inches
figure_width = options.width / options.dpi # inches
# Create the plot
with plt.style.context(options.style):
g = sns.scatterplot(data=df, y=target_name, x=x_axis_legend)
g.figure.suptitle(case.title, fontsize='large', weight='bold')
g.figure.subplots_adjust(top=.9)
g.figure.set_dpi(options.dpi) # dots per inch
g.figure.set_figheight(figure_height) # type: ignore[reportAttributeAccessIssue,union-attr]
g.figure.set_figwidth(figure_width) # type: ignore[reportAttributeAccessIssue,union-attr]
g.tick_params("x", rotation=options.x_labels_rotation)
if options.y_starts_at_zero:
_, top = plt.ylim()
plt.ylim(bottom=0, top=top * 1.10) # Add 10% headroom
plt.savefig(graphfile, format=options.image_type)
plt.close() # Close the figure to free memory
graphfile.flush()
return graphfile.getvalue()
def _old_render(self, *, case: Case, section: Section, options: ReporterOptions) -> bytes:
"""Render the scatter plot graph and return it as bytes.
:param case: The :class:`~simplebench.case.Case` instance representing the
benchmarked code.
:param section: The section of the results to plot.
:param options: The options for rendering the scatter plot.
:return: The rendered graph as bytes. The format is determined by the options.
The defaults are defined in :class:`~.ScatterPlotOptions`.
:raises ~simplebench.exceptions.SimpleBenchTypeError: If the provided arguments are not
of the expected types or values.
:raises ~simplebench.exceptions.SimpleBenchValueError: If the provided values are not
valid.
"""
# is_* checks provide deferred import validation to avoid circular imports
if not is_case(case):
raise SimpleBenchTypeError(
f"'case' argument must be a Case instance, got {type(case)}",
tag=_ScatterPlotReporterErrorTag.RENDER_INVALID_CASE)
section = validate_type(section, Section, 'section',
_ScatterPlotReporterErrorTag.RENDER_INVALID_SECTION)
options = validate_type(
options, Options, 'options',
_ScatterPlotReporterErrorTag.RENDER_INVALID_OPTIONS)
base_unit = self.get_base_unit_for_section(section=section)
results: list[Results] = case.results
all_numbers = self.get_all_stats_values(results=results, section=section)
common_unit, common_scale = si_scale_for_largest(numbers=all_numbers, base_unit=base_unit)
target_name = f'{section.value} ({base_unit})'
with BytesIO() as graphfile:
with mpl.rc_context():
plot_data = []
x_axis_legend = '\n'.join([
f"{case.variation_cols.get(k, k)}" for k in case.variation_cols.keys()])
for result in results:
target_stats = result.results_section(section)
variation_label = '\n'.join([f"{v}" for v in result.variation_marks.values()])
plot_data.append({
x_axis_legend: variation_label,
target_name: target_stats.mean * common_scale,
})
# See https://matplotlib.org/stable/users/explain/customizing.html#the-matplotlibrc-file
benchmarking_theme = options.theme
mpl.rcParams.update(benchmarking_theme)
df = pd.DataFrame(plot_data)
# Create the plot
with plt.style.context(options.style):
g = sns.relplot(data=df, y=target_name, x=x_axis_legend)
g.figure.suptitle(case.title, fontsize='large', weight='bold')
g.figure.subplots_adjust(top=.9)
g.figure.set_dpi(options.dpi) # dots per inch
g.figure.set_figheight(options.height / options.dpi) # inches
g.figure.set_figwidth(options.width / options.dpi) # inches
g.tick_params("x", rotation=options.x_labels_rotation)
# format the labels with f-strings
for ax in g.axes.flat:
ax.yaxis.set_major_formatter('{x}' + f' {common_unit}')
if options.y_starts_at_zero:
_, top = plt.ylim()
plt.ylim(bottom=0, top=top * 1.10) # Add 10% headroom
plt.savefig(graphfile, format=options.image_type)
plt.close() # Close the figure to free memory
graphfile.flush()
return graphfile.getvalue()