"""Session management for SimpleBench."""
from __future__ import annotations
from argparse import ArgumentError, ArgumentParser, Namespace
from datetime import datetime
from pathlib import Path
from typing import TYPE_CHECKING, Callable, Optional, Sequence
from rich.console import Console
from rich.progress import Progress
from simplebench import defaults
from simplebench.case import Case
from simplebench.doc_utils import format_docstring
from simplebench.enums import Color, Target, Verbosity
from simplebench.exceptions import SimpleBenchArgumentError, SimpleBenchTypeError, _SessionErrorTag
from simplebench.reporters.choice import Choice
from simplebench.reporters.choices import Choices
from simplebench.reporters.log.report_log_metadata import ReportLogMetadata
from simplebench.reporters.protocols import ReporterCallback
from simplebench.reporters.reporter_manager import ReporterManager
from simplebench.runners import SimpleRunner
from simplebench.tasks import ProgressTracker, RichProgressTasks
from simplebench.utils import sanitize_filename
if TYPE_CHECKING:
from simplebench.reporters.reporter import Reporter
[docs]
class Session():
"""Container for session related information while running benchmarks.
The session is responsible for managing benchmark cases, command line
arguments, progress display, and report generation.
This makes it the primary orchestrator for running benchmarks and generating
reports.
"""
@format_docstring(DEFAULT_TIMER=defaults.DEFAULT_TIMER.__name__)
def __init__(self,
*,
cases: Optional[Sequence[Case]] = None,
verbosity: Verbosity = Verbosity.NORMAL,
default_runner: type[SimpleRunner] | None = None,
args_parser: Optional[ArgumentParser] = None,
show_progress: bool = False,
output_path: Optional[Path] = None,
console: Optional[Console] = None,
timer: Callable[[], int] | None = None) -> None:
"""Container and orchestrator for session related information while running benchmarks.
:param cases: A Sequence of benchmark cases for the session.
If None, an empty list will be created. Defaults to None.
:param verbosity: The verbosity level for console output.
Defaults to :attr:`Verbosity.NORMAL`.
:param default_runner: The default runner class to use
for Cases that do not specify a runner. If None, the default :class:`~.runners.SimpleRunner` is used.
Defaults to None.
:param args_parser: The :class:`~argparse.ArgumentParser` instance for the
session. If None, a new :class:`~argparse.ArgumentParser` will be automatically created.
Defaults to None.
:param show_progress: Whether to show progress bars during execution.
Defaults to False.
:param output_path: The output path for reports. Defaults to None.
:param console: A Rich Console instance for displaying output. If None,
a new Console will be automatically created. Defaults to None.
:param timer: A callable that returns the current time for timing benchmarks.
If None, a default timer `simplebench.defaults.DEFAULT_TIMER` ({DEFAULT_TIMER})
will be used. Defaults to None.
:raises SimpleBenchTypeError: If the arguments are of the wrong type.
""" # params here are for IDEs
# public read/write properties with private backing fields
self.default_runner = default_runner
self.args_parser = ArgumentParser() if args_parser is None else args_parser
self.cases = [] if cases is None else cases
self.verbosity = verbosity
self.show_progress = show_progress
self.output_path = output_path
self.console = Console() if console is None else console
self.timer = defaults.DEFAULT_TIMER if timer is None else timer
# private attributes
self._args_parsed: bool = False
"""Whether the command line arguments have been parsed."""
self._reporter_flags_added: bool = False
"""Whether the reporter flags have been added to the ArgumentParser."""
self._progress_tasks: RichProgressTasks = RichProgressTasks(verbosity=verbosity, console=self.console)
"""ProgressTasks instance for managing progress tasks - backing field for the 'tasks' attribute."""
self._progress: Progress = self.tasks.progress
"""Rich Progress instance for displaying progress bars - backing field for the 'progress' attribute."""
self._reporter_manager: ReporterManager = ReporterManager()
"""The ReporterManager instance for managing reporters."""
self._choices: Choices = self._reporter_manager.choices
"""The Choices instance for managing registered reporters."""
# backing fields for public read-only properties
self._args: Optional[Namespace] = None
"""The command line arguments - backing field for the 'args' attribute."""
self._console: Console = self._progress.console
"""Rich Console instance for displaying output - backing field for the 'console' attribute."""
[docs]
def parse_args(self, args: Sequence[str] | None = None) -> None:
"""Parse the command line arguments using the session's :class:`~argparse.ArgumentParser`.
This method parses the command line arguments and stores them in the session's :attr:`args` property.
By default, it parses the arguments from :data:`sys.argv`. If ``args`` is provided, it will parse
the arguments from the provided sequence of strings instead.
This can be used to customize the command line arguments for testing or other purposes.
It automatically calls :meth:`add_reporter_flags` to ensure that reporter flags
are added to the ArgumentParser if it has not already been called.
If you wish to customize the ArgumentParser before or after adding the reporter flags,
you can do so by calling :meth:`add_reporter_flags` before calling this method.
:param args: A list of command line arguments to parse. If None,
the arguments will be taken from :data:`sys.argv`. Defaults to None.
:type args: Sequence[str], optional
:raises SimpleBenchTypeError: If the ``args_parser`` is not set.
"""
if self._args_parsed:
return
self._args_parsed = True
if args is not None:
if not isinstance(args, Sequence):
raise SimpleBenchTypeError(
"'args' argument must either be None or a list of str: "
f"type of passed 'args' was {type(args).__name__}",
tag=_SessionErrorTag.PARSE_ARGS_INVALID_ARGS_TYPE)
args = tuple(args)
if not all(isinstance(arg, str) for arg in args):
raise SimpleBenchTypeError(
"'args' argument must either be None or a list of str: A non-str item was found in the passed list",
tag=_SessionErrorTag.PARSE_ARGS_INVALID_ARGS_TYPE)
self.add_reporter_flags()
self._args = self._args_parser.parse_args(args=args)
@property
def reporter_manager(self) -> ReporterManager:
"""Returns the :class:`~.reporters.reporter_manager.ReporterManager` instance for managing reporters.
The reporter manager handles the registration and discovery of all reporters
available to the session. This allows you to customize reporting behavior
by adding your own reporters or removing default ones.
:return: The :class:`~.reporters.reporter_manager.ReporterManager` instance for managing reporters.
:rtype: ReporterManager
"""
return self._reporter_manager
[docs]
def add_reporter_flags(self) -> None:
"""Add the command line flags for all registered reporters to the session's ArgumentParser.
Any conflicts in flag names with already declared :class:`~argparse.ArgumentParser` flags will have to be
handled by the reporters themselves.
It has its own method so that a user can customize the :class:`~argparse.ArgumentParser`
before or after adding the reporter flags as needed.
It also allows the user to unregister reporters before adding the reporter flags if they
want to omit specific built-in reporters entirely.
It is called internally by :meth:`parse_args` if it has not already been called and
does not re-add the reporter flags if called again.
:raises SimpleBenchArgumentError: If there is a conflict or other error in reporter flag names.
"""
if self._reporter_flags_added:
return
try:
# Add reporter flags to the ArgumentParser based on command line args defined in each registered Choice
self._reporter_manager.add_reporters_to_argparse(self._args_parser)
self._reporter_flags_added = True
except ArgumentError as arg_err:
raise SimpleBenchArgumentError(
argument_name=arg_err.argument_name,
message=f'Error adding reporter flags to ArgumentParser: {arg_err.message}',
tag=_SessionErrorTag.ARGUMENT_ERROR_ADDING_FLAGS
) from arg_err
[docs]
def run(self) -> None:
"""Run all benchmark cases in the session.
This method iterates over all :class:`~.case.Case` instances in the session's
cases and invokes their :meth:`~.Case.run` method to execute the benchmarks.
If the :meth:`parse_args` method has not been called prior to invoking this method,
it will be called authomatically with no arguments to parse from :data:`sys.argv`.
If you wish to customize argument parsing or run the session entirely programmatically
without command line args, you should call :meth:`parse_args` before calling this method.
:raises SimpleBenchTimeoutError: If a benchmark case times out during execution.
:raises SimpleBenchBenchmarkError: If an error occurs during the execution of a benchmark.
"""
if not self._args_parsed:
self.parse_args()
if self._verbosity > Verbosity.NORMAL:
self._console.print(f'Running {len(self.cases)} benchmark case(s)...')
self.tasks.clear()
progress_tracker = ProgressTracker(
session=self,
task_name='Session:cases',
progress_max=len(self.cases),
description='Running benchmark cases',
color=Color.WHITE
)
if self.show_progress and self.verbosity > Verbosity.QUIET and self.tasks:
self.tasks.start()
case_counter: int = 0
progress_tracker.reset()
progress_tracker.update(
completed=0,
description=f'Running benchmark cases (case {case_counter + 1:2d}/{len(self.cases)})')
progress_tracker.start()
for case in self.cases:
progress_tracker.update(
description=f'Running benchmark cases (case {case_counter + 1:2d}/{len(self.cases)})',
completed=case_counter,
refresh=True)
case_counter += 1
case.run(session=self)
progress_tracker.stop()
self.tasks.stop()
self.tasks.clear()
[docs]
def report_keys(self) -> list[str]:
"""Get a list of report keys for all reports to be generated in this session.
This filters the report choices based on the command line arguments
that were set and parsed when the session was created and returns a list of
report keys for the reports that should be generated.
The report keys correspond to the command line flags/args defined
in the Choices of the registered reporters.
.. note::
This method requires that the session's :attr:`args` property
has been set by calling :meth:`parse_args` prior to invoking this method.
Otherwise it will return an empty list.
:return: A list of report keys for all reports to be generated in this session.
"""
report_keys: list[str] = []
for key in self._choices.all_choice_args():
# skip all Choices that are not set in self.args
if not getattr(self.args, key, None):
continue
report_keys.append(key)
return report_keys
[docs]
def report(self) -> None:
"""Generate reports for all benchmark cases in the session."""
# all_choice_args returns a set of all Namespace args from all Choice instances
# we check each arg to see if it is set in self.args.
# The logic here is that if the arg is set, the user wants that report. By
# making the lookup go from the defined Choices to the args, we ensure
# that we only consider valid args that are associated with a Choice.
if self.verbosity > Verbosity.NORMAL:
self._console.print(f"Generating reports for {len(self.cases)} case(s)...")
now = datetime.now()
epoch_timestamp = now.timestamp()
timestamp = now.strftime('%Y%m%d%H%M%S')
processed_choices: set[str] = set()
report_keys: list[str] = self.report_keys()
n_reports = len(report_keys)
self.tasks.clear()
reports_progress_tracker = ProgressTracker(
session=self,
task_name='Session:reports',
progress_max=n_reports,
description='Running reports',
color=Color.WHITE
)
cases_progress_tracker = ProgressTracker(
session=self,
task_name='Session:cases',
progress_max=len(self.cases),
description='Generating reports for cases',
color=Color.CYAN)
reports_progress_tracker.start()
report_counter: int = 0
for key in report_keys:
report_counter += 1
if self.verbosity >= Verbosity.DEBUG:
self._console.print(f"[DEBUG] Checking report for arg '{key}'")
choice: Choice | None = self._choices.get_choice_for_arg(key)
if not isinstance(choice, Choice):
raise SimpleBenchTypeError(
"choice must be a Choice instance",
tag=_SessionErrorTag.REPORT_INVALID_CHOICE_RETRIEVED
)
# If we have already processed this Choice (there can be multiple
# possible valid triggering args defined for a single Choice), then skip it.
if choice and choice.name in processed_choices:
continue
processed_choices.add(choice.name)
reports_progress_tracker.update(
description=f'Running report {choice.name} ({report_counter:2d}/{n_reports})',
completed=report_counter - 1,
refresh=True)
cases_progress_tracker.reset()
for case_counter, case in enumerate(self.cases, start=1):
cases_progress_tracker.update(
description=(
f'Generating reports for case {case.title} (case {case_counter:2d}/{len(self.cases)})'),
completed=case_counter - 1,
refresh=True)
callback: Optional[ReporterCallback] = case.callback
reporter: Reporter = choice.reporter
output_path: Path | None = self._output_path
report_log_path: Path | None = output_path / "_reports_log" if output_path is not None else None
if Target.FILESYSTEM in choice.targets:
if output_path is None:
flag: str = '--' + key.replace('_', '-')
raise SimpleBenchTypeError(
f'output_path must be set to generate Choice {choice.name} / {flag} report',
tag=_SessionErrorTag.REPORT_OUTPUT_PATH_NOT_SET
)
group_path = sanitize_filename(case.group)
output_path = output_path / timestamp / group_path
if self.verbosity >= Verbosity.DEBUG:
self._console.print(f"[DEBUG] Output path for report: {output_path}")
log_metadata = ReportLogMetadata(
timestamp=epoch_timestamp,
case=case,
choice=choice,
reports_log_path=report_log_path)
if self.args: # mypy guard
reporter.report(
log_metadata=log_metadata,
args=self.args,
case=case,
choice=choice,
path=output_path,
session=self,
callback=callback)
cases_progress_tracker.stop()
reports_progress_tracker.stop()
self.tasks.stop()
self.tasks.clear()
@property
def timer(self) -> Callable[[], int]:
"""The timer function used for benchmarking."""
return self._timer
@timer.setter
def timer(self, value: Callable[[], int]) -> None:
"""Set the timer function used for benchmarking.
:param value: The timer function used for benchmarking.
:raises SimpleBenchTypeError: If the value is not a callable that returns an int.
"""
if not callable(value):
raise SimpleBenchTypeError(
f'timer must be a callable - cannot be a {type(value)}',
tag=_SessionErrorTag.PROPERTY_INVALID_TIMER_ARG
)
test_result = value()
if not isinstance(test_result, int):
raise SimpleBenchTypeError(
f'timer callable must return an int - cannot return a {type(test_result)}',
tag=_SessionErrorTag.PROPERTY_INVALID_TIMER_RETURN_TYPE
)
self._timer = value
@property
def default_runner(self) -> type[SimpleRunner] | None:
"""The session scoped default runner class to use for Cases that do not specify a runner."""
return self._default_runner
@default_runner.setter
def default_runner(self, value: type[SimpleRunner] | None) -> None:
"""Set the session scoped default runner class to use for Cases that do not specify a runner.
Example:
.. code-block:: python
from simplebench import Session
from mybenchmark.runners import MyCustomRunner
session = Session(default_runner=MyCustomRunner)
:param value: The default runner class to use for Cases that do
not specify a runner. Default is :class:`~.runners.SimpleRunner`.
:type value: type[SimpleRunner] or None
:raises SimpleBenchTypeError: If the value is not a subclass of :class:`~.runners.SimpleRunner` or None.
"""
if value is not None and not (isinstance(value, type) and issubclass(value, SimpleRunner)):
raise SimpleBenchTypeError(
f'default_runner must be a subclass of SimpleRunner or None - cannot be a {type(value)}',
tag=_SessionErrorTag.PROPERTY_INVALID_DEFAULT_RUNNER_ARG
)
self._default_runner = value
@property
def args_parser(self) -> ArgumentParser:
"""The ArgumentParser instance for the session."""
return self._args_parser
@args_parser.setter
def args_parser(self, value: ArgumentParser) -> None:
"""Set the :class:`~argparse.ArgumentParser` instance for the session.
:param value: The :class:`~argparse.ArgumentParser` instance for the session.
:type value: ArgumentParser
"""
if not isinstance(value, ArgumentParser):
raise SimpleBenchTypeError(
f'args_parser must be an ArgumentParser instance - cannot be a {type(value)}',
tag=_SessionErrorTag.PROPERTY_INVALID_ARGSPARSER_ARG
)
self._args_parser = value
@property
def args(self) -> Optional[Namespace]:
"""The command line arguments for the session. This will be None until the parse_args()
method has been called."""
return self._args
@args.setter
def args(self, value: Namespace) -> None:
"""Set the command line arguments for the session.
:param value: The command line arguments for the session.
:type value: Namespace
"""
if not isinstance(value, Namespace):
raise SimpleBenchTypeError(
f'args must be a Namespace instance - cannot be a {type(value)}',
tag=_SessionErrorTag.PROPERTY_INVALID_ARGS_ARG
)
self._args = value
@property
def progress(self) -> Progress:
"""The Rich Progress instance for displaying progress bars."""
return self._progress
@property
def show_progress(self) -> bool:
"""Whether to show progress bars during execution."""
return self._show_progress
@show_progress.setter
def show_progress(self, value: bool) -> None:
"""Set whether to show progress bars during execution.
:param value: Whether to show progress bars during execution.
:type value: bool
:raises SimpleBenchTypeError: If the value is not a bool.
"""
if not isinstance(value, bool):
raise SimpleBenchTypeError(
f'progress must be a bool - cannot be a {type(value)}',
tag=_SessionErrorTag.PROPERTY_INVALID_PROGRESS_ARG
)
self._show_progress = value
@property
def tasks(self) -> RichProgressTasks:
"""The RichProgressTasks instance for managing progress tasks."""
return self._progress_tasks
@property
def verbosity(self) -> Verbosity:
"""The Verbosity level for this session."""
return self._verbosity
@verbosity.setter
def verbosity(self, value: Verbosity) -> None:
"""Set the Verbosity level for this session.
:param value: The new verbosity level for the session.
:type value: Verbosity
:raises SimpleBenchTypeError: If the value is not a :class:`~.enums.Verbosity` instance.
"""
if not isinstance(value, Verbosity):
raise SimpleBenchTypeError(
f'verbosity must be a Verbosity instance - cannot be a {type(value)}',
tag=_SessionErrorTag.PROPERTY_INVALID_VERBOSITY_ARG
)
self._verbosity = value
@property
def cases(self) -> tuple[Case]:
"""Tuple of Cases for this session."""
return self._cases # type: ignore[return-value]
@cases.setter
def cases(self, value: Sequence[Case]) -> None:
"""Set the tuple of :class:`~simplebench.Cases` for this session.
This replaces all existing Cases in the session.
:param value: Sequence of Cases for the Session
:type value: Sequence[Case]
:raises SimpleBenchTypeError: If the value is not a :class:`Sequence` of :class:`~simplebench.Case` instances.
"""
if not isinstance(value, Sequence):
raise SimpleBenchTypeError(
f'value must be a Sequence of Case - cannot be a {type(value)}',
tag=_SessionErrorTag.PROPERTY_INVALID_CASES_ARG
)
for case in value:
if not isinstance(case, Case):
error_text = f'items in Sequence must be Case instances - cannot be a {type(case)}'
raise SimpleBenchTypeError(
error_text,
tag=_SessionErrorTag.PROPERTY_INVALID_CASE_ARG_IN_SEQUENCE
)
self._cases = tuple(value)
[docs]
def add_case(self, case: Case) -> None:
"""Add a :class:`~.case.Case` to the Cases for this session.
:param case: benchmark case to add to the Session
:raises SimpleBenchTypeError: If the value is not a :class:`~.case.Case` instance.
"""
if not isinstance(case, Case):
raise SimpleBenchTypeError(
f'case must be a Case instance - cannot be a {type(case)}',
tag=_SessionErrorTag.PROPERTY_INVALID_CASE_ARG
)
self._cases = tuple(list(self._cases) + [case])
[docs]
def extend_cases(self, cases: Sequence[Case]) -> None:
"""Extend the Cases for this session.
:param cases: Sequence of Cases to add to the Session
:type cases: Sequence[Case]
:raises SimpleBenchTypeError: If the value is not a Sequence of Cases.
"""
if not isinstance(cases, Sequence):
raise SimpleBenchTypeError(
f'cases must be a Sequence of Case - cannot be a {type(cases)}',
tag=_SessionErrorTag.PROPERTY_INVALID_CASES_ARG
)
for case in cases:
if not isinstance(case, Case):
error_text = f'items in Sequence must be Case instances - cannot be a {type(case)}'
raise SimpleBenchTypeError(
error_text,
tag=_SessionErrorTag.PROPERTY_INVALID_CASE_ARG_IN_SEQUENCE
)
self._cases = tuple(list(self._cases) + list(cases))
@property
def output_path(self) -> Path | None:
"""The output path for reports."""
return self._output_path
@output_path.setter
def output_path(self, value: Path | None) -> None:
"""Set the output path for reports.
:param value: The output path for reports.
:type value: Path or None
"""
if value is not None and not isinstance(value, Path):
raise SimpleBenchTypeError(
f'output_path must be a Path instance - cannot be a {type(value)}',
tag=_SessionErrorTag.PROPERTY_INVALID_OUTPUT_PATH_ARG
)
self._output_path = value
@property
def console(self) -> Console:
"""The Rich Console instance for displaying output."""
return self._console
@console.setter
def console(self, value: Console) -> None:
"""Set the Rich Console instance for displaying output.
:param value: The Rich Console instance for displaying output.
:type value: Console
"""
if not isinstance(value, Console):
raise SimpleBenchTypeError(
f'console must be a Console instance - cannot be a {type(value)}',
tag=_SessionErrorTag.PROPERTY_INVALID_CONSOLE_ARG
)
self._console = value