Coverage for mlos_bench/mlos_bench/util.py: 89%

121 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2025-01-21 01:50 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5"""Various helper functions for mlos_bench.""" 

6 

7# NOTE: This has to be placed in the top-level mlos_bench package to avoid circular imports. 

8 

9import importlib 

10import json 

11import logging 

12import os 

13import subprocess 

14from collections.abc import Callable, Iterable, Mapping 

15from datetime import datetime 

16from typing import TYPE_CHECKING, Any, Literal, TypeVar, Union 

17 

18import pandas 

19import pytz 

20 

21_LOG = logging.getLogger(__name__) 

22 

23if TYPE_CHECKING: 

24 from mlos_bench.environments.base_environment import Environment 

25 from mlos_bench.optimizers.base_optimizer import Optimizer 

26 from mlos_bench.schedulers.base_scheduler import Scheduler 

27 from mlos_bench.services.base_service import Service 

28 from mlos_bench.storage.base_storage import Storage 

29 

30BaseTypeVar = TypeVar("BaseTypeVar", "Environment", "Optimizer", "Scheduler", "Service", "Storage") 

31"""BaseTypeVar is a generic with a constraint of the main base classes (e.g., 

32:py:class:`~mlos_bench.environments.base_environment.Environment`, 

33:py:class:`~mlos_bench.optimizers.base_optimizer.Optimizer`, 

34:py:class:`~mlos_bench.schedulers.base_scheduler.Scheduler`, 

35:py:class:`~mlos_bench.services.base_service.Service`, 

36:py:class:`~mlos_bench.storage.base_storage.Storage`, etc.). 

37""" 

38 

39BaseTypes = Union[ # pylint: disable=consider-alternative-union-syntax 

40 "Environment", "Optimizer", "Scheduler", "Service", "Storage" 

41] 

42"""Similar to :py:data:`.BaseTypeVar`, BaseTypes is a Union of the main base classes.""" 

43 

44 

45# Adjusted from https://github.com/python/cpython/blob/v3.11.10/Lib/distutils/util.py#L308 

46# See Also: https://github.com/microsoft/MLOS/issues/865 

47def strtobool(val: str) -> bool: 

48 """ 

49 Convert a string representation of truth to true (1) or false (0). 

50 

51 Parameters 

52 ---------- 

53 val : str 

54 True values are 'y', 'yes', 't', 'true', 'on', and '1'; 

55 False values are 'n', 'no', 'f', 'false', 'off', and '0'. 

56 

57 Raises 

58 ------ 

59 ValueError 

60 If 'val' is anything else. 

61 """ 

62 val = val.lower() 

63 if val in {"y", "yes", "t", "true", "on", "1"}: 

64 return True 

65 elif val in {"n", "no", "f", "false", "off", "0"}: 

66 return False 

67 else: 

68 raise ValueError(f"Invalid Boolean value: '{val}'") 

69 

70 

71def preprocess_dynamic_configs(*, dest: dict, source: dict | None = None) -> dict: 

72 """ 

73 Replaces all ``$name`` values in the destination config with the corresponding value 

74 from the source config. 

75 

76 Parameters 

77 ---------- 

78 dest : dict 

79 Destination config. 

80 source : dict | None 

81 Source config. 

82 

83 Returns 

84 ------- 

85 dest : dict 

86 A reference to the destination config after the preprocessing. 

87 """ 

88 if source is None: 

89 source = {} 

90 for key, val in dest.items(): 

91 if isinstance(val, str) and val.startswith("$") and val[1:] in source: 

92 dest[key] = source[val[1:]] 

93 return dest 

94 

95 

96def merge_parameters( 

97 *, 

98 dest: dict, 

99 source: dict | None = None, 

100 required_keys: Iterable[str] | None = None, 

101) -> dict: 

