Coverage for mlos_bench/mlos_bench/environments/base_environment.py: 93%

148 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2025-01-21 01:50 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5"""A hierarchy of benchmark environments.""" 

6 

7import abc 

8import json 

9import logging 

10from collections.abc import Iterable, Sequence 

11from contextlib import AbstractContextManager as ContextManager 

12from datetime import datetime 

13from types import TracebackType 

14from typing import TYPE_CHECKING, Any, Literal 

15 

16from pytz import UTC 

17 

18from mlos_bench.config.schemas import ConfigSchema 

19from mlos_bench.dict_templater import DictTemplater 

20from mlos_bench.environments.status import Status 

21from mlos_bench.services.base_service import Service 

22from mlos_bench.tunables.tunable import TunableValue 

23from mlos_bench.tunables.tunable_groups import TunableGroups 

24from mlos_bench.util import instantiate_from_config, merge_parameters 

25 

26if TYPE_CHECKING: 

27 from mlos_bench.services.types.config_loader_type import SupportsConfigLoading 

28 

29_LOG = logging.getLogger(__name__) 

30 

31 

32class Environment(ContextManager, metaclass=abc.ABCMeta): 

33 # pylint: disable=too-many-instance-attributes 

34 """An abstract base of all benchmark environments.""" 

35 

36 # Should be provided by the runtime. 

37 _COMMON_CONST_ARGS = { 

38 "trial_runner_id", 

39 } 

40 _COMMON_REQ_ARGS = { 

41 "experiment_id", 

42 "trial_id", 

43 } 

44 

45 @classmethod 

46 def new( # pylint: disable=too-many-arguments 

47 cls, 

48 *, 

49 env_name: str, 

50 class_name: str, 

51 config: dict, 

52 global_config: dict | None = None, 

53 tunables: TunableGroups | None = None, 

54 service: Service | None = None, 

55 ) -> "Environment": 

56 """ 

57 Factory method for a new environment with a given config. 

58 

59 Parameters 

60 ---------- 

61 env_name: str 

62 Human-readable name of the environment. 

63 class_name: str 

64 FQN of a Python class to instantiate, e.g., 

65 "mlos_bench.environments.remote.HostEnv". 

66 Must be derived from the `Environment` class. 

67 config : dict 

68 Free-format dictionary that contains the benchmark environment 

69 configuration. It will be passed as a constructor parameter of 

70 the class specified by `name`. 

71 global_config : dict 

72 Free-format dictionary of global parameters (e.g., security credentials) 

73 to be mixed in into the "const_args" section of the local config. 

74 tunables : TunableGroups 

75 A collection of groups of tunable parameters for all environments. 

76 service: Service 

77 An optional service object (e.g., providing methods to 

78 deploy or reboot a VM/Host, etc.). 

79 

80 Returns 

81 ------- 

82 env : Environment 

83 An instance of the `Environment` class initialized with `config`. 

84 """ 

85 assert issubclass(cls, Environment) 

86 return instantiate_from_config( 

87 cls, 

88 class_name, 

89 name=env_name, 

90 config=config, 

91 global_config=global_config, 

92 tunables=tunables, 

93 service=service, 

94 ) 

95 

96 def __init__( # pylint: disable=too-many-arguments 

97 self, 

98 *, 

99 name: str, 

100 config: dict, 

101 global_config: dict | None = None, 

102 tunables: TunableGroups | None = None, 

103 service: Service | None = None, 

104 ): 

105 """ 

106 Create a new environment with a given config. 

107 

108 Parameters 

109 ---------- 

110 name: str 

111 Human-readable name of the environment. 

112 config : dict 

113 Free-format dictionary that contains the benchmark environment 

114 configuration. Each config must have at least the "tunable_params" 

115 and the "const_args" sections. 

116 global_config : dict 

117 Free-format dictionary of global parameters (e.g., security credentials) 

118 to be mixed in into the "const_args" section of the local config. 

119 tunables : TunableGroups 

120 A collection of groups of tunable parameters for all environments. 

121 service: Service 

122 An optional service object (e.g., providing methods to 

123 deploy or reboot a VM/Host, etc.). 

124 """ 

125 global_config = global_config or {} 

126 self._validate_json_config(config, name) 

127 self.name = name 

