Coverage for mlos_viz/mlos_viz/__init__.py: 100%

25 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-10-07 01:52 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5"""mlos_viz is a framework to help visualizing, explain, and gain insights from results 

6from the mlos_bench framework for benchmarking and optimization automation. 

7""" 

8 

9from enum import Enum 

10from typing import Any, Dict, Literal, Optional 

11 

12import pandas 

13 

14from mlos_bench.storage.base_experiment_data import ExperimentData 

15from mlos_viz import base 

16from mlos_viz.util import expand_results_data_args 

17from mlos_viz.version import VERSION 

18 

19__version__ = VERSION 

20 

21 

22class MlosVizMethod(Enum): 

23 """What method to use for visualizing the experiment results.""" 

24 

25 DABL = "dabl" 

26 AUTO = DABL # use dabl as the current default 

27 

28 

29def ignore_plotter_warnings(plotter_method: MlosVizMethod = MlosVizMethod.AUTO) -> None: 

30 """ 

31 Suppress some annoying warnings from third-party data visualization packages by 

32 adding them to the warnings filter. 

33 

34 Parameters 

35 ---------- 

36 plotter_method: MlosVizMethod 

37 The method to use for visualizing the experiment results. 

38 """ 

39 base.ignore_plotter_warnings() 

40 if plotter_method == MlosVizMethod.DABL: 

41 import mlos_viz.dabl # pylint: disable=import-outside-toplevel 

42 

43 mlos_viz.dabl.ignore_plotter_warnings() 

44 else: 

45 raise NotImplementedError(f"Unhandled method: {plotter_method}") 

46 

47 

48def plot( 

49 exp_data: Optional[ExperimentData] = None, 

50 *, 

51 results_df: Optional[pandas.DataFrame] = None, 

52 objectives: Optional[Dict[str, Literal["min", "max"]]] = None, 

53 plotter_method: MlosVizMethod = MlosVizMethod.AUTO, 

54 filter_warnings: bool = True, 

55 **kwargs: Any, 

56) -> None: 

57 """ 

58 Plots the results of the experiment. 

59 

60 Intended to be used from a Jupyter notebook. 

61 

62 Parameters 

63 ---------- 

64 exp_data: ExperimentData 

65 The experiment data to plot. 

66 results_df : Optional["pandas.DataFrame"] 

67 Optional results_df to plot. 

68 If not provided, defaults to exp_data.results_df property. 

69 objectives : Optional[Dict[str, Literal["min", "max"]]] 

70 Optional objectives to plot. 

71 If not provided, defaults to exp_data.objectives property. 

72 plotter_method: MlosVizMethod 

73 The method to use for visualizing the experiment results. 

74 filter_warnings: bool 

75 Whether or not to filter some warnings from the plotter. 

76 kwargs : dict 

77 Remaining keyword arguments are passed along to the underlying plotter(s). 

78 """ 

79 if filter_warnings: 

80 ignore_plotter_warnings(plotter_method) 

81 (results_df, _obj_cols) = expand_results_data_args(exp_data, results_df, objectives) 

82 

83 base.plot_optimizer_trends(exp_data, results_df=results_df, objectives=objectives) 

84 base.plot_top_n_configs(exp_data, results_df=results_df, objectives=objectives, **kwargs) 

85 

86 if MlosVizMethod.DABL: 

87 import mlos_viz.dabl # pylint: disable=import-outside-toplevel 

88 

89 mlos_viz.dabl.plot(exp_data, results_df=results_df, objectives=objectives) 

90 else: 

91 raise NotImplementedError(f"Unhandled method: {plotter_method}")