Coverage for mlos_viz/mlos_viz/dabl.py: 100%

22 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2024-12-14 01:58 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5""" 

6Small wrapper functions for plotting :py:mod:`mlos_bench` data via 

7:external:py:func:`dabl.plot`. 

8 

9Notes 

10----- 

11See `dabl <https://dabl.github.io/stable/>`_ for more information on the dabl library. 

12""" 

13import warnings 

14from typing import Dict, Literal, Optional 

15 

16import dabl 

17import pandas 

18 

19from mlos_bench.storage.base_experiment_data import ExperimentData 

20from mlos_viz.util import expand_results_data_args 

21 

22 

23def plot( 

24 exp_data: Optional[ExperimentData] = None, 

25 *, 

26 results_df: Optional[pandas.DataFrame] = None, 

27 objectives: Optional[Dict[str, Literal["min", "max"]]] = None, 

28) -> None: 

29 """ 

30 Plots the :py:class:`~mlos_bench.storage.base_storage.Storage.Experiment` results 

31 data using :external:py:func:`dabl.plot`. 

32 

33 Parameters 

34 ---------- 

35 exp_data : ExperimentData 

36 The ExperimentData (e.g., obtained from the storage layer) to plot. 

37 results_df : Optional[pandas.DataFrame] 

38 Optional results_df to plot. 

39 If not provided, defaults to exp_data.results_df property. 

40 objectives : Optional[Dict[str, Literal["min", "max"]]] 

41 Optional objectives to plot. 

42 If not provided, defaults to exp_data.objectives property. 

43 """ 

44 (results_df, obj_cols) = expand_results_data_args(exp_data, results_df, objectives) 

45 for obj_col in obj_cols: 

46 dabl.plot(X=results_df, target_col=obj_col) 

47 

48 

49def ignore_plotter_warnings() -> None: 

50 """Add some filters to ignore warnings from the plotter.""" 

51 # pylint: disable=import-outside-toplevel 

52 warnings.filterwarnings("ignore", category=FutureWarning) 

53 warnings.filterwarnings( 

54 "ignore", 

55 module="dabl", 

56 category=UserWarning, 

57 message="Could not infer format", 

58 ) 

59 warnings.filterwarnings( 

60 "ignore", 

61 module="dabl", 

62 category=UserWarning, 

63 message="(Dropped|Discarding) .* outliers", 

64 ) 

65 warnings.filterwarnings( 

66 "ignore", 

67 module="dabl", 

68 category=UserWarning, 

69 message="Not plotting highly correlated", 

70 ) 

71 warnings.filterwarnings( 

72 "ignore", 

73 module="dabl", 

74 category=UserWarning, 

75 message="Missing values in target_col have been removed for regression", 

76 ) 

77 from sklearn.exceptions import UndefinedMetricWarning 

78 

79 warnings.filterwarnings( 

80 "ignore", 

81 module="sklearn", 

82 category=UndefinedMetricWarning, 

83 message="Recall is ill-defined", 

84 ) 

85 warnings.filterwarnings( 

86 "ignore", 

87 category=DeprecationWarning, 

88 message="is_categorical_dtype is deprecated and will be removed in a future version.", 

89 ) 

90 warnings.filterwarnings( 

91 "ignore", 

92 category=DeprecationWarning, 

93 module="sklearn", 

94 message="is_sparse is deprecated and will be removed in a future version.", 

95 ) 

96 from matplotlib._api.deprecation import MatplotlibDeprecationWarning 

97 

98 warnings.filterwarnings( 

99 "ignore", 

100 category=MatplotlibDeprecationWarning, 

101 module="dabl", 

102 message="The legendHandles attribute was deprecated in Matplotlib 3.7 and will be removed", 

103 )