Coverage for mlos_bench/mlos_bench/environments/script_env.py: 100%
31 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-10-07 01:52 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-10-07 01:52 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""Base scriptable benchmark environment."""
7import abc
8import logging
9import re
10from typing import Dict, Iterable, Optional
12from mlos_bench.environments.base_environment import Environment
13from mlos_bench.services.base_service import Service
14from mlos_bench.tunables.tunable import TunableValue
15from mlos_bench.tunables.tunable_groups import TunableGroups
16from mlos_bench.util import try_parse_val
18_LOG = logging.getLogger(__name__)
21class ScriptEnv(Environment, metaclass=abc.ABCMeta):
22 """Base Environment that runs scripts for setup/run/teardown."""
24 _RE_INVALID = re.compile(r"[^a-zA-Z0-9_]")
26 def __init__( # pylint: disable=too-many-arguments
27 self,
28 *,
29 name: str,
30 config: dict,
31 global_config: Optional[dict] = None,
32 tunables: Optional[TunableGroups] = None,
33 service: Optional[Service] = None,
34 ):
35 """
36 Create a new environment for script execution.
38 Parameters
39 ----------
40 name: str
41 Human-readable name of the environment.
42 config : dict
43 Free-format dictionary that contains the benchmark environment
44 configuration. Each config must have at least the `tunable_params`
45 and the `const_args` sections. It must also have at least one of
46 the following parameters: {`setup`, `run`, `teardown`}.
47 Additional parameters:
48 * `shell_env_params` - an array of parameters to pass to the script
49 as shell environment variables, and
50 * `shell_env_params_rename` - a dictionary of {to: from} mappings
51 of the script parameters. If not specified, replace all
52 non-alphanumeric characters with underscores.
53 If neither `shell_env_params` nor `shell_env_params_rename` are specified,
54 *no* additional shell parameters will be passed to the script.
55 global_config : dict
56 Free-format dictionary of global parameters (e.g., security credentials)
57 to be mixed in into the "const_args" section of the local config.
58 tunables : TunableGroups
59 A collection of tunable parameters for *all* environments.
60 service: Service
61 An optional service object (e.g., providing methods to
62 deploy or reboot a VM, etc.).
63 """
64 super().__init__(
65 name=name,
66 config=config,
67 global_config=global_config,
68 tunables=tunables,
69 service=service,
70 )
72 self._script_setup = self.config.get("setup")
73 self._script_run = self.config.get("run")
74 self._script_teardown = self.config.get("teardown")
76 self._shell_env_params: Iterable[str] = self.config.get("shell_env_params", [])
77 self._shell_env_params_rename: Dict[str, str] = self.config.get(
78 "shell_env_params_rename", {}
79 )
81 results_stdout_pattern = self.config.get("results_stdout_pattern")
82 self._results_stdout_pattern: Optional[re.Pattern[str]] = (
83 re.compile(results_stdout_pattern, flags=re.MULTILINE)
84 if results_stdout_pattern
85 else None
86 )
88 def _get_env_params(self, restrict: bool = True) -> Dict[str, str]:
89 """
90 Get the *shell* environment parameters to be passed to the script.
92 Parameters
93 ----------
94 restrict : bool
95 If True, only return the parameters that are in the `_shell_env_params`
96 list. If False, return all parameters in `_params` with some possible
97 conversions.
99 Returns
100 -------
101 env_params : Dict[str, str]
102 Parameters to pass as *shell* environment variables into the script.
103 This is usually a subset of `_params` with some possible conversions.
104 """
105 input_params = self._shell_env_params if restrict else self._params.keys()
106 rename = {self._RE_INVALID.sub("_", key): key for key in input_params}
107 rename.update(self._shell_env_params_rename)
108 return {key_sub: str(self._params[key]) for (key_sub, key) in rename.items()}
110 def _extract_stdout_results(self, stdout: str) -> Dict[str, TunableValue]:
111 """
112 Extract the results from the stdout of the script.
114 Parameters
115 ----------
116 stdout : str
117 The stdout of the script.
119 Returns
120 -------
121 results : Dict[str, TunableValue]
122 A dictionary of results extracted from the stdout.
123 """
124 if not self._results_stdout_pattern:
125 return {}
126 _LOG.debug("Extract regex: '%s' from: '%s'", self._results_stdout_pattern, stdout)
127 return {
128 key: try_parse_val(val) for (key, val) in self._results_stdout_pattern.findall(stdout)
129 }