102 """ 

103 Merge the source config dict into the destination config. Pick from the source 

104 configs *ONLY* the keys that are already present in the destination config. 

105 

106 Parameters 

107 ---------- 

108 dest : dict 

109 Destination config. 

110 source : dict | None 

111 Source config. 

112 required_keys : Optional[Iterable[str]] 

113 An optional list of keys that must be present in the destination config. 

114 

115 Returns 

116 ------- 

117 dest : dict 

118 A reference to the destination config after the merge. 

119 """ 

120 if source is None: 

121 source = {} 

122 

123 for key in set(dest).intersection(source): 

124 dest[key] = source[key] 

125 

126 for key in required_keys or []: 

127 if key in dest: 

128 continue 

129 if key in source: 

130 dest[key] = source[key] 

131 else: 

132 raise ValueError("Missing required parameter: " + key) 

133 

134 return dest 

135 

136 

137def path_join(*args: str, abs_path: bool = False) -> str: 

138 """ 

139 Joins the path components and normalizes the path. 

140 

141 Parameters 

142 ---------- 

143 args : str 

144 Path components. 

145 

146 abs_path : bool 

147 If True, the path is converted to be absolute. 

148 

149 Returns 

150 ------- 

151 str 

152 Joined path. 

153 """ 

154 path = os.path.join(*args) 

155 if abs_path: 

156 path = os.path.abspath(path) 

157 return os.path.normpath(path).replace("\\", "/") 

158 

159 

160def prepare_class_load( 

161 config: dict, 

162 global_config: dict[str, Any] | None = None, 

163) -> tuple[str, dict[str, Any]]: 

164 """ 

165 Extract the class instantiation parameters from the configuration. 

166 

167 Parameters 

168 ---------- 

169 config : dict 

170 Configuration of the optimizer. 

171 global_config : dict 

172 Global configuration parameters (optional). 

173 

174 Returns 

175 ------- 

176 (class_name, class_config) : (str, dict) 

177 Name of the class to instantiate and its configuration. 

178 """ 

179 class_name = config["class"] 

180 class_config = config.setdefault("config", {}) 

181 

182 merge_parameters(dest=class_config, source=global_config) 

183 

184 if _LOG.isEnabledFor(logging.DEBUG): 

185 _LOG.debug( 

186 "Instantiating: %s with config:\n%s", class_name, json.dumps(class_config, indent=2) 

187 ) 

188 

189 return (class_name, class_config) 

190 

191 

192def get_class_from_name(class_name: str) -> type: 

193 """ 

194 Gets the class from the fully qualified name. 

195 

196 Parameters 

197 ---------- 

198 class_name : str 

199 Fully qualified class name. 

200 

201 Returns 

202 ------- 

203 type 

204 Class object. 

205 """ 

206 # We need to import mlos_bench to make the factory methods work. 

207 class_name_split = class_name.split(".") 

208 module_name = ".".join(class_name_split[:-1]) 

209 class_id = class_name_split[-1] 

210 

211 module = importlib.import_module(module_name) 

212 cls = getattr(module, class_id) 

213 assert isinstance(cls, type) 

214 return cls 

215 

216 

217# FIXME: Technically, this should return a type "class_name" derived from "base_class". 

218def instantiate_from_config( 

219 base_class: type[BaseTypeVar], 

220 class_name: str, 

221 *args: Any, 

222 **kwargs: Any, 

223) -> BaseTypeVar: 

224 """ 

225 Factory method for a new class instantiated from config. 

226 

227 Parameters 

228 ---------- 

229 base_class : type 

230 Base type of the class to instantiate. 

231 Currently it's one of {Environment, Service, Optimizer}. 

232 class_name : str 

233 FQN of a Python class to instantiate, e.g., 

234 "mlos_bench.environments.remote.HostEnv". 

235 Must be derived from the `base_class`. 

236 args : list 

237 Positional arguments to pass to the constructor. 

238 kwargs : dict 

239 Keyword arguments to pass to the constructor. 

240 

241 Returns 

242 ------- 

243 inst : Union[Environment, Service, Optimizer, Storage] 

244 An instance of the `class_name` class. 

245 """ 

