Coverage for mlos_bench/mlos_bench/util.py: 89%
121 statements
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-21 01:50 +0000
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-21 01:50 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""Various helper functions for mlos_bench."""
7# NOTE: This has to be placed in the top-level mlos_bench package to avoid circular imports.
9import importlib
10import json
11import logging
12import os
13import subprocess
14from collections.abc import Callable, Iterable, Mapping
15from datetime import datetime
16from typing import TYPE_CHECKING, Any, Literal, TypeVar, Union
18import pandas
19import pytz
21_LOG = logging.getLogger(__name__)
23if TYPE_CHECKING:
24 from mlos_bench.environments.base_environment import Environment
25 from mlos_bench.optimizers.base_optimizer import Optimizer
26 from mlos_bench.schedulers.base_scheduler import Scheduler
27 from mlos_bench.services.base_service import Service
28 from mlos_bench.storage.base_storage import Storage
30BaseTypeVar = TypeVar("BaseTypeVar", "Environment", "Optimizer", "Scheduler", "Service", "Storage")
31"""BaseTypeVar is a generic with a constraint of the main base classes (e.g.,
32:py:class:`~mlos_bench.environments.base_environment.Environment`,
33:py:class:`~mlos_bench.optimizers.base_optimizer.Optimizer`,
34:py:class:`~mlos_bench.schedulers.base_scheduler.Scheduler`,
35:py:class:`~mlos_bench.services.base_service.Service`,
36:py:class:`~mlos_bench.storage.base_storage.Storage`, etc.).
37"""
39BaseTypes = Union[ # pylint: disable=consider-alternative-union-syntax
40 "Environment", "Optimizer", "Scheduler", "Service", "Storage"
41]
42"""Similar to :py:data:`.BaseTypeVar`, BaseTypes is a Union of the main base classes."""
45# Adjusted from https://github.com/python/cpython/blob/v3.11.10/Lib/distutils/util.py#L308
46# See Also: https://github.com/microsoft/MLOS/issues/865
47def strtobool(val: str) -> bool:
48 """
49 Convert a string representation of truth to true (1) or false (0).
51 Parameters
52 ----------
53 val : str
54 True values are 'y', 'yes', 't', 'true', 'on', and '1';
55 False values are 'n', 'no', 'f', 'false', 'off', and '0'.
57 Raises
58 ------
59 ValueError
60 If 'val' is anything else.
61 """
62 val = val.lower()
63 if val in {"y", "yes", "t", "true", "on", "1"}:
64 return True
65 elif val in {"n", "no", "f", "false", "off", "0"}:
66 return False
67 else:
68 raise ValueError(f"Invalid Boolean value: '{val}'")
71def preprocess_dynamic_configs(*, dest: dict, source: dict | None = None) -> dict:
72 """
73 Replaces all ``$name`` values in the destination config with the corresponding value
74 from the source config.
76 Parameters
77 ----------
78 dest : dict
79 Destination config.
80 source : dict | None
81 Source config.
83 Returns
84 -------
85 dest : dict
86 A reference to the destination config after the preprocessing.
87 """
88 if source is None:
89 source = {}
90 for key, val in dest.items():
91 if isinstance(val, str) and val.startswith("$") and val[1:] in source:
92 dest[key] = source[val[1:]]
93 return dest
96def merge_parameters(
97 *,
98 dest: dict,
99 source: dict | None = None,
100 required_keys: Iterable[str] | None = None,
101) -> dict:
102 """
103 Merge the source config dict into the destination config. Pick from the source
104 configs *ONLY* the keys that are already present in the destination config.
106 Parameters
107 ----------
108 dest : dict
109 Destination config.
110 source : dict | None
111 Source config.
112 required_keys : Optional[Iterable[str]]
113 An optional list of keys that must be present in the destination config.
115 Returns
116 -------
117 dest : dict
118 A reference to the destination config after the merge.
119 """
120 if source is None:
121 source = {}
123 for key in set(dest).intersection(source):
124 dest[key] = source[key]
126 for key in required_keys or []:
127 if key in dest:
128 continue
129 if key in source:
130 dest[key] = source[key]
131 else:
132 raise ValueError("Missing required parameter: " + key)
134 return dest
137def path_join(*args: str, abs_path: bool = False) -> str:
138 """
139 Joins the path components and normalizes the path.
141 Parameters
142 ----------
143 args : str
144 Path components.
146 abs_path : bool
147 If True, the path is converted to be absolute.
149 Returns
150 -------
151 str
152 Joined path.
153 """
154 path = os.path.join(*args)
155 if abs_path:
156 path = os.path.abspath(path)
157 return os.path.normpath(path).replace("\\", "/")
160def prepare_class_load(
161 config: dict,
162 global_config: dict[str, Any] | None = None,
163) -> tuple[str, dict[str, Any]]:
164 """
165 Extract the class instantiation parameters from the configuration.
167 Parameters
168 ----------
169 config : dict
170 Configuration of the optimizer.
171 global_config : dict
172 Global configuration parameters (optional).
174 Returns
175 -------
176 (class_name, class_config) : (str, dict)
177 Name of the class to instantiate and its configuration.
178 """
179 class_name = config["class"]
180 class_config = config.setdefault("config", {})
182 merge_parameters(dest=class_config, source=global_config)
184 if _LOG.isEnabledFor(logging.DEBUG):
185 _LOG.debug(
186 "Instantiating: %s with config:\n%s", class_name, json.dumps(class_config, indent=2)
187 )
189 return (class_name, class_config)
192def get_class_from_name(class_name: str) -> type:
193 """
194 Gets the class from the fully qualified name.
196 Parameters
197 ----------
198 class_name : str
199 Fully qualified class name.
201 Returns
202 -------
203 type
204 Class object.
205 """
206 # We need to import mlos_bench to make the factory methods work.
207 class_name_split = class_name.split(".")
208 module_name = ".".join(class_name_split[:-1])
209 class_id = class_name_split[-1]
211 module = importlib.import_module(module_name)
212 cls = getattr(module, class_id)
213 assert isinstance(cls, type)
214 return cls
217# FIXME: Technically, this should return a type "class_name" derived from "base_class".
218def instantiate_from_config(
219 base_class: type[BaseTypeVar],
220 class_name: str,
221 *args: Any,
222 **kwargs: Any,
223) -> BaseTypeVar:
224 """
225 Factory method for a new class instantiated from config.
227 Parameters
228 ----------
229 base_class : type
230 Base type of the class to instantiate.
231 Currently it's one of {Environment, Service, Optimizer}.
232 class_name : str
233 FQN of a Python class to instantiate, e.g.,
234 "mlos_bench.environments.remote.HostEnv".
235 Must be derived from the `base_class`.
236 args : list
237 Positional arguments to pass to the constructor.
238 kwargs : dict
239 Keyword arguments to pass to the constructor.
241 Returns
242 -------
243 inst : Union[Environment, Service, Optimizer, Storage]
244 An instance of the `class_name` class.
245 """
246 impl = get_class_from_name(class_name)
247 _LOG.info("Instantiating: %s :: %s", class_name, impl)
249 assert issubclass(impl, base_class)
250 ret: BaseTypeVar = impl(*args, **kwargs)
251 assert isinstance(ret, base_class)
252 return ret
255def check_required_params(config: Mapping[str, Any], required_params: Iterable[str]) -> None:
256 """
257 Check if all required parameters are present in the configuration. Raise ValueError
258 if any of the parameters are missing.
260 Parameters
261 ----------
262 config : dict
263 Free-format dictionary with the configuration
264 of the service or benchmarking environment.
265 required_params : Iterable[str]
266 A collection of identifiers of the parameters that must be present
267 in the configuration.
268 """
269 missing_params = set(required_params).difference(config)
270 if missing_params:
271 raise ValueError(
272 "The following parameters must be provided in the configuration"
273 + f" or as command line arguments: {missing_params}"
274 )
277def get_git_info(path: str = __file__) -> tuple[str, str, str]:
278 """
279 Get the git repository, commit hash, and local path of the given file.
281 Parameters
282 ----------
283 path : str
284 Path to the file in git repository.
286 Returns
287 -------
288 (git_repo, git_commit, git_path) : tuple[str, str, str]
289 Git repository URL, last commit hash, and relative file path.
290 """
291 dirname = os.path.dirname(path)
292 git_repo = subprocess.check_output(
293 ["git", "-C", dirname, "remote", "get-url", "origin"], text=True
294 ).strip()
295 git_commit = subprocess.check_output(
296 ["git", "-C", dirname, "rev-parse", "HEAD"], text=True
297 ).strip()
298 git_root = subprocess.check_output(
299 ["git", "-C", dirname, "rev-parse", "--show-toplevel"], text=True
300 ).strip()
301 _LOG.debug("Current git branch: %s %s", git_repo, git_commit)
302 rel_path = os.path.relpath(os.path.abspath(path), os.path.abspath(git_root))
303 return (git_repo, git_commit, rel_path.replace("\\", "/"))
306# Note: to avoid circular imports, we don't specify TunableValue here.
307def try_parse_val(val: str | None) -> int | float | str | None:
308 """
309 Try to parse the value as an int or float, otherwise return the string.
311 This can help with config schema validation to make sure early on that
312 the args we're expecting are the right type.
314 Parameters
315 ----------
316 val : str
317 The initial cmd line arg value.
319 Returns
320 -------
321 TunableValue
322 The parsed value.
323 """
324 if val is None:
325 return val
326 try:
327 val_float = float(val)
328 try:
329 val_int = int(val)
330 return val_int if val_int == val_float else val_float
331 except (ValueError, OverflowError):
332 return val_float
333 except ValueError:
334 return str(val)
337NullableT = TypeVar("NullableT")
338"""A generic type variable for :py:func:`nullable` return types."""
341def nullable(func: Callable[..., NullableT], value: Any | None) -> NullableT | None:
342 """
343 Poor man's Maybe monad: apply the function to the value if it's not None.
345 Parameters
346 ----------
347 func : Callable
348 Function to apply to the value.
349 value : Any | None
350 Value to apply the function to.
352 Returns
353 -------
354 value : NullableT | None
355 The result of the function application or None if the value is None.
357 Examples
358 --------
359 >>> nullable(int, "1")
360 1
361 >>> nullable(int, None)
362 ...
363 >>> nullable(str, 1)
364 '1'
365 """
366 return None if value is None else func(value)
369def utcify_timestamp(timestamp: datetime, *, origin: Literal["utc", "local"]) -> datetime:
370 """
371 Augment a timestamp with zoneinfo if missing and convert it to UTC.
373 Parameters
374 ----------
375 timestamp : datetime.datetime
376 A timestamp to convert to UTC.
377 Note: The original datetime may or may not have tzinfo associated with it.
379 origin : Literal["utc", "local"]
380 Whether the source timestamp is considered to be in UTC or local time.
381 In the case of loading data from storage, where we intentionally convert all
382 timestamps to UTC, this can help us retrieve the original timezone when the
383 storage backend doesn't explicitly store it.
384 In the case of receiving data from a client or other source, this can help us
385 convert the timestamp to UTC if it's not already.
387 Returns
388 -------
389 datetime.datetime
390 A datetime with zoneinfo in UTC.
391 """
392 if timestamp.tzinfo is not None or origin == "local":
393 # A timestamp with no zoneinfo is interpretted as "local" time
394 # (e.g., according to the TZ environment variable).
395 # That could be UTC or some other timezone, but either way we convert it to
396 # be explicitly UTC with zone info.
397 return timestamp.astimezone(pytz.UTC)
398 elif origin == "utc":
399 # If the timestamp is already in UTC, we just add the zoneinfo without conversion.
400 # Converting with astimezone() when the local time is *not* UTC would cause
401 # a timestamp conversion which we don't want.
402 return timestamp.replace(tzinfo=pytz.UTC)
403 else:
404 raise ValueError(f"Invalid origin: {origin}")
407def utcify_nullable_timestamp(
408 timestamp: datetime | None,
409 *,
410 origin: Literal["utc", "local"],
411) -> datetime | None:
412 """A nullable version of utcify_timestamp."""
413 return utcify_timestamp(timestamp, origin=origin) if timestamp is not None else None
416# All timestamps in the telemetry data must be greater than this date
417# (a very rough approximation for the start of this feature).
418_MIN_TS = datetime(2024, 1, 1, 0, 0, 0, tzinfo=pytz.UTC)
421def datetime_parser(
422 datetime_col: pandas.Series,
423 *,
424 origin: Literal["utc", "local"],
425) -> pandas.Series:
426 """
427 Attempt to convert a pandas column to a datetime format.
429 Parameters
430 ----------
431 datetime_col : pandas.Series
432 The column to convert.
434 origin : Literal["utc", "local"]
435 Whether to interpret naive timestamps as originating from UTC or local time.
437 Returns
438 -------
439 pandas.Series
440 The converted datetime column.
442 Raises
443 ------
444 ValueError
445 On parse errors.
446 """
447 new_datetime_col = pandas.to_datetime(datetime_col, utc=False)
448 # If timezone data is missing, assume the provided origin timezone.
449 if new_datetime_col.dt.tz is None:
450 if origin == "local":
451 tzinfo = datetime.now().astimezone().tzinfo
452 elif origin == "utc":
453 tzinfo = pytz.UTC
454 else:
455 raise ValueError(f"Invalid timezone origin: {origin}")
456 new_datetime_col = new_datetime_col.dt.tz_localize(tzinfo)
457 assert new_datetime_col.dt.tz is not None
458 # And convert it to UTC.
459 new_datetime_col = new_datetime_col.dt.tz_convert("UTC")
460 if new_datetime_col.isna().any():
461 raise ValueError(f"Invalid date format in the data: {datetime_col}")
462 if new_datetime_col.le(_MIN_TS).any():
463 raise ValueError(f"Invalid date range in the data: {datetime_col}")
464 return new_datetime_col