Coverage for mlos_bench/mlos_bench/util.py: 89%
118 statements
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-14 01:58 +0000
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-14 01:58 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""Various helper functions for mlos_bench."""
7# NOTE: This has to be placed in the top-level mlos_bench package to avoid circular imports.
9import importlib
10import json
11import logging
12import os
13import subprocess
14from datetime import datetime
15from typing import (
16 TYPE_CHECKING,
17 Any,
18 Callable,
19 Dict,
20 Iterable,
21 Literal,
22 Mapping,
23 Optional,
24 Tuple,
25 Type,
26 TypeVar,
27 Union,
28)
30import pandas
31import pytz
33_LOG = logging.getLogger(__name__)
35if TYPE_CHECKING:
36 from mlos_bench.environments.base_environment import Environment
37 from mlos_bench.optimizers.base_optimizer import Optimizer
38 from mlos_bench.schedulers.base_scheduler import Scheduler
39 from mlos_bench.services.base_service import Service
40 from mlos_bench.storage.base_storage import Storage
42BaseTypeVar = TypeVar("BaseTypeVar", "Environment", "Optimizer", "Scheduler", "Service", "Storage")
43"""BaseTypeVar is a generic with a constraint of the main base classes (e.g.,
44:py:class:`~mlos_bench.environments.base_environment.Environment`,
45:py:class:`~mlos_bench.optimizers.base_optimizer.Optimizer`,
46:py:class:`~mlos_bench.schedulers.base_scheduler.Scheduler`,
47:py:class:`~mlos_bench.services.base_service.Service`,
48:py:class:`~mlos_bench.storage.base_storage.Storage`, etc.).
49"""
51BaseTypes = Union["Environment", "Optimizer", "Scheduler", "Service", "Storage"]
52"""Similar to :py:data:`.BaseTypeVar`, BaseTypes is a Union of the main base classes."""
55# Adjusted from https://github.com/python/cpython/blob/v3.11.10/Lib/distutils/util.py#L308
56# See Also: https://github.com/microsoft/MLOS/issues/865
57def strtobool(val: str) -> bool:
58 """
59 Convert a string representation of truth to true (1) or false (0).
61 Parameters
62 ----------
63 val : str
64 True values are 'y', 'yes', 't', 'true', 'on', and '1';
65 False values are 'n', 'no', 'f', 'false', 'off', and '0'.
67 Raises
68 ------
69 ValueError
70 If 'val' is anything else.
71 """
72 val = val.lower()
73 if val in {"y", "yes", "t", "true", "on", "1"}:
74 return True
75 elif val in {"n", "no", "f", "false", "off", "0"}:
76 return False
77 else:
78 raise ValueError(f"Invalid Boolean value: '{val}'")
81def preprocess_dynamic_configs(*, dest: dict, source: Optional[dict] = None) -> dict:
82 """
83 Replaces all ``$name`` values in the destination config with the corresponding value
84 from the source config.
86 Parameters
87 ----------
88 dest : dict
89 Destination config.
90 source : Optional[dict]
91 Source config.
93 Returns
94 -------
95 dest : dict
96 A reference to the destination config after the preprocessing.
97 """
98 if source is None:
99 source = {}
100 for key, val in dest.items():
101 if isinstance(val, str) and val.startswith("$") and val[1:] in source:
102 dest[key] = source[val[1:]]
103 return dest
106def merge_parameters(
107 *,
108 dest: dict,
109 source: Optional[dict] = None,
110 required_keys: Optional[Iterable[str]] = None,
111) -> dict:
112 """
113 Merge the source config dict into the destination config. Pick from the source
114 configs *ONLY* the keys that are already present in the destination config.
116 Parameters
117 ----------
118 dest : dict
119 Destination config.
120 source : Optional[dict]
121 Source config.
122 required_keys : Optional[Iterable[str]]
123 An optional list of keys that must be present in the destination config.
125 Returns
126 -------
127 dest : dict
128 A reference to the destination config after the merge.
129 """
130 if source is None:
131 source = {}
133 for key in set(dest).intersection(source):
134 dest[key] = source[key]
136 for key in required_keys or []:
137 if key in dest:
138 continue
139 if key in source:
140 dest[key] = source[key]
141 else:
142 raise ValueError("Missing required parameter: " + key)
144 return dest
147def path_join(*args: str, abs_path: bool = False) -> str:
148 """
149 Joins the path components and normalizes the path.
151 Parameters
152 ----------
153 args : str
154 Path components.
156 abs_path : bool
157 If True, the path is converted to be absolute.
159 Returns
160 -------
161 str
162 Joined path.
163 """
164 path = os.path.join(*args)
165 if abs_path:
166 path = os.path.abspath(path)
167 return os.path.normpath(path).replace("\\", "/")
170def prepare_class_load(
171 config: dict,
172 global_config: Optional[Dict[str, Any]] = None,
173) -> Tuple[str, Dict[str, Any]]:
174 """
175 Extract the class instantiation parameters from the configuration.
177 Parameters
178 ----------
179 config : dict
180 Configuration of the optimizer.
181 global_config : dict
182 Global configuration parameters (optional).
184 Returns
185 -------
186 (class_name, class_config) : (str, dict)
187 Name of the class to instantiate and its configuration.
188 """
189 class_name = config["class"]
190 class_config = config.setdefault("config", {})
192 merge_parameters(dest=class_config, source=global_config)
194 if _LOG.isEnabledFor(logging.DEBUG):
195 _LOG.debug(
196 "Instantiating: %s with config:\n%s", class_name, json.dumps(class_config, indent=2)
197 )
199 return (class_name, class_config)
202def get_class_from_name(class_name: str) -> type:
203 """
204 Gets the class from the fully qualified name.
206 Parameters
207 ----------
208 class_name : str
209 Fully qualified class name.
211 Returns
212 -------
213 type
214 Class object.
215 """
216 # We need to import mlos_bench to make the factory methods work.
217 class_name_split = class_name.split(".")
218 module_name = ".".join(class_name_split[:-1])
219 class_id = class_name_split[-1]
221 module = importlib.import_module(module_name)
222 cls = getattr(module, class_id)
223 assert isinstance(cls, type)
224 return cls
227# FIXME: Technically, this should return a type "class_name" derived from "base_class".
228def instantiate_from_config(
229 base_class: Type[BaseTypeVar],
230 class_name: str,
231 *args: Any,
232 **kwargs: Any,
233) -> BaseTypeVar:
234 """
235 Factory method for a new class instantiated from config.
237 Parameters
238 ----------
239 base_class : type
240 Base type of the class to instantiate.
241 Currently it's one of {Environment, Service, Optimizer}.
242 class_name : str
243 FQN of a Python class to instantiate, e.g.,
244 "mlos_bench.environments.remote.HostEnv".
245 Must be derived from the `base_class`.
246 args : list
247 Positional arguments to pass to the constructor.
248 kwargs : dict
249 Keyword arguments to pass to the constructor.
251 Returns
252 -------
253 inst : Union[Environment, Service, Optimizer, Storage]
254 An instance of the `class_name` class.
255 """
256 impl = get_class_from_name(class_name)
257 _LOG.info("Instantiating: %s :: %s", class_name, impl)
259 assert issubclass(impl, base_class)
260 ret: BaseTypeVar = impl(*args, **kwargs)
261 assert isinstance(ret, base_class)
262 return ret
265def check_required_params(config: Mapping[str, Any], required_params: Iterable[str]) -> None:
266 """
267 Check if all required parameters are present in the configuration. Raise ValueError
268 if any of the parameters are missing.
270 Parameters
271 ----------
272 config : dict
273 Free-format dictionary with the configuration
274 of the service or benchmarking environment.
275 required_params : Iterable[str]
276 A collection of identifiers of the parameters that must be present
277 in the configuration.
278 """
279 missing_params = set(required_params).difference(config)
280 if missing_params:
281 raise ValueError(
282 "The following parameters must be provided in the configuration"
283 + f" or as command line arguments: {missing_params}"
284 )
287def get_git_info(path: str = __file__) -> Tuple[str, str, str]:
288 """
289 Get the git repository, commit hash, and local path of the given file.
291 Parameters
292 ----------
293 path : str
294 Path to the file in git repository.
296 Returns
297 -------
298 (git_repo, git_commit, git_path) : Tuple[str, str, str]
299 Git repository URL, last commit hash, and relative file path.
300 """
301 dirname = os.path.dirname(path)
302 git_repo = subprocess.check_output(
303 ["git", "-C", dirname, "remote", "get-url", "origin"], text=True
304 ).strip()
305 git_commit = subprocess.check_output(
306 ["git", "-C", dirname, "rev-parse", "HEAD"], text=True
307 ).strip()
308 git_root = subprocess.check_output(
309 ["git", "-C", dirname, "rev-parse", "--show-toplevel"], text=True
310 ).strip()
311 _LOG.debug("Current git branch: %s %s", git_repo, git_commit)
312 rel_path = os.path.relpath(os.path.abspath(path), os.path.abspath(git_root))
313 return (git_repo, git_commit, rel_path.replace("\\", "/"))
316# Note: to avoid circular imports, we don't specify TunableValue here.
317def try_parse_val(val: Optional[str]) -> Optional[Union[int, float, str]]:
318 """
319 Try to parse the value as an int or float, otherwise return the string.
321 This can help with config schema validation to make sure early on that
322 the args we're expecting are the right type.
324 Parameters
325 ----------
326 val : str
327 The initial cmd line arg value.
329 Returns
330 -------
331 TunableValue
332 The parsed value.
333 """
334 if val is None:
335 return val
336 try:
337 val_float = float(val)
338 try:
339 val_int = int(val)
340 return val_int if val_int == val_float else val_float
341 except (ValueError, OverflowError):
342 return val_float
343 except ValueError:
344 return str(val)
347def nullable(func: Callable, value: Optional[Any]) -> Optional[Any]:
348 """
349 Poor man's Maybe monad: apply the function to the value if it's not None.
351 Parameters
352 ----------
353 func : Callable
354 Function to apply to the value.
355 value : Optional[Any]
356 Value to apply the function to.
358 Returns
359 -------
360 value : Optional[Any]
361 The result of the function application or None if the value is None.
362 """
363 return None if value is None else func(value)
366def utcify_timestamp(timestamp: datetime, *, origin: Literal["utc", "local"]) -> datetime:
367 """
368 Augment a timestamp with zoneinfo if missing and convert it to UTC.
370 Parameters
371 ----------
372 timestamp : datetime.datetime
373 A timestamp to convert to UTC.
374 Note: The original datetime may or may not have tzinfo associated with it.
376 origin : Literal["utc", "local"]
377 Whether the source timestamp is considered to be in UTC or local time.
378 In the case of loading data from storage, where we intentionally convert all
379 timestamps to UTC, this can help us retrieve the original timezone when the
380 storage backend doesn't explicitly store it.
381 In the case of receiving data from a client or other source, this can help us
382 convert the timestamp to UTC if it's not already.
384 Returns
385 -------
386 datetime.datetime
387 A datetime with zoneinfo in UTC.
388 """
389 if timestamp.tzinfo is not None or origin == "local":
390 # A timestamp with no zoneinfo is interpretted as "local" time
391 # (e.g., according to the TZ environment variable).
392 # That could be UTC or some other timezone, but either way we convert it to
393 # be explicitly UTC with zone info.
394 return timestamp.astimezone(pytz.UTC)
395 elif origin == "utc":
396 # If the timestamp is already in UTC, we just add the zoneinfo without conversion.
397 # Converting with astimezone() when the local time is *not* UTC would cause
398 # a timestamp conversion which we don't want.
399 return timestamp.replace(tzinfo=pytz.UTC)
400 else:
401 raise ValueError(f"Invalid origin: {origin}")
404def utcify_nullable_timestamp(
405 timestamp: Optional[datetime],
406 *,
407 origin: Literal["utc", "local"],
408) -> Optional[datetime]:
409 """A nullable version of utcify_timestamp."""
410 return utcify_timestamp(timestamp, origin=origin) if timestamp is not None else None
413# All timestamps in the telemetry data must be greater than this date
414# (a very rough approximation for the start of this feature).
415_MIN_TS = datetime(2024, 1, 1, 0, 0, 0, tzinfo=pytz.UTC)
418def datetime_parser(
419 datetime_col: pandas.Series,
420 *,
421 origin: Literal["utc", "local"],
422) -> pandas.Series:
423 """
424 Attempt to convert a pandas column to a datetime format.
426 Parameters
427 ----------
428 datetime_col : pandas.Series
429 The column to convert.
431 origin : Literal["utc", "local"]
432 Whether to interpret naive timestamps as originating from UTC or local time.
434 Returns
435 -------
436 pandas.Series
437 The converted datetime column.
439 Raises
440 ------
441 ValueError
442 On parse errors.
443 """
444 new_datetime_col = pandas.to_datetime(datetime_col, utc=False)
445 # If timezone data is missing, assume the provided origin timezone.
446 if new_datetime_col.dt.tz is None:
447 if origin == "local":
448 tzinfo = datetime.now().astimezone().tzinfo
449 elif origin == "utc":
450 tzinfo = pytz.UTC
451 else:
452 raise ValueError(f"Invalid timezone origin: {origin}")
453 new_datetime_col = new_datetime_col.dt.tz_localize(tzinfo)
454 assert new_datetime_col.dt.tz is not None
455 # And convert it to UTC.
456 new_datetime_col = new_datetime_col.dt.tz_convert("UTC")
457 if new_datetime_col.isna().any():
458 raise ValueError(f"Invalid date format in the data: {datetime_col}")
459 if new_datetime_col.le(_MIN_TS).any():
460 raise ValueError(f"Invalid date range in the data: {datetime_col}")
461 return new_datetime_col