246 impl = get_class_from_name(class_name) 

247 _LOG.info("Instantiating: %s :: %s", class_name, impl) 

248 

249 assert issubclass(impl, base_class) 

250 ret: BaseTypeVar = impl(*args, **kwargs) 

251 assert isinstance(ret, base_class) 

252 return ret 

253 

254 

255def check_required_params(config: Mapping[str, Any], required_params: Iterable[str]) -> None: 

256 """ 

257 Check if all required parameters are present in the configuration. Raise ValueError 

258 if any of the parameters are missing. 

259 

260 Parameters 

261 ---------- 

262 config : dict 

263 Free-format dictionary with the configuration 

264 of the service or benchmarking environment. 

265 required_params : Iterable[str] 

266 A collection of identifiers of the parameters that must be present 

267 in the configuration. 

268 """ 

269 missing_params = set(required_params).difference(config) 

270 if missing_params: 

271 raise ValueError( 

272 "The following parameters must be provided in the configuration" 

273 + f" or as command line arguments: {missing_params}" 

274 ) 

275 

276 

277def get_git_info(path: str = __file__) -> tuple[str, str, str]: 

278 """ 

279 Get the git repository, commit hash, and local path of the given file. 

280 

281 Parameters 

282 ---------- 

283 path : str 

284 Path to the file in git repository. 

285 

286 Returns 

287 ------- 

288 (git_repo, git_commit, git_path) : tuple[str, str, str] 

289 Git repository URL, last commit hash, and relative file path. 

290 """ 

291 dirname = os.path.dirname(path) 

292 git_repo = subprocess.check_output( 

293 ["git", "-C", dirname, "remote", "get-url", "origin"], text=True 

294 ).strip() 

295 git_commit = subprocess.check_output( 

296 ["git", "-C", dirname, "rev-parse", "HEAD"], text=True 

297 ).strip() 

298 git_root = subprocess.check_output( 

299 ["git", "-C", dirname, "rev-parse", "--show-toplevel"], text=True 

300 ).strip() 

301 _LOG.debug("Current git branch: %s %s", git_repo, git_commit) 

302 rel_path = os.path.relpath(os.path.abspath(path), os.path.abspath(git_root)) 

303 return (git_repo, git_commit, rel_path.replace("\\", "/")) 

304 

305 

306# Note: to avoid circular imports, we don't specify TunableValue here. 

307def try_parse_val(val: str | None) -> int | float | str | None: 

308 """ 

309 Try to parse the value as an int or float, otherwise return the string. 

310 

311 This can help with config schema validation to make sure early on that 

312 the args we're expecting are the right type. 

313 

314 Parameters 

315 ---------- 

316 val : str 

317 The initial cmd line arg value. 

318 

319 Returns 

320 ------- 

321 TunableValue 

322 The parsed value. 

323 """ 

324 if val is None: 

325 return val 

326 try: 

327 val_float = float(val) 

328 try: 

329 val_int = int(val) 

330 return val_int if val_int == val_float else val_float 

331 except (ValueError, OverflowError): 

332 return val_float 

333 except ValueError: 

334 return str(val) 

335 

336 

337NullableT = TypeVar("NullableT") 

338"""A generic type variable for :py:func:`nullable` return types.""" 

339 

340 

341def nullable(func: Callable[..., NullableT], value: Any | None) -> NullableT | None: 

342 """ 

343 Poor man's Maybe monad: apply the function to the value if it's not None. 

344 

345 Parameters 

346 ---------- 

347 func : Callable 

348 Function to apply to the value. 

349 value : Any | None 

350 Value to apply the function to. 

351 

352 Returns 

353 ------- 

354 value : NullableT | None 

355 The result of the function application or None if the value is None. 

356 

357 Examples 

358 -------- 

359 >>> nullable(int, "1") 

360 1 

361 >>> nullable(int, None) 

362 ... 

363 >>> nullable(str, 1) 

364 '1' 

365 """ 

366 return None if value is None else func(value) 

