Coverage for mlos_bench/mlos_bench/environments/composite_env.py: 88%

95 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2024-12-14 01:58 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5"""Composite benchmark environment.""" 

6 

7import logging 

8from datetime import datetime 

9from types import TracebackType 

10from typing import Any, Dict, List, Literal, Optional, Tuple, Type 

11 

12from mlos_bench.environments.base_environment import Environment 

13from mlos_bench.environments.status import Status 

14from mlos_bench.services.base_service import Service 

15from mlos_bench.tunables.tunable import TunableValue 

16from mlos_bench.tunables.tunable_groups import TunableGroups 

17 

18_LOG = logging.getLogger(__name__) 

19 

20 

21class CompositeEnv(Environment): 

22 """Composite benchmark environment.""" 

23 

24 def __init__( # pylint: disable=too-many-arguments 

25 self, 

26 *, 

27 name: str, 

28 config: dict, 

29 global_config: Optional[dict] = None, 

30 tunables: Optional[TunableGroups] = None, 

31 service: Optional[Service] = None, 

32 ): 

33 """ 

34 Create a new environment with a given config. 

35 

36 Parameters 

37 ---------- 

38 name: str 

39 Human-readable name of the environment. 

40 config : dict 

41 Free-format dictionary that contains the environment 

42 configuration. Must have a "children" section. 

43 global_config : dict 

44 Free-format dictionary of global parameters (e.g., security credentials) 

45 to be mixed in into the "const_args" section of the local config. 

46 tunables : TunableGroups 

47 A collection of groups of tunable parameters for *all* environments. 

48 service: Service 

49 An optional service object (e.g., providing methods to 

50 deploy or reboot a VM, etc.). 

51 """ 

52 super().__init__( 

53 name=name, 

54 config=config, 

55 global_config=global_config, 

56 tunables=tunables, 

57 service=service, 

58 ) 

59 

60 # By default, the Environment includes only the tunables explicitly specified 

61 # in the "tunable_params" section of the config. `CompositeEnv`, however, must 

62 # retain all tunables from its children environments plus the ones that come 

63 # from the "include_tunables". 

64 tunables = tunables.copy() if tunables else TunableGroups() 

65 

66 _LOG.debug("Build composite environment '%s' START: %s", self, tunables) 

67 self._children: List[Environment] = [] 

68 self._child_contexts: List[Environment] = [] 

69 

70 # To support trees of composite environments (e.g. for multiple VM experiments), 

71 # each CompositeEnv gets a copy of the original global config and adjusts it with 

72 # the `const_args` specific to it. 

73 global_config = (global_config or {}).copy() 

74 for key, val in self._const_args.items(): 

75 global_config.setdefault(key, val) 

76 

77 for child_config_file in config.get("include_children", []): 

78 for env in self._config_loader_service.load_environment_list( 

79 child_config_file, 

80 tunables, 

81 global_config, 

82 self._const_args, 

83 self._service, 

84 ): 

85 self._add_child(env, tunables) 

86 

87 for child_config in config.get("children", []): 

88 env = self._config_loader_service.build_environment( 

89 child_config, 

90 tunables, 

91 global_config, 

92 self._const_args, 

93 self._service, 

94 ) 

95 self._add_child(env, tunables) 

96 

97 _LOG.debug("Build composite environment '%s' END: %s", self, self._tunable_params) 

98 

99 if not self._children: 

100 raise ValueError("At least one child environment must be present") 

101 

102 def __enter__(self) -> Environment: 

103 self._child_contexts = [env.__enter__() for env in self._children] 

104 return super().__enter__() 

105 

106 def __exit__( 

107 self, 

108 ex_type: Optional[Type[BaseException]], 

109 ex_val: Optional[BaseException], 

110 ex_tb: Optional[TracebackType], 

111 ) -> Literal[False]: 

112 ex_throw = None 

113 for env in reversed(self._children): 

114 try: 

115 env.__exit__(ex_type, ex_val, ex_tb) 

116 # pylint: disable=broad-exception-caught 

117 except Exception as ex: 

118 _LOG.error("Exception while exiting child environment '%s': %s", env, ex) 

119 ex_throw = ex 

120 self._child_contexts = [] 

121 super().__exit__(ex_type, ex_val, ex_tb) 

122 if ex_throw: 

123 raise ex_throw 

124 return False 

125 

126 @property 

127 def children(self) -> List[Environment]: 

128 """Return the list of child environments.""" 

129 return self._children 

130 

131 def pprint(self, indent: int = 4, level: int = 0) -> str: 

132 """ 

133 Pretty-print the environment and its children. 

134 

135 Parameters 

136 ---------- 

137 indent : int 

138 Number of spaces to indent the output at each level. Default is 4. 

139 level : int 

140 Current level of indentation. Default is 0. 

141 

142 Returns 

143 ------- 

144 pretty : str 

145 Pretty-printed environment configuration. 

146 """ 

