Coverage for mlos_bench/mlos_bench/util.py: 89%

118 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2024-12-14 01:58 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5"""Various helper functions for mlos_bench.""" 

6 

7# NOTE: This has to be placed in the top-level mlos_bench package to avoid circular imports. 

8 

9import importlib 

10import json 

11import logging 

12import os 

13import subprocess 

14from datetime import datetime 

15from typing import ( 

16 TYPE_CHECKING, 

17 Any, 

18 Callable, 

19 Dict, 

20 Iterable, 

21 Literal, 

22 Mapping, 

23 Optional, 

24 Tuple, 

25 Type, 

26 TypeVar, 

27 Union, 

28) 

29 

30import pandas 

31import pytz 

32 

33_LOG = logging.getLogger(__name__) 

34 

35if TYPE_CHECKING: 

36 from mlos_bench.environments.base_environment import Environment 

37 from mlos_bench.optimizers.base_optimizer import Optimizer 

38 from mlos_bench.schedulers.base_scheduler import Scheduler 

39 from mlos_bench.services.base_service import Service 

40 from mlos_bench.storage.base_storage import Storage 

41 

42BaseTypeVar = TypeVar("BaseTypeVar", "Environment", "Optimizer", "Scheduler", "Service", "Storage") 

43"""BaseTypeVar is a generic with a constraint of the main base classes (e.g., 

44:py:class:`~mlos_bench.environments.base_environment.Environment`, 

45:py:class:`~mlos_bench.optimizers.base_optimizer.Optimizer`, 

46:py:class:`~mlos_bench.schedulers.base_scheduler.Scheduler`, 

47:py:class:`~mlos_bench.services.base_service.Service`, 

48:py:class:`~mlos_bench.storage.base_storage.Storage`, etc.). 

49""" 

50 

51BaseTypes = Union["Environment", "Optimizer", "Scheduler", "Service", "Storage"] 

52"""Similar to :py:data:`.BaseTypeVar`, BaseTypes is a Union of the main base classes.""" 

53 

54 

55# Adjusted from https://github.com/python/cpython/blob/v3.11.10/Lib/distutils/util.py#L308 

56# See Also: https://github.com/microsoft/MLOS/issues/865 

57def strtobool(val: str) -> bool: 

58 """ 

59 Convert a string representation of truth to true (1) or false (0). 

60 

61 Parameters 

62 ---------- 

63 val : str 

64 True values are 'y', 'yes', 't', 'true', 'on', and '1'; 

65 False values are 'n', 'no', 'f', 'false', 'off', and '0'. 

66 

67 Raises 

68 ------ 

69 ValueError 

70 If 'val' is anything else. 

71 """ 

72 val = val.lower() 

73 if val in {"y", "yes", "t", "true", "on", "1"}: 

74 return True 

75 elif val in {"n", "no", "f", "false", "off", "0"}: 

76 return False 

77 else: 

78 raise ValueError(f"Invalid Boolean value: '{val}'") 

79 

80 

81def preprocess_dynamic_configs(*, dest: dict, source: Optional[dict] = None) -> dict: 

82 """ 

83 Replaces all ``$name`` values in the destination config with the corresponding value 

84 from the source config. 

85 

86 Parameters 

87 ---------- 

88 dest : dict 

89 Destination config. 

90 source : Optional[dict] 

91 Source config. 

92 

93 Returns 

94 ------- 

95 dest : dict 

96 A reference to the destination config after the preprocessing. 

97 """ 

98 if source is None: 

99 source = {} 

100 for key, val in dest.items(): 

101 if isinstance(val, str) and val.startswith("$") and val[1:] in source: 

102 dest[key] = source[val[1:]] 

103 return dest 

104 

105 

106def merge_parameters( 

107 *, 

108 dest: dict, 

109 source: Optional[dict] = None, 

110 required_keys: Optional[Iterable[str]] = None, 

111) -> dict: 