128 self.config = config 

129 self._service = service 

130 self._service_context: Service | None = None 

131 self._is_ready = False 

132 self._in_context = False 

133 self._const_args: dict[str, TunableValue] = config.get("const_args", {}) 

134 

135 # Make some usual runtime arguments available for tests. 

136 for arg in self._COMMON_CONST_ARGS | self._COMMON_REQ_ARGS: 

137 global_config.setdefault(arg, self._const_args.get(arg, None)) 

138 

139 if _LOG.isEnabledFor(logging.DEBUG): 

140 _LOG.debug( 

141 "Environment: '%s' Service: %s", 

142 name, 

143 self._service.pprint() if self._service else None, 

144 ) 

145 

146 if tunables is None: 

147 _LOG.warning( 

148 ( 

149 "No tunables provided for %s. " 

150 "Tunable inheritance across composite environments may be broken." 

151 ), 

152 name, 

153 ) 

154 tunables = TunableGroups() 

155 

156 # TODO: add user docstrings for these in the module 

157 groups = self._expand_groups( 

158 config.get("tunable_params", []), 

159 (global_config or {}).get("tunable_params_map", {}), 

160 ) 

161 _LOG.debug("Tunable groups for: '%s' :: %s", name, groups) 

162 

163 self._tunable_params = tunables.subgroup(groups) 

164 

165 # If a parameter comes from the tunables, do not require it in the const_args or globals 

166 req_args = ( 

167 set(config.get("required_args", [])) - self._tunable_params.get_param_values().keys() 

168 ) 

169 req_args.update(self._COMMON_REQ_ARGS | self._COMMON_CONST_ARGS) 

170 merge_parameters(dest=self._const_args, source=global_config, required_keys=req_args) 

171 self._const_args = self._expand_vars(self._const_args, global_config) 

172 

173 self._params = self._combine_tunables(self._tunable_params) 

174 _LOG.debug("Parameters for '%s' :: %s", name, self._params) 

175 

176 if _LOG.isEnabledFor(logging.DEBUG): 

177 _LOG.debug("Config for: '%s'\n%s", name, json.dumps(self.config, indent=2)) 

178 

179 def _validate_json_config(self, config: dict, name: str) -> None: 

180 """Reconstructs a basic json config that this class might have been instantiated 

181 from in order to validate configs provided outside the file loading 

182 mechanism. 

183 """ 

184 json_config: dict = { 

185 "class": self.__class__.__module__ + "." + self.__class__.__name__, 

186 } 

187 if name: 

188 json_config["name"] = name 

189 if config: 

190 json_config["config"] = config 

191 ConfigSchema.ENVIRONMENT.validate(json_config) 

192 

193 @staticmethod 

194 def _expand_groups( 

195 groups: Iterable[str], 

196 groups_exp: dict[str, str | Sequence[str]], 

197 ) -> list[str]: 

198 """ 

199 Expand `$tunable_group` into actual names of the tunable groups. 

200 

201 Parameters 

202 ---------- 

203 groups : list[str] 

204 Names of the groups of tunables, maybe with `$` prefix (subject to expansion). 

205 groups_exp : dict 

206 A dictionary that maps dollar variables for tunable groups to the lists 

207 of actual tunable groups IDs. 

208 

209 Returns 

210 ------- 

211 groups : list[str] 

212 A flat list of tunable groups IDs for the environment. 

213 """ 

214 res: list[str] = [] 

215 for grp in groups: 

216 if grp[:1] == "$": 

217 tunable_group_name = grp[1:] 

218 if tunable_group_name not in groups_exp: 

219 raise KeyError( 

220 f"Expected tunable group name ${tunable_group_name} " 

221 "undefined in {groups_exp}" 

222 ) 

223 add_groups = groups_exp[tunable_group_name] 

224 res += [add_groups] if isinstance(add_groups, str) else add_groups 

225 else: 

226 res.append(grp) 

227 return res 

228 

229 @staticmethod 

230 def _expand_vars( 

231 params: dict[str, TunableValue], 

232 global_config: dict[str, TunableValue], 

233 ) -> dict: 

234 """Expand `$var` into actual values of the variables.""" 

235 return DictTemplater(params).expand_vars(extra_source_dict=global_config) 

236 

237 @property 