147 return ( 

148 super().pprint(indent, level) 

149 + "\n" 

150 + "\n".join(child.pprint(indent, level + 1) for child in self._children) 

151 ) 

152 

153 def _add_child(self, env: Environment, tunables: TunableGroups) -> None: 

154 """ 

155 Add a new child environment to the composite environment. 

156 

157 This method is called from the constructor only. 

158 """ 

159 _LOG.debug("Merge tunables: '%s' <- '%s' :: %s", self, env, env.tunable_params) 

160 self._children.append(env) 

161 self._tunable_params.merge(env.tunable_params) 

162 tunables.merge(env.tunable_params) 

163 

164 def setup(self, tunables: TunableGroups, global_config: Optional[dict] = None) -> bool: 

165 """ 

166 Set up the children environments. 

167 

168 Parameters 

169 ---------- 

170 tunables : TunableGroups 

171 A collection of tunable parameters along with their values. 

172 global_config : dict 

173 Free-format dictionary of global parameters of the environment 

174 that are not used in the optimization process. 

175 

176 Returns 

177 ------- 

178 is_success : bool 

179 True if all children setup() operations are successful, 

180 false otherwise. 

181 """ 

182 assert self._in_context 

183 self._is_ready = super().setup(tunables, global_config) and all( 

184 env_context.setup(tunables, global_config) for env_context in self._child_contexts 

185 ) 

186 return self._is_ready 

187 

188 def teardown(self) -> None: 

189 """ 

190 Tear down the children environments. 

191 

192 This method is idempotent, i.e., calling it several times is equivalent to a 

193 single call. The environments are being torn down in the reverse order. 

194 """ 

195 assert self._in_context 

196 for env_context in reversed(self._child_contexts): 

197 env_context.teardown() 

198 super().teardown() 

199 

200 def run(self) -> Tuple[Status, datetime, Optional[Dict[str, TunableValue]]]: 

201 """ 

202 Submit a new experiment to the environment. Return the result of the *last* 

203 child environment if successful, or the status of the last failed environment 

204 otherwise. 

205 

206 Returns 

207 ------- 

208 (status, timestamp, output) : (Status, datetime.datetime, dict) 

209 3-tuple of (Status, timestamp, output) values, where `output` is a dict 

210 with the results or None if the status is not COMPLETED. 

211 If run script is a benchmark, then the score is usually expected to 

212 be in the `score` field. 

213 """ 

214 _LOG.info("Run: %s", self._children) 

215 (status, timestamp, metrics) = super().run() 

216 if not status.is_ready(): 

217 return (status, timestamp, metrics) 

218 

219 joint_metrics = {} 

220 for env_context in self._child_contexts: 

221 _LOG.debug("Child env. run: %s", env_context) 

222 (status, timestamp, metrics) = env_context.run() 

223 _LOG.debug("Child env. run results: %s :: %s %s", env_context, status, metrics) 

224 if not status.is_good(): 

225 _LOG.info("Run failed: %s :: %s", self, status) 

226 return (status, timestamp, None) 

227 joint_metrics.update(metrics or {}) 

228 

229 _LOG.info("Run completed: %s :: %s %s", self, status, joint_metrics) 

230 # Return the status and the timestamp of the last child environment. 

231 return (status, timestamp, joint_metrics) 

232 

233 def status(self) -> Tuple[Status, datetime, List[Tuple[datetime, str, Any]]]: 

234 """ 

235 Check the status of the benchmark environment. 

236 

237 Returns 

238 ------- 

239 (benchmark_status, timestamp, telemetry) : (Status, datetime.datetime, list) 

240 3-tuple of (benchmark status, timestamp, telemetry) values. 

241 `timestamp` is UTC time stamp of the status; it's current time by default. 

242 `telemetry` is a list (maybe empty) of (timestamp, metric, value) triplets. 

243 """ 

244 (status, timestamp, telemetry) = super().status() 

245 if not status.is_ready(): 

246 return (status, timestamp, telemetry) 

247 

248 joint_telemetry = [] 

249 final_status = None 

250 for env_context in self._child_contexts: 

251 (status, timestamp, telemetry) = env_context.status() 

252 _LOG.debug("Child env. status: %s :: %s", env_context, status) 

253 joint_telemetry.extend(telemetry) 

254 if not status.is_good() and final_status is None: 

255 final_status = status 

256 

257 final_status = final_status or status 

258 _LOG.info("Final status: %s :: %s", self, final_status) 

259 # Return the status and the timestamp of the last child environment or the 

260 # first failed child environment. 

261 return (final_status, timestamp, joint_telemetry)