112 """ 

113 Merge the source config dict into the destination config. Pick from the source 

114 configs *ONLY* the keys that are already present in the destination config. 

115 

116 Parameters 

117 ---------- 

118 dest : dict 

119 Destination config. 

120 source : Optional[dict] 

121 Source config. 

122 required_keys : Optional[Iterable[str]] 

123 An optional list of keys that must be present in the destination config. 

124 

125 Returns 

126 ------- 

127 dest : dict 

128 A reference to the destination config after the merge. 

129 """ 

130 if source is None: 

131 source = {} 

132 

133 for key in set(dest).intersection(source): 

134 dest[key] = source[key] 

135 

136 for key in required_keys or []: 

137 if key in dest: 

138 continue 

139 if key in source: 

140 dest[key] = source[key] 

141 else: 

142 raise ValueError("Missing required parameter: " + key) 

143 

144 return dest 

145 

146 

147def path_join(*args: str, abs_path: bool = False) -> str: 

148 """ 

149 Joins the path components and normalizes the path. 

150 

151 Parameters 

152 ---------- 

153 args : str 

154 Path components. 

155 

156 abs_path : bool 

157 If True, the path is converted to be absolute. 

158 

159 Returns 

160 ------- 

161 str 

162 Joined path. 

163 """ 

164 path = os.path.join(*args) 

165 if abs_path: 

166 path = os.path.abspath(path) 

167 return os.path.normpath(path).replace("\\", "/") 

168 

169 

170def prepare_class_load( 

171 config: dict, 

172 global_config: Optional[Dict[str, Any]] = None, 

173) -> Tuple[str, Dict[str, Any]]: 

174 """ 

175 Extract the class instantiation parameters from the configuration. 

176 

177 Parameters 

178 ---------- 

179 config : dict 

180 Configuration of the optimizer. 

181 global_config : dict 

182 Global configuration parameters (optional). 

183 

184 Returns 

185 ------- 

186 (class_name, class_config) : (str, dict) 

187 Name of the class to instantiate and its configuration. 

188 """ 

189 class_name = config["class"] 

190 class_config = config.setdefault("config", {}) 

191 

192 merge_parameters(dest=class_config, source=global_config) 

193 

194 if _LOG.isEnabledFor(logging.DEBUG): 

195 _LOG.debug( 

196 "Instantiating: %s with config:\n%s", class_name, json.dumps(class_config, indent=2) 

197 ) 

198 

199 return (class_name, class_config) 

200 

201 

202def get_class_from_name(class_name: str) -> type: 

203 """ 

204 Gets the class from the fully qualified name. 

205 

206 Parameters 

207 ---------- 

208 class_name : str 

209 Fully qualified class name. 

210 

211 Returns 

212 ------- 

213 type 

214 Class object. 

215 """ 

216 # We need to import mlos_bench to make the factory methods work. 

217 class_name_split = class_name.split(".") 

218 module_name = ".".join(class_name_split[:-1]) 

219 class_id = class_name_split[-1] 

220 

221 module = importlib.import_module(module_name) 

222 cls = getattr(module, class_id) 

223 assert isinstance(cls, type) 

224 return cls 

225 

226 

227# FIXME: Technically, this should return a type "class_name" derived from "base_class". 

228def instantiate_from_config( 

229 base_class: Type[BaseTypeVar], 

230 class_name: str, 

231 *args: Any, 

232 **kwargs: Any, 

233) -> BaseTypeVar: 

234 """ 

235 Factory method for a new class instantiated from config. 

236 

237 Parameters 

238 ---------- 

239 base_class : type 

240 Base type of the class to instantiate. 

241 Currently it's one of {Environment, Service, Optimizer}. 

242 class_name : str 

243 FQN of a Python class to instantiate, e.g., 

244 "mlos_bench.environments.remote.HostEnv". 

245 Must be derived from the `base_class`. 

246 args : list 

247 Positional arguments to pass to the constructor. 

248 kwargs : dict 

249 Keyword arguments to pass to the constructor. 

250 

251 Returns 

252 ------- 

253 inst : Union[Environment, Service, Optimizer, Storage] 

254 An instance of the `class_name` class. 

255 """ 

