Coverage for mlos_bench/mlos_bench/environments/composite_env.py: 89%

96 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-10-07 01:52 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5"""Composite benchmark environment.""" 

6 

7import logging 

8from datetime import datetime 

9from types import TracebackType 

10from typing import Any, Dict, List, Optional, Tuple, Type 

11 

12from typing_extensions import Literal 

13 

14from mlos_bench.environments.base_environment import Environment 

15from mlos_bench.environments.status import Status 

16from mlos_bench.services.base_service import Service 

17from mlos_bench.tunables.tunable import TunableValue 

18from mlos_bench.tunables.tunable_groups import TunableGroups 

19 

20_LOG = logging.getLogger(__name__) 

21 

22 

23class CompositeEnv(Environment): 

24 """Composite benchmark environment.""" 

25 

26 def __init__( # pylint: disable=too-many-arguments 

27 self, 

28 *, 

29 name: str, 

30 config: dict, 

31 global_config: Optional[dict] = None, 

32 tunables: Optional[TunableGroups] = None, 

33 service: Optional[Service] = None, 

34 ): 

35 """ 

36 Create a new environment with a given config. 

37 

38 Parameters 

39 ---------- 

40 name: str 

41 Human-readable name of the environment. 

42 config : dict 

43 Free-format dictionary that contains the environment 

44 configuration. Must have a "children" section. 

45 global_config : dict 

46 Free-format dictionary of global parameters (e.g., security credentials) 

47 to be mixed in into the "const_args" section of the local config. 

48 tunables : TunableGroups 

49 A collection of groups of tunable parameters for *all* environments. 

50 service: Service 

51 An optional service object (e.g., providing methods to 

52 deploy or reboot a VM, etc.). 

53 """ 

54 super().__init__( 

55 name=name, 

56 config=config, 

57 global_config=global_config, 

58 tunables=tunables, 

59 service=service, 

60 ) 

61 

62 # By default, the Environment includes only the tunables explicitly specified 

63 # in the "tunable_params" section of the config. `CompositeEnv`, however, must 

64 # retain all tunables from its children environments plus the ones that come 

65 # from the "include_tunables". 

66 tunables = tunables.copy() if tunables else TunableGroups() 

67 

68 _LOG.debug("Build composite environment '%s' START: %s", self, tunables) 

69 self._children: List[Environment] = [] 

70 self._child_contexts: List[Environment] = [] 

71 

72 # To support trees of composite environments (e.g. for multiple VM experiments), 

73 # each CompositeEnv gets a copy of the original global config and adjusts it with 

74 # the `const_args` specific to it. 

75 global_config = (global_config or {}).copy() 

76 for key, val in self._const_args.items(): 

77 global_config.setdefault(key, val) 

78 

79 for child_config_file in config.get("include_children", []): 

80 for env in self._config_loader_service.load_environment_list( 

81 child_config_file, 

82 tunables, 

83 global_config, 

84 self._const_args, 

85 self._service, 

86 ): 

87 self._add_child(env, tunables) 

88 

89 for child_config in config.get("children", []): 

90 env = self._config_loader_service.build_environment( 

91 child_config, 

92 tunables, 

93 global_config, 

94 self._const_args, 

95 self._service, 

96 ) 

97 self._add_child(env, tunables) 

98 

99 _LOG.debug("Build composite environment '%s' END: %s", self, self._tunable_params) 

100 

101 if not self._children: 

102 raise ValueError("At least one child environment must be present") 

103 

104 def __enter__(self) -> Environment: 

105 self._child_contexts = [env.__enter__() for env in self._children] 

106 return super().__enter__() 

107 

108 def __exit__( 

109 self, 

110 ex_type: Optional[Type[BaseException]], 

111 ex_val: Optional[BaseException], 

112 ex_tb: Optional[TracebackType], 

113 ) -> Literal[False]: 

114 ex_throw = None 

115 for env in reversed(self._children): 

116 try: 

117 env.__exit__(ex_type, ex_val, ex_tb) 

118 # pylint: disable=broad-exception-caught 

119 except Exception as ex: 

120 _LOG.error("Exception while exiting child environment '%s': %s", env, ex) 

121 ex_throw = ex 

122 self._child_contexts = [] 

123 super().__exit__(ex_type, ex_val, ex_tb) 

124 if ex_throw: 

125 raise ex_throw 

126 return False 

127 

128 @property 

129 def children(self) -> List[Environment]: 

130 """Return the list of child environments.""" 

131 return self._children 

132 

133 def pprint(self, indent: int = 4, level: int = 0) -> str: 

134 """ 

135 Pretty-print the environment and its children. 

136 

137 Parameters 

138 ---------- 

139 indent : int 

140 Number of spaces to indent the output at each level. Default is 4. 

141 level : int 

142 Current level of indentation. Default is 0. 

143 

144 Returns 

145 ------- 

146 pretty : str 

147 Pretty-printed environment configuration. 

148 """ 

