Coverage for mlos_bench/mlos_bench/util.py: 89%

109 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-10-07 01:52 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5"""Various helper functions for mlos_bench.""" 

6 

7# NOTE: This has to be placed in the top-level mlos_bench package to avoid circular imports. 

8 

9import importlib 

10import json 

11import logging 

12import os 

13import subprocess 

14from datetime import datetime 

15from typing import ( 

16 TYPE_CHECKING, 

17 Any, 

18 Callable, 

19 Dict, 

20 Iterable, 

21 Literal, 

22 Mapping, 

23 Optional, 

24 Tuple, 

25 Type, 

26 TypeVar, 

27 Union, 

28) 

29 

30import pandas 

31import pytz 

32 

33_LOG = logging.getLogger(__name__) 

34 

35if TYPE_CHECKING: 

36 from mlos_bench.environments.base_environment import Environment 

37 from mlos_bench.optimizers.base_optimizer import Optimizer 

38 from mlos_bench.schedulers.base_scheduler import Scheduler 

39 from mlos_bench.services.base_service import Service 

40 from mlos_bench.storage.base_storage import Storage 

41 

42# BaseTypeVar is a generic with a constraint of the three base classes. 

43BaseTypeVar = TypeVar("BaseTypeVar", "Environment", "Optimizer", "Scheduler", "Service", "Storage") 

44BaseTypes = Union["Environment", "Optimizer", "Scheduler", "Service", "Storage"] 

45 

46 

47def preprocess_dynamic_configs(*, dest: dict, source: Optional[dict] = None) -> dict: 

48 """ 

49 Replaces all $name values in the destination config with the corresponding value 

50 from the source config. 

51 

52 Parameters 

53 ---------- 

54 dest : dict 

55 Destination config. 

56 source : Optional[dict] 

57 Source config. 

58 

59 Returns 

60 ------- 

61 dest : dict 

62 A reference to the destination config after the preprocessing. 

63 """ 

64 if source is None: 

65 source = {} 

66 for key, val in dest.items(): 

67 if isinstance(val, str) and val.startswith("$") and val[1:] in source: 

68 dest[key] = source[val[1:]] 

69 return dest 

70 

71 

72def merge_parameters( 

73 *, 

74 dest: dict, 

75 source: Optional[dict] = None, 

76 required_keys: Optional[Iterable[str]] = None, 

77) -> dict: 

78 """ 

79 Merge the source config dict into the destination config. Pick from the source 

80 configs *ONLY* the keys that are already present in the destination config. 

81 

82 Parameters 

83 ---------- 

84 dest : dict 

85 Destination config. 

86 source : Optional[dict] 

87 Source config. 

88 required_keys : Optional[Iterable[str]] 

89 An optional list of keys that must be present in the destination config. 

90 

91 Returns 

92 ------- 

93 dest : dict 

94 A reference to the destination config after the merge. 

95 """ 

96 if source is None: 

97 source = {} 

98 

99 for key in set(dest).intersection(source): 

100 dest[key] = source[key] 

101 

102 for key in required_keys or []: 

103 if key in dest: 

104 continue 

105 if key in source: 

106 dest[key] = source[key] 

107 else: 

108 raise ValueError("Missing required parameter: " + key) 

109 

110 return dest 

111 

112 

113def path_join(*args: str, abs_path: bool = False) -> str: 

114 """ 

115 Joins the path components and normalizes the path. 

116 

117 Parameters 

118 ---------- 

119 args : str 

120 Path components. 

121 

122 abs_path : bool 

123 If True, the path is converted to be absolute. 

124 

125 Returns 

126 ------- 

127 str 

128 Joined path. 

129 """ 

130 path = os.path.join(*args) 

131 if abs_path: 

132 path = os.path.abspath(path) 

133 return os.path.normpath(path).replace("\\", "/") 

134 

135 

136def prepare_class_load( 

137 config: dict, 

138 global_config: Optional[Dict[str, Any]] = None, 

139) -> Tuple[str, Dict[str, Any]]: 

140 """ 

141 Extract the class instantiation parameters from the configuration. 

142 

143 Parameters 

144 ---------- 

145 config : dict 

146 Configuration of the optimizer. 

147 global_config : dict 

148 Global configuration parameters (optional). 

149 

150 Returns 

151 ------- 

152 (class_name, class_config) : (str, dict) 

153 Name of the class to instantiate and its configuration. 

154 """ 

155 class_name = config["class"] 

156 class_config = config.setdefault("config", {}) 

157 

158 merge_parameters(dest=class_config, source=global_config) 

159 

160 if _LOG.isEnabledFor(logging.DEBUG): 

161 _LOG.debug( 

162 "Instantiating: %s with config:\n%s", class_name, json.dumps(class_config, indent=2) 

163 ) 