256 impl = get_class_from_name(class_name) 

257 _LOG.info("Instantiating: %s :: %s", class_name, impl) 

258 

259 assert issubclass(impl, base_class) 

260 ret: BaseTypeVar = impl(*args, **kwargs) 

261 assert isinstance(ret, base_class) 

262 return ret 

263 

264 

265def check_required_params(config: Mapping[str, Any], required_params: Iterable[str]) -> None: 

266 """ 

267 Check if all required parameters are present in the configuration. Raise ValueError 

268 if any of the parameters are missing. 

269 

270 Parameters 

271 ---------- 

272 config : dict 

273 Free-format dictionary with the configuration 

274 of the service or benchmarking environment. 

275 required_params : Iterable[str] 

276 A collection of identifiers of the parameters that must be present 

277 in the configuration. 

278 """ 

279 missing_params = set(required_params).difference(config) 

280 if missing_params: 

281 raise ValueError( 

282 "The following parameters must be provided in the configuration" 

283 + f" or as command line arguments: {missing_params}" 

284 ) 

285 

286 

287def get_git_info(path: str = __file__) -> Tuple[str, str, str]: 

288 """ 

289 Get the git repository, commit hash, and local path of the given file. 

290 

291 Parameters 

292 ---------- 

293 path : str 

294 Path to the file in git repository. 

295 

296 Returns 

297 ------- 

298 (git_repo, git_commit, git_path) : Tuple[str, str, str] 

299 Git repository URL, last commit hash, and relative file path. 

300 """ 

301 dirname = os.path.dirname(path) 

302 git_repo = subprocess.check_output( 

303 ["git", "-C", dirname, "remote", "get-url", "origin"], text=True 

304 ).strip() 

305 git_commit = subprocess.check_output( 

306 ["git", "-C", dirname, "rev-parse", "HEAD"], text=True 

307 ).strip() 

308 git_root = subprocess.check_output( 

309 ["git", "-C", dirname, "rev-parse", "--show-toplevel"], text=True 

310 ).strip() 

311 _LOG.debug("Current git branch: %s %s", git_repo, git_commit) 

312 rel_path = os.path.relpath(os.path.abspath(path), os.path.abspath(git_root)) 

313 return (git_repo, git_commit, rel_path.replace("\\", "/")) 

314 

315 

316# Note: to avoid circular imports, we don't specify TunableValue here. 

317def try_parse_val(val: Optional[str]) -> Optional[Union[int, float, str]]: 

318 """ 

319 Try to parse the value as an int or float, otherwise return the string. 

320 

321 This can help with config schema validation to make sure early on that 

322 the args we're expecting are the right type. 

323 

324 Parameters 

325 ---------- 

326 val : str 

327 The initial cmd line arg value. 

328 

329 Returns 

330 ------- 

331 TunableValue 

332 The parsed value. 

333 """ 

334 if val is None: 

335 return val 

336 try: 

337 val_float = float(val) 

338 try: 

339 val_int = int(val) 

340 return val_int if val_int == val_float else val_float 

341 except (ValueError, OverflowError): 

342 return val_float 

343 except ValueError: 

344 return str(val) 

345 

346 

347def nullable(func: Callable, value: Optional[Any]) -> Optional[Any]: 

348 """ 

349 Poor man's Maybe monad: apply the function to the value if it's not None. 

350 

351 Parameters 

352 ---------- 

353 func : Callable 

354 Function to apply to the value. 

355 value : Optional[Any] 

356 Value to apply the function to. 

357 

358 Returns 

359 ------- 

360 value : Optional[Any] 

361 The result of the function application or None if the value is None. 

362 """ 

363 return None if value is None else func(value) 

364 

365 