238 def _config_loader_service(self) -> "SupportsConfigLoading": 

239 assert self._service is not None 

240 return self._service.config_loader_service 

241 

242 def __enter__(self) -> "Environment": 

243 """Enter the environment's benchmarking context.""" 

244 _LOG.debug("Environment START :: %s", self) 

245 assert not self._in_context 

246 if self._service: 

247 self._service_context = self._service.__enter__() 

248 self._in_context = True 

249 return self 

250 

251 def __exit__( 

252 self, 

253 ex_type: type[BaseException] | None, 

254 ex_val: BaseException | None, 

255 ex_tb: TracebackType | None, 

256 ) -> Literal[False]: 

257 """Exit the context of the benchmarking environment.""" 

258 ex_throw = None 

259 if ex_val is None: 

260 _LOG.debug("Environment END :: %s", self) 

261 else: 

262 assert ex_type and ex_val 

263 _LOG.warning("Environment END :: %s", self, exc_info=(ex_type, ex_val, ex_tb)) 

264 assert self._in_context 

265 if self._service_context: 

266 try: 

267 self._service_context.__exit__(ex_type, ex_val, ex_tb) 

268 # pylint: disable=broad-exception-caught 

269 except Exception as ex: 

270 _LOG.error("Exception while exiting Service context '%s': %s", self._service, ex) 

271 ex_throw = ex 

272 finally: 

273 self._service_context = None 

274 self._in_context = False 

275 if ex_throw: 

276 raise ex_throw 

277 return False # Do not suppress exceptions 

278 

279 def __str__(self) -> str: 

280 return self.name 

281 

282 def __repr__(self) -> str: 

283 return f"{self.__class__.__name__} :: '{self.name}'" 

284 

285 def pprint(self, indent: int = 4, level: int = 0) -> str: 

286 """ 

287 Pretty-print the environment configuration. For composite environments, print 

288 all children environments as well. 

289 

290 Parameters 

291 ---------- 

292 indent : int 

293 Number of spaces to indent the output. Default is 4. 

294 level : int 

295 Current level of indentation. Default is 0. 

296 

297 Returns 

298 ------- 

299 pretty : str 

300 Pretty-printed environment configuration. 

301 Default output is the same as `__repr__`. 

302 """ 

303 return f'{" " * indent * level}{repr(self)}' 

304 

305 def _combine_tunables(self, tunables: TunableGroups) -> dict[str, TunableValue]: 

306 """ 

307 Plug tunable values into the base config. If the tunable group is unknown, 

308 ignore it (it might belong to another environment). This method should never 

309 mutate the original config or the tunables. 

310 

311 Parameters 

312 ---------- 

313 tunables : TunableGroups 

314 A collection of groups of tunable parameters 

315 along with the parameters' values. 

316 

317 Returns 

318 ------- 

319 params : dict[str, Union[int, float, str]] 

320 Free-format dictionary that contains the new environment configuration. 

321 """ 

322 return tunables.get_param_values( 

323 group_names=list(self._tunable_params.get_covariant_group_names()), 

324 into_params=self._const_args.copy(), 

325 ) 

326 

327 @property 

328 def tunable_params(self) -> TunableGroups: 

329 """ 

330 Get the configuration space of the given environment. 

331 

332 Returns 

333 ------- 

334 tunables : TunableGroups 

335 A collection of covariant groups of tunable parameters. 

336 """ 

337 return self._tunable_params 

338 

339 @property 

340 def const_args(self) -> dict[str, TunableValue]: 

341 """ 

342 Get the constant arguments for this Environment. 

343 

344 Returns 

345 ------- 

346 parameters : Dict[str, TunableValue] 

347 Key/value pairs of all environment const_args parameters. 

348 """ 

349 return self._const_args.copy() 

350 

351 @property 

352 def parameters(self) -> dict[str, TunableValue]: 

353 """ 

354 Key/value pairs of all environment parameters (i.e., `const_args` and 

355 `tunable_params`). Note that before `.setup()` is called, all tunables will be 

356 set to None. 

357 

358 Returns 

359 ------- 

360 parameters : dict[str, TunableValue] 

361 Key/value pairs of all environment parameters 

362 (i.e., `const_args` and `tunable_params`). 

363 """ 