149 return ( 

150 super().pprint(indent, level) 

151 + "\n" 

152 + "\n".join(child.pprint(indent, level + 1) for child in self._children) 

153 ) 

154 

155 def _add_child(self, env: Environment, tunables: TunableGroups) -> None: 

156 """ 

157 Add a new child environment to the composite environment. 

158 

159 This method is called from the constructor only. 

160 """ 

161 _LOG.debug("Merge tunables: '%s' <- '%s' :: %s", self, env, env.tunable_params) 

162 self._children.append(env) 

163 self._tunable_params.merge(env.tunable_params) 

164 tunables.merge(env.tunable_params) 

165 

166 def setup(self, tunables: TunableGroups, global_config: Optional[dict] = None) -> bool: 

167 """ 

168 Set up the children environments. 

169 

170 Parameters 

171 ---------- 

172 tunables : TunableGroups 

173 A collection of tunable parameters along with their values. 

174 global_config : dict 

175 Free-format dictionary of global parameters of the environment 

176 that are not used in the optimization process. 

177 

178 Returns 

179 ------- 

180 is_success : bool 

181 True if all children setup() operations are successful, 

182 false otherwise. 

183 """ 

184 assert self._in_context 

185 self._is_ready = super().setup(tunables, global_config) and all( 

186 env_context.setup(tunables, global_config) for env_context in self._child_contexts 

187 ) 

188 return self._is_ready 

189 

190 def teardown(self) -> None: 

191 """ 

192 Tear down the children environments. 

193 

194 This method is idempotent, i.e., calling it several times is equivalent to a 

195 single call. The environments are being torn down in the reverse order. 

196 """ 

197 assert self._in_context 

198 for env_context in reversed(self._child_contexts): 

199 env_context.teardown() 

200 super().teardown() 

201 

202 def run(self) -> Tuple[Status, datetime, Optional[Dict[str, TunableValue]]]: 

203 """ 

204 Submit a new experiment to the environment. Return the result of the *last* 

205 child environment if successful, or the status of the last failed environment 

206 otherwise. 

207 

208 Returns 

209 ------- 

210 (status, timestamp, output) : (Status, datetime, dict) 

211 3-tuple of (Status, timestamp, output) values, where `output` is a dict 

212 with the results or None if the status is not COMPLETED. 

213 If run script is a benchmark, then the score is usually expected to 

214 be in the `score` field. 

215 """ 

216 _LOG.info("Run: %s", self._children) 

217 (status, timestamp, metrics) = super().run() 

218 if not status.is_ready(): 

219 return (status, timestamp, metrics) 

220 

221 joint_metrics = {} 

222 for env_context in self._child_contexts: 

223 _LOG.debug("Child env. run: %s", env_context) 

224 (status, timestamp, metrics) = env_context.run() 

225 _LOG.debug("Child env. run results: %s :: %s %s", env_context, status, metrics) 

226 if not status.is_good(): 

227 _LOG.info("Run failed: %s :: %s", self, status) 

228 return (status, timestamp, None) 

229 joint_metrics.update(metrics or {}) 

230 

231 _LOG.info("Run completed: %s :: %s %s", self, status, joint_metrics) 

232 # Return the status and the timestamp of the last child environment. 

233 return (status, timestamp, joint_metrics) 

234 

235 def status(self) -> Tuple[Status, datetime, List[Tuple[datetime, str, Any]]]: 

236 """ 

237 Check the status of the benchmark environment. 

238 

239 Returns 

240 ------- 

241 (benchmark_status, timestamp, telemetry) : (Status, datetime, list) 

242 3-tuple of (benchmark status, timestamp, telemetry) values. 

243 `timestamp` is UTC time stamp of the status; it's current time by default. 

244 `telemetry` is a list (maybe empty) of (timestamp, metric, value) triplets. 

245 """ 

246 (status, timestamp, telemetry) = super().status() 

247 if not status.is_ready(): 

248 return (status, timestamp, telemetry) 

249 

250 joint_telemetry = [] 

251 final_status = None 

252 for env_context in self._child_contexts: 

253 (status, timestamp, telemetry) = env_context.status() 

254 _LOG.debug("Child env. status: %s :: %s", env_context, status) 

255 joint_telemetry.extend(telemetry) 

256 if not status.is_good() and final_status is None: 

257 final_status = status 

258 

259 final_status = final_status or status 

260 _LOG.info("Final status: %s :: %s", self, final_status) 

261 # Return the status and the timestamp of the last child environment or the 

262 # first failed child environment. 

263 return (final_status, timestamp, joint_telemetry)