366def utcify_timestamp(timestamp: datetime, *, origin: Literal["utc", "local"]) -> datetime: 

367 """ 

368 Augment a timestamp with zoneinfo if missing and convert it to UTC. 

369 

370 Parameters 

371 ---------- 

372 timestamp : datetime.datetime 

373 A timestamp to convert to UTC. 

374 Note: The original datetime may or may not have tzinfo associated with it. 

375 

376 origin : Literal["utc", "local"] 

377 Whether the source timestamp is considered to be in UTC or local time. 

378 In the case of loading data from storage, where we intentionally convert all 

379 timestamps to UTC, this can help us retrieve the original timezone when the 

380 storage backend doesn't explicitly store it. 

381 In the case of receiving data from a client or other source, this can help us 

382 convert the timestamp to UTC if it's not already. 

383 

384 Returns 

385 ------- 

386 datetime.datetime 

387 A datetime with zoneinfo in UTC. 

388 """ 

389 if timestamp.tzinfo is not None or origin == "local": 

390 # A timestamp with no zoneinfo is interpretted as "local" time 

391 # (e.g., according to the TZ environment variable). 

392 # That could be UTC or some other timezone, but either way we convert it to 

393 # be explicitly UTC with zone info. 

394 return timestamp.astimezone(pytz.UTC) 

395 elif origin == "utc": 

396 # If the timestamp is already in UTC, we just add the zoneinfo without conversion. 

397 # Converting with astimezone() when the local time is *not* UTC would cause 

398 # a timestamp conversion which we don't want. 

399 return timestamp.replace(tzinfo=pytz.UTC) 

400 else: 

401 raise ValueError(f"Invalid origin: {origin}") 

402 

403 

404def utcify_nullable_timestamp( 

405 timestamp: Optional[datetime], 

406 *, 

407 origin: Literal["utc", "local"], 

408) -> Optional[datetime]: 

409 """A nullable version of utcify_timestamp.""" 

410 return utcify_timestamp(timestamp, origin=origin) if timestamp is not None else None 

411 

412 

413# All timestamps in the telemetry data must be greater than this date 

414# (a very rough approximation for the start of this feature). 

415_MIN_TS = datetime(2024, 1, 1, 0, 0, 0, tzinfo=pytz.UTC) 

416 

417 

418def datetime_parser( 

419 datetime_col: pandas.Series, 

420 *, 

421 origin: Literal["utc", "local"], 

422) -> pandas.Series: 

423 """ 

424 Attempt to convert a pandas column to a datetime format. 

425 

426 Parameters 

427 ---------- 

428 datetime_col : pandas.Series 

429 The column to convert. 

430 

431 origin : Literal["utc", "local"] 

432 Whether to interpret naive timestamps as originating from UTC or local time. 

433 

434 Returns 

435 ------- 

436 pandas.Series 

437 The converted datetime column. 

438 

439 Raises 

440 ------ 

441 ValueError 

442 On parse errors. 

443 """ 

444 new_datetime_col = pandas.to_datetime(datetime_col, utc=False) 

445 # If timezone data is missing, assume the provided origin timezone. 

446 if new_datetime_col.dt.tz is None: 

447 if origin == "local": 

448 tzinfo = datetime.now().astimezone().tzinfo 

449 elif origin == "utc": 

450 tzinfo = pytz.UTC 

451 else: 

452 raise ValueError(f"Invalid timezone origin: {origin}") 

453 new_datetime_col = new_datetime_col.dt.tz_localize(tzinfo) 

454 assert new_datetime_col.dt.tz is not None 

455 # And convert it to UTC. 

456 new_datetime_col = new_datetime_col.dt.tz_convert("UTC") 

457 if new_datetime_col.isna().any(): 

458 raise ValueError(f"Invalid date format in the data: {datetime_col}") 

459 if new_datetime_col.le(_MIN_TS).any(): 

460 raise ValueError(f"Invalid date range in the data: {datetime_col}") 

461 return new_datetime_col