164 

165 return (class_name, class_config) 

166 

167 

168def get_class_from_name(class_name: str) -> type: 

169 """ 

170 Gets the class from the fully qualified name. 

171 

172 Parameters 

173 ---------- 

174 class_name : str 

175 Fully qualified class name. 

176 

177 Returns 

178 ------- 

179 type 

180 Class object. 

181 """ 

182 # We need to import mlos_bench to make the factory methods work. 

183 class_name_split = class_name.split(".") 

184 module_name = ".".join(class_name_split[:-1]) 

185 class_id = class_name_split[-1] 

186 

187 module = importlib.import_module(module_name) 

188 cls = getattr(module, class_id) 

189 assert isinstance(cls, type) 

190 return cls 

191 

192 

193# FIXME: Technically, this should return a type "class_name" derived from "base_class". 

194def instantiate_from_config( 

195 base_class: Type[BaseTypeVar], 

196 class_name: str, 

197 *args: Any, 

198 **kwargs: Any, 

199) -> BaseTypeVar: 

200 """ 

201 Factory method for a new class instantiated from config. 

202 

203 Parameters 

204 ---------- 

205 base_class : type 

206 Base type of the class to instantiate. 

207 Currently it's one of {Environment, Service, Optimizer}. 

208 class_name : str 

209 FQN of a Python class to instantiate, e.g., 

210 "mlos_bench.environments.remote.HostEnv". 

211 Must be derived from the `base_class`. 

212 args : list 

213 Positional arguments to pass to the constructor. 

214 kwargs : dict 

215 Keyword arguments to pass to the constructor. 

216 

217 Returns 

218 ------- 

219 inst : Union[Environment, Service, Optimizer, Storage] 

220 An instance of the `class_name` class. 

221 """ 

222 impl = get_class_from_name(class_name) 

223 _LOG.info("Instantiating: %s :: %s", class_name, impl) 

224 

225 assert issubclass(impl, base_class) 

226 ret: BaseTypeVar = impl(*args, **kwargs) 

227 assert isinstance(ret, base_class) 

228 return ret 

229 

230 

231def check_required_params(config: Mapping[str, Any], required_params: Iterable[str]) -> None: 

232 """ 

233 Check if all required parameters are present in the configuration. Raise ValueError 

234 if any of the parameters are missing. 

235 

236 Parameters 

237 ---------- 

238 config : dict 

239 Free-format dictionary with the configuration 

240 of the service or benchmarking environment. 

241 required_params : Iterable[str] 

242 A collection of identifiers of the parameters that must be present 

243 in the configuration. 

244 """ 

245 missing_params = set(required_params).difference(config) 

246 if missing_params: 

247 raise ValueError( 

248 "The following parameters must be provided in the configuration" 

249 + f" or as command line arguments: {missing_params}" 

250 ) 

251 

252 

253def get_git_info(path: str = __file__) -> Tuple[str, str, str]: 

254 """ 

255 Get the git repository, commit hash, and local path of the given file. 

256 

257 Parameters 

258 ---------- 

259 path : str 

260 Path to the file in git repository. 

261 

262 Returns 

263 ------- 

264 (git_repo, git_commit, git_path) : Tuple[str, str, str] 

265 Git repository URL, last commit hash, and relative file path. 

266 """ 

267 dirname = os.path.dirname(path) 

268 git_repo = subprocess.check_output( 

269 ["git", "-C", dirname, "remote", "get-url", "origin"], text=True 

270 ).strip() 

271 git_commit = subprocess.check_output( 

272 ["git", "-C", dirname, "rev-parse", "HEAD"], text=True 

273 ).strip() 

274 git_root = subprocess.check_output( 

275 ["git", "-C", dirname, "rev-parse", "--show-toplevel"], text=True 

276 ).strip() 

277 _LOG.debug("Current git branch: %s %s", git_repo, git_commit) 

278 rel_path = os.path.relpath(os.path.abspath(path), os.path.abspath(git_root)) 

279 return (git_repo, git_commit, rel_path.replace("\\", "/")) 

280 

281 

282# Note: to avoid circular imports, we don't specify TunableValue here. 

283def try_parse_val(val: Optional[str]) -> Optional[Union[int, float, str]]: 

284 """ 

285 Try to parse the value as an int or float, otherwise return the string. 

286 

287 This can help with config schema validation to make sure early on that 

288 the args we're expecting are the right type. 

289 

290 Parameters 

291 ---------- 

292 val : str 

293 The initial cmd line arg value. 

294 

295 Returns 

296 ------- 

297 TunableValue 

298 The parsed value. 

299 """ 

300 if val is None: 

301 return val 

302 try: 