364 return self._params.copy() 

365 

366 def setup(self, tunables: TunableGroups, global_config: dict | None = None) -> bool: 

367 """ 

368 Set up a new benchmark environment, if necessary. This method must be 

369 idempotent, i.e., calling it several times in a row should be equivalent to a 

370 single call. 

371 

372 Parameters 

373 ---------- 

374 tunables : TunableGroups 

375 A collection of tunable parameters along with their values. 

376 global_config : dict 

377 Free-format dictionary of global parameters of the environment 

378 that are not used in the optimization process. 

379 

380 Returns 

381 ------- 

382 is_success : bool 

383 True if operation is successful, false otherwise. 

384 """ 

385 _LOG.info("Setup %s :: %s", self, tunables) 

386 assert isinstance(tunables, TunableGroups) 

387 

388 # Make sure we create a context before invoking setup/run/status/teardown 

389 assert self._in_context 

390 

391 # Assign new values to the environment's tunable parameters: 

392 groups = list(self._tunable_params.get_covariant_group_names()) 

393 self._tunable_params.assign(tunables.get_param_values(groups)) 

394 

395 # Write to the log whether the environment needs to be reset. 

396 # (Derived classes still have to check `self._tunable_params.is_updated()`). 

397 is_updated = self._tunable_params.is_updated() 

398 if _LOG.isEnabledFor(logging.DEBUG): 

399 _LOG.debug( 

400 "Env '%s': Tunable groups reset = %s :: %s", 

401 self, 

402 is_updated, 

403 { 

404 name: self._tunable_params.is_updated([name]) 

405 for name in self._tunable_params.get_covariant_group_names() 

406 }, 

407 ) 

408 else: 

409 _LOG.info("Env '%s': Tunable groups reset = %s", self, is_updated) 

410 

411 # Combine tunables, const_args, and global config into `self._params`: 

412 self._params = self._combine_tunables(tunables) 

413 merge_parameters(dest=self._params, source=global_config) 

414 

415 if _LOG.isEnabledFor(logging.DEBUG): 

416 _LOG.debug("Combined parameters:\n%s", json.dumps(self._params, indent=2)) 

417 

418 return True 

419 

420 def teardown(self) -> None: 

421 """ 

422 Tear down the benchmark environment. 

423 

424 This method must be idempotent, i.e., calling it several times in a row should 

425 be equivalent to a single call. 

426 """ 

427 _LOG.info("Teardown %s", self) 

428 # Make sure we create a context before invoking setup/run/status/teardown 

429 assert self._in_context 

430 self._is_ready = False 

431 

432 def run(self) -> tuple[Status, datetime, dict[str, TunableValue] | None]: 

433 """ 

434 Execute the run script for this environment. 

435 

436 For instance, this may start a new experiment, download results, reconfigure 

437 the environment, etc. Details are configurable via the environment config. 

438 

439 Returns 

440 ------- 

441 (status, timestamp, output) : (Status, datetime.datetime, dict) 

442 3-tuple of (Status, timestamp, output) values, where `output` is a dict 

443 with the results or None if the status is not COMPLETED. 

444 If run script is a benchmark, then the score is usually expected to 

445 be in the `score` field. 

446 """ 

447 # Make sure we create a context before invoking setup/run/status/teardown 

448 assert self._in_context 

449 (status, timestamp, _) = self.status() 

450 return (status, timestamp, None) 

451 

452 def status(self) -> tuple[Status, datetime, list[tuple[datetime, str, Any]]]: 

453 """ 

454 Check the status of the benchmark environment. 

455 

456 Returns 

457 ------- 

458 (benchmark_status, timestamp, telemetry) : (Status, datetime.datetime, list) 

459 3-tuple of (benchmark status, timestamp, telemetry) values. 

460 `timestamp` is UTC time stamp of the status; it's current time by default. 

461 `telemetry` is a list (maybe empty) of (timestamp, metric, value) triplets. 

462 """ 

463 # Make sure we create a context before invoking setup/run/status/teardown 

464 assert self._in_context 

465 timestamp = datetime.now(UTC) 

466 if self._is_ready: 

467 return (Status.READY, timestamp, []) 

468 _LOG.warning("Environment not ready: %s", self) 

469 return (Status.PENDING, timestamp, [])