Coverage for mlos_bench/mlos_bench/environments/script_env.py: 100%

31 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2024-12-14 01:58 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5""" 

6Base scriptable benchmark environment. 

7 

8TODO: Document how variable propogation works in the script environments using 

9shell_env_params, required_args, const_args, etc. 

10""" 

11 

12import abc 

13import logging 

14import re 

15from typing import Dict, Iterable, Optional 

16 

17from mlos_bench.environments.base_environment import Environment 

18from mlos_bench.services.base_service import Service 

19from mlos_bench.tunables.tunable import TunableValue 

20from mlos_bench.tunables.tunable_groups import TunableGroups 

21from mlos_bench.util import try_parse_val 

22 

23_LOG = logging.getLogger(__name__) 

24 

25 

26class ScriptEnv(Environment, metaclass=abc.ABCMeta): 

27 """Base Environment that runs scripts for the different phases (e.g., 

28 :py:meth:`.Environment.setup`, :py:meth:`.Environment.run`, 

29 :py:meth:`.Environment.teardown`, etc.) 

30 """ 

31 

32 _RE_INVALID = re.compile(r"[^a-zA-Z0-9_]") 

33 

34 def __init__( # pylint: disable=too-many-arguments 

35 self, 

36 *, 

37 name: str, 

38 config: dict, 

39 global_config: Optional[dict] = None, 

40 tunables: Optional[TunableGroups] = None, 

41 service: Optional[Service] = None, 

42 ): 

43 """ 

44 Create a new environment for script execution. 

45 

46 Parameters 

47 ---------- 

48 name : str 

49 Human-readable name of the environment. 

50 config : dict 

51 Free-format dictionary that contains the benchmark environment 

52 configuration. Each config must have at least the `tunable_params` 

53 and the `const_args` sections. It must also have at least one of 

54 the following parameters: {`setup`, `run`, `teardown`}. 

55 Additional parameters: 

56 

57 - `shell_env_params` - an array of parameters to pass to the script 

58 as shell environment variables, and 

59 - `shell_env_params_rename` - a dictionary of {to: from} mappings 

60 of the script parameters. If not specified, replace all 

61 non-alphanumeric characters with underscores. 

62 

63 If neither `shell_env_params` nor `shell_env_params_rename` are specified, 

64 *no* additional shell parameters will be passed to the script. 

65 global_config : dict 

66 Free-format dictionary of global parameters (e.g., security credentials) 

67 to be mixed in into the "const_args" section of the local config. 

68 tunables : TunableGroups 

69 A collection of tunable parameters for *all* environments. 

70 service : Service 

71 An optional service object (e.g., providing methods to 

72 deploy or reboot a VM, etc.). 

73 """ 

74 super().__init__( 

75 name=name, 

76 config=config, 

77 global_config=global_config, 

78 tunables=tunables, 

79 service=service, 

80 ) 

81 

82 self._script_setup = self.config.get("setup") 

83 self._script_run = self.config.get("run") 

84 self._script_teardown = self.config.get("teardown") 

85 

86 self._shell_env_params: Iterable[str] = self.config.get("shell_env_params", []) 

87 self._shell_env_params_rename: Dict[str, str] = self.config.get( 

88 "shell_env_params_rename", {} 

89 ) 

90 

91 results_stdout_pattern = self.config.get("results_stdout_pattern") 

92 self._results_stdout_pattern: Optional[re.Pattern[str]] = ( 

93 re.compile(results_stdout_pattern, flags=re.MULTILINE) 

94 if results_stdout_pattern 

95 else None 

96 ) 

97 

98 def _get_env_params(self, restrict: bool = True) -> Dict[str, str]: 

99 """ 

100 Get the *shell* environment parameters to be passed to the script. 

101 

102 Parameters 

103 ---------- 

104 restrict : bool 

105 If True, only return the parameters that are in the `_shell_env_params` 

106 list. If False, return all parameters in `_params` with some possible 

107 conversions. 

108 

109 Returns 

110 ------- 

111 env_params : Dict[str, str] 

112 Parameters to pass as *shell* environment variables into the script. 

113 This is usually a subset of `_params` with some possible conversions. 

114 """ 

115 input_params = self._shell_env_params if restrict else self._params.keys() 

116 rename = {self._RE_INVALID.sub("_", key): key for key in input_params} 

117 rename.update(self._shell_env_params_rename) 

118 return {key_sub: str(self._params[key]) for (key_sub, key) in rename.items()} 

119 

120 def _extract_stdout_results(self, stdout: str) -> Dict[str, TunableValue]: 

121 """ 

122 Extract the results from the stdout of the script. 

123 

124 Parameters 

125 ---------- 

126 stdout : str 

127 The stdout of the script. 

128 

129 Returns 

130 ------- 

131 results : Dict[str, TunableValue] 

132 A dictionary of results extracted from the stdout. 

133 """ 

134 if not self._results_stdout_pattern: 

135 return {} 

136 _LOG.debug("Extract regex: '%s' from: '%s'", self._results_stdout_pattern, stdout) 

137 return { 

138 key: try_parse_val(val) for (key, val) in self._results_stdout_pattern.findall(stdout) 

139 }