Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/utilities.rst
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,7 @@ Convenience Tools and Functions
.. automodule:: plot2n
:members: Plot2N
:show-inheritance:

.. automodule:: libensemble.tools.live_data.rich_progress
:members: RichProgress
:show-inheritance:
94 changes: 94 additions & 0 deletions libensemble/tests/unit_tests/test_rich_progress.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
"""Unit tests for RichProgress live data class."""

import pytest


class MockHist:
"""Minimal mock for libensemble History used in tests."""

def __init__(self, sim_ended_count=0, gen_informed_count=0):
self.sim_ended_count = sim_ended_count
self.gen_informed_count = gen_informed_count


def test_rich_progress_sim_max():
"""RichProgress tracks sim_ended_count when sim_max is set."""
pytest.importorskip("rich")
from libensemble.tools.live_data.rich_progress import RichProgress

exit_criteria = {"sim_max": 10}
rp = RichProgress(exit_criteria)
assert rp._sim_max == 10
assert rp._total == 10
assert rp._description == "Sims completed"

hist = MockHist(sim_ended_count=5)
rp.live_update(hist)
task = rp._progress.tasks[rp._task_id]
assert task.completed == 5

rp.finalize(hist)


def test_rich_progress_gen_max():
"""RichProgress tracks gen_informed_count when gen_max is set."""
pytest.importorskip("rich")
from libensemble.tools.live_data.rich_progress import RichProgress

exit_criteria = {"gen_max": 20}
rp = RichProgress(exit_criteria)
assert rp._gen_max == 20
assert rp._total == 20
assert rp._description == "Gen points informed"

hist = MockHist(gen_informed_count=7)
rp.live_update(hist)
task = rp._progress.tasks[rp._task_id]
assert task.completed == 7

rp.finalize(hist)


def test_rich_progress_no_criteria():
"""RichProgress works without exit_criteria (unbounded spinner)."""
pytest.importorskip("rich")
from libensemble.tools.live_data.rich_progress import RichProgress

rp = RichProgress()
assert rp._total is None
assert rp._description == "Running"

hist = MockHist(sim_ended_count=3)
rp.live_update(hist)
task = rp._progress.tasks[rp._task_id]
assert task.completed == 3

rp.finalize(hist)


def test_rich_progress_sim_max_priority():
"""sim_max takes priority over gen_max when both are set."""
pytest.importorskip("rich")
from libensemble.tools.live_data.rich_progress import RichProgress

exit_criteria = {"sim_max": 10, "gen_max": 20}
rp = RichProgress(exit_criteria)
assert rp._sim_max == 10
assert rp._total == 10
assert rp._description == "Sims completed"

rp.finalize(MockHist())


def test_rich_progress_pydantic_exit_criteria():
"""RichProgress works with pydantic ExitCriteria object."""
pytest.importorskip("rich")
from libensemble.specs import ExitCriteria
from libensemble.tools.live_data.rich_progress import RichProgress

exit_criteria = ExitCriteria(sim_max=50)
rp = RichProgress(exit_criteria)
assert rp._sim_max == 50
assert rp._total == 50

rp.finalize(MockHist())
95 changes: 95 additions & 0 deletions libensemble/tools/live_data/rich_progress.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
"""Optional rich.Progress progress bar for libEnsemble CLI runs."""

try:
from rich.progress import BarColumn, MofNCompleteColumn, Progress, SpinnerColumn, TextColumn, TimeElapsedColumn
except ImportError:
raise ImportError("The 'rich' package is required for RichProgress. Install it with: pip install rich")

from libensemble.tools.live_data.live_data import LiveData


class RichProgress(LiveData):
"""Display a rich progress bar in the terminal during an ensemble run.

Shows progress toward ``sim_max`` (simulation completions) or
``gen_max`` (generator points informed), whichever is set in
``exit_criteria``. If both are set, ``sim_max`` takes priority.

Parameters
----------

exit_criteria : dict or :class:`libensemble.specs.ExitCriteria`
The exit criteria used for the run. Must contain either
``sim_max`` or ``gen_max`` to show a bounded progress bar.
If neither is set, an unbounded spinner is displayed instead.

Examples
--------

.. code-block:: python

from libensemble.tools.live_data.rich_progress import RichProgress

libE_specs["live_data"] = RichProgress(exit_criteria)
"""

def __init__(self, exit_criteria=None):
"""Initialize a RichProgress bar.

Parameters
----------
exit_criteria : dict or ExitCriteria, optional
Used to determine the total (sim_max or gen_max).
"""
self._sim_max = None
self._gen_max = None

if exit_criteria is not None:
if hasattr(exit_criteria, "sim_max"):
# Pydantic model (ExitCriteria)
self._sim_max = exit_criteria.sim_max
self._gen_max = exit_criteria.gen_max
else:
# Plain dict
self._sim_max = exit_criteria.get("sim_max")
self._gen_max = exit_criteria.get("gen_max")

if self._sim_max is not None:
total = self._sim_max
description = "Sims completed"
elif self._gen_max is not None:
total = self._gen_max
description = "Gen points informed"
else:
total = None
description = "Running"

self._description = description
self._total = total
self._progress = Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
BarColumn(),
MofNCompleteColumn(),
TimeElapsedColumn(),
)
self._task_id = self._progress.add_task(description, total=total)
self._progress.start()

def live_update(self, hist) -> None:
"""Update the progress bar based on the latest history counts."""
if self._sim_max is not None:
completed = hist.sim_ended_count
elif self._gen_max is not None:
completed = hist.gen_informed_count
else:
completed = hist.sim_ended_count

self._progress.update(self._task_id, completed=completed)

def finalize(self, hist) -> None:
"""Stop the progress bar display."""
# Ensure bar reaches 100% if total is known
if self._total is not None:
self._progress.update(self._task_id, completed=self._total)
self._progress.stop()