367 

368 

369def utcify_timestamp(timestamp: datetime, *, origin: Literal["utc", "local"]) -> datetime: 

370 """ 

371 Augment a timestamp with zoneinfo if missing and convert it to UTC. 

372 

373 Parameters 

374 ---------- 

375 timestamp : datetime.datetime 

376 A timestamp to convert to UTC. 

377 Note: The original datetime may or may not have tzinfo associated with it. 

378 

379 origin : Literal["utc", "local"] 

380 Whether the source timestamp is considered to be in UTC or local time. 

381 In the case of loading data from storage, where we intentionally convert all 

382 timestamps to UTC, this can help us retrieve the original timezone when the 

383 storage backend doesn't explicitly store it. 

384 In the case of receiving data from a client or other source, this can help us 

385 convert the timestamp to UTC if it's not already. 

386 

387 Returns 

388 ------- 

389 datetime.datetime 

390 A datetime with zoneinfo in UTC. 

391 """ 

392 if timestamp.tzinfo is not None or origin == "local": 

393 # A timestamp with no zoneinfo is interpretted as "local" time 

394 # (e.g., according to the TZ environment variable). 

395 # That could be UTC or some other timezone, but either way we convert it to 

396 # be explicitly UTC with zone info. 

397 return timestamp.astimezone(pytz.UTC) 

398 elif origin == "utc": 

399 # If the timestamp is already in UTC, we just add the zoneinfo without conversion. 

400 # Converting with astimezone() when the local time is *not* UTC would cause 

401 # a timestamp conversion which we don't want. 

402 return timestamp.replace(tzinfo=pytz.UTC) 

403 else: 

404 raise ValueError(f"Invalid origin: {origin}") 

405 

406 

407def utcify_nullable_timestamp( 

408 timestamp: datetime | None, 

409 *, 

410 origin: Literal["utc", "local"], 

411) -> datetime | None: 

412 """A nullable version of utcify_timestamp.""" 

413 return utcify_timestamp(timestamp, origin=origin) if timestamp is not None else None 

414 

415 

416# All timestamps in the telemetry data must be greater than this date 

417# (a very rough approximation for the start of this feature). 

418_MIN_TS = datetime(2024, 1, 1, 0, 0, 0, tzinfo=pytz.UTC) 

419 

420 

421def datetime_parser( 

422 datetime_col: pandas.Series, 

423 *, 

424 origin: Literal["utc", "local"], 

425) -> pandas.Series: 

426 """ 

427 Attempt to convert a pandas column to a datetime format. 

428 

429 Parameters 

430 ---------- 

431 datetime_col : pandas.Series 

432 The column to convert. 

433 

434 origin : Literal["utc", "local"] 

435 Whether to interpret naive timestamps as originating from UTC or local time. 

436 

437 Returns 

438 ------- 

439 pandas.Series 

440 The converted datetime column. 

441 

442 Raises 

443 ------ 

444 ValueError 

445 On parse errors. 

446 """ 

447 new_datetime_col = pandas.to_datetime(datetime_col, utc=False) 

448 # If timezone data is missing, assume the provided origin timezone. 

449 if new_datetime_col.dt.tz is None: 

450 if origin == "local": 

451 tzinfo = datetime.now().astimezone().tzinfo 

452 elif origin == "utc": 

453 tzinfo = pytz.UTC 

454 else: 

455 raise ValueError(f"Invalid timezone origin: {origin}") 

456 new_datetime_col = new_datetime_col.dt.tz_localize(tzinfo) 

457 assert new_datetime_col.dt.tz is not None 

458 # And convert it to UTC. 

459 new_datetime_col = new_datetime_col.dt.tz_convert("UTC") 

460 if new_datetime_col.isna().any(): 

461 raise ValueError(f"Invalid date format in the data: {datetime_col}") 

462 if new_datetime_col.le(_MIN_TS).any(): 

463 raise ValueError(f"Invalid date range in the data: {datetime_col}") 

464 return new_datetime_col