303 val_float = float(val) 

304 try: 

305 val_int = int(val) 

306 return val_int if val_int == val_float else val_float 

307 except (ValueError, OverflowError): 

308 return val_float 

309 except ValueError: 

310 return str(val) 

311 

312 

313def nullable(func: Callable, value: Optional[Any]) -> Optional[Any]: 

314 """ 

315 Poor man's Maybe monad: apply the function to the value if it's not None. 

316 

317 Parameters 

318 ---------- 

319 func : Callable 

320 Function to apply to the value. 

321 value : Optional[Any] 

322 Value to apply the function to. 

323 

324 Returns 

325 ------- 

326 value : Optional[Any] 

327 The result of the function application or None if the value is None. 

328 """ 

329 return None if value is None else func(value) 

330 

331 

332def utcify_timestamp(timestamp: datetime, *, origin: Literal["utc", "local"]) -> datetime: 

333 """ 

334 Augment a timestamp with zoneinfo if missing and convert it to UTC. 

335 

336 Parameters 

337 ---------- 

338 timestamp : datetime 

339 A timestamp to convert to UTC. 

340 Note: The original datetime may or may not have tzinfo associated with it. 

341 

342 origin : Literal["utc", "local"] 

343 Whether the source timestamp is considered to be in UTC or local time. 

344 In the case of loading data from storage, where we intentionally convert all 

345 timestamps to UTC, this can help us retrieve the original timezone when the 

346 storage backend doesn't explicitly store it. 

347 In the case of receiving data from a client or other source, this can help us 

348 convert the timestamp to UTC if it's not already. 

349 

350 Returns 

351 ------- 

352 datetime 

353 A datetime with zoneinfo in UTC. 

354 """ 

355 if timestamp.tzinfo is not None or origin == "local": 

356 # A timestamp with no zoneinfo is interpretted as "local" time 

357 # (e.g., according to the TZ environment variable). 

358 # That could be UTC or some other timezone, but either way we convert it to 

359 # be explicitly UTC with zone info. 

360 return timestamp.astimezone(pytz.UTC) 

361 elif origin == "utc": 

362 # If the timestamp is already in UTC, we just add the zoneinfo without conversion. 

363 # Converting with astimezone() when the local time is *not* UTC would cause 

364 # a timestamp conversion which we don't want. 

365 return timestamp.replace(tzinfo=pytz.UTC) 

366 else: 

367 raise ValueError(f"Invalid origin: {origin}") 

368 

369 

370def utcify_nullable_timestamp( 

371 timestamp: Optional[datetime], 

372 *, 

373 origin: Literal["utc", "local"], 

374) -> Optional[datetime]: 

375 """A nullable version of utcify_timestamp.""" 

376 return utcify_timestamp(timestamp, origin=origin) if timestamp is not None else None 

377 

378 

379# All timestamps in the telemetry data must be greater than this date 

380# (a very rough approximation for the start of this feature). 

381_MIN_TS = datetime(2024, 1, 1, 0, 0, 0, tzinfo=pytz.UTC) 

382 

383 

384def datetime_parser( 

385 datetime_col: pandas.Series, 

386 *, 

387 origin: Literal["utc", "local"], 

388) -> pandas.Series: 

389 """ 

390 Attempt to convert a pandas column to a datetime format. 

391 

392 Parameters 

393 ---------- 

394 datetime_col : pandas.Series 

395 The column to convert. 

396 

397 origin : Literal["utc", "local"] 

398 Whether to interpret naive timestamps as originating from UTC or local time. 

399 

400 Returns 

401 ------- 

402 pandas.Series 

403 The converted datetime column. 

404 

405 Raises 

406 ------ 

407 ValueError 

408 On parse errors. 

409 """ 

410 new_datetime_col = pandas.to_datetime(datetime_col, utc=False) 

411 # If timezone data is missing, assume the provided origin timezone. 

412 if new_datetime_col.dt.tz is None: 

413 if origin == "local": 

414 tzinfo = datetime.now().astimezone().tzinfo 

415 elif origin == "utc": 

416 tzinfo = pytz.UTC 

417 else: 

418 raise ValueError(f"Invalid timezone origin: {origin}") 

419 new_datetime_col = new_datetime_col.dt.tz_localize(tzinfo) 

420 assert new_datetime_col.dt.tz is not None 

421 # And convert it to UTC. 

422 new_datetime_col = new_datetime_col.dt.tz_convert("UTC") 

423 if new_datetime_col.isna().any(): 

424 raise ValueError(f"Invalid date format in the data: {datetime_col}") 

425 if new_datetime_col.le(_MIN_TS).any(): 

426 raise ValueError(f"Invalid date range in the data: {datetime_col}") 

427 return new_datetime_col