Coverage for mlos_bench/mlos_bench/environments/script_env.py: 100%

31 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-10-07 01:52 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5"""Base scriptable benchmark environment.""" 

6 

7import abc 

8import logging 

9import re 

10from typing import Dict, Iterable, Optional 

11 

12from mlos_bench.environments.base_environment import Environment 

13from mlos_bench.services.base_service import Service 

14from mlos_bench.tunables.tunable import TunableValue 

15from mlos_bench.tunables.tunable_groups import TunableGroups 

16from mlos_bench.util import try_parse_val 

17 

18_LOG = logging.getLogger(__name__) 

19 

20 

21class ScriptEnv(Environment, metaclass=abc.ABCMeta): 

22 """Base Environment that runs scripts for setup/run/teardown.""" 

23 

24 _RE_INVALID = re.compile(r"[^a-zA-Z0-9_]") 

25 

26 def __init__( # pylint: disable=too-many-arguments 

27 self, 

28 *, 

29 name: str, 

30 config: dict, 

31 global_config: Optional[dict] = None, 

32 tunables: Optional[TunableGroups] = None, 

33 service: Optional[Service] = None, 

34 ): 

35 """ 

36 Create a new environment for script execution. 

37 

38 Parameters 

39 ---------- 

40 name: str 

41 Human-readable name of the environment. 

42 config : dict 

43 Free-format dictionary that contains the benchmark environment 

44 configuration. Each config must have at least the `tunable_params` 

45 and the `const_args` sections. It must also have at least one of 

46 the following parameters: {`setup`, `run`, `teardown`}. 

47 Additional parameters: 

48 * `shell_env_params` - an array of parameters to pass to the script 

49 as shell environment variables, and 

50 * `shell_env_params_rename` - a dictionary of {to: from} mappings 

51 of the script parameters. If not specified, replace all 

52 non-alphanumeric characters with underscores. 

53 If neither `shell_env_params` nor `shell_env_params_rename` are specified, 

54 *no* additional shell parameters will be passed to the script. 

55 global_config : dict 

56 Free-format dictionary of global parameters (e.g., security credentials) 

57 to be mixed in into the "const_args" section of the local config. 

58 tunables : TunableGroups 

59 A collection of tunable parameters for *all* environments. 

60 service: Service 

61 An optional service object (e.g., providing methods to 

62 deploy or reboot a VM, etc.). 

63 """ 

64 super().__init__( 

65 name=name, 

66 config=config, 

67 global_config=global_config, 

68 tunables=tunables, 

69 service=service, 

70 ) 

71 

72 self._script_setup = self.config.get("setup") 

73 self._script_run = self.config.get("run") 

74 self._script_teardown = self.config.get("teardown") 

75 

76 self._shell_env_params: Iterable[str] = self.config.get("shell_env_params", []) 

77 self._shell_env_params_rename: Dict[str, str] = self.config.get( 

78 "shell_env_params_rename", {} 

79 ) 

80 

81 results_stdout_pattern = self.config.get("results_stdout_pattern") 

82 self._results_stdout_pattern: Optional[re.Pattern[str]] = ( 

83 re.compile(results_stdout_pattern, flags=re.MULTILINE) 

84 if results_stdout_pattern 

85 else None 

86 ) 

87 

88 def _get_env_params(self, restrict: bool = True) -> Dict[str, str]: 

89 """ 

90 Get the *shell* environment parameters to be passed to the script. 

91 

92 Parameters 

93 ---------- 

94 restrict : bool 

95 If True, only return the parameters that are in the `_shell_env_params` 

96 list. If False, return all parameters in `_params` with some possible 

97 conversions. 

98 

99 Returns 

100 ------- 

101 env_params : Dict[str, str] 

102 Parameters to pass as *shell* environment variables into the script. 

103 This is usually a subset of `_params` with some possible conversions. 

104 """ 

105 input_params = self._shell_env_params if restrict else self._params.keys() 

106 rename = {self._RE_INVALID.sub("_", key): key for key in input_params} 

107 rename.update(self._shell_env_params_rename) 

108 return {key_sub: str(self._params[key]) for (key_sub, key) in rename.items()} 

109 

110 def _extract_stdout_results(self, stdout: str) -> Dict[str, TunableValue]: 

111 """ 

112 Extract the results from the stdout of the script. 

113 

114 Parameters 

115 ---------- 

116 stdout : str 

117 The stdout of the script. 

118 

119 Returns 

120 ------- 

121 results : Dict[str, TunableValue] 

122 A dictionary of results extracted from the stdout. 

123 """ 

124 if not self._results_stdout_pattern: 

125 return {} 

126 _LOG.debug("Extract regex: '%s' from: '%s'", self._results_stdout_pattern, stdout) 

127 return { 

128 key: try_parse_val(val) for (key, val) in self._results_stdout_pattern.findall(stdout) 

129 }