Coverage for mlos_bench/mlos_bench/util.py: 89%
109 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-10-07 01:52 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-10-07 01:52 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""Various helper functions for mlos_bench."""
7# NOTE: This has to be placed in the top-level mlos_bench package to avoid circular imports.
9import importlib
10import json
11import logging
12import os
13import subprocess
14from datetime import datetime
15from typing import (
16 TYPE_CHECKING,
17 Any,
18 Callable,
19 Dict,
20 Iterable,
21 Literal,
22 Mapping,
23 Optional,
24 Tuple,
25 Type,
26 TypeVar,
27 Union,
28)
30import pandas
31import pytz
33_LOG = logging.getLogger(__name__)
35if TYPE_CHECKING:
36 from mlos_bench.environments.base_environment import Environment
37 from mlos_bench.optimizers.base_optimizer import Optimizer
38 from mlos_bench.schedulers.base_scheduler import Scheduler
39 from mlos_bench.services.base_service import Service
40 from mlos_bench.storage.base_storage import Storage
42# BaseTypeVar is a generic with a constraint of the three base classes.
43BaseTypeVar = TypeVar("BaseTypeVar", "Environment", "Optimizer", "Scheduler", "Service", "Storage")
44BaseTypes = Union["Environment", "Optimizer", "Scheduler", "Service", "Storage"]
47def preprocess_dynamic_configs(*, dest: dict, source: Optional[dict] = None) -> dict:
48 """
49 Replaces all $name values in the destination config with the corresponding value
50 from the source config.
52 Parameters
53 ----------
54 dest : dict
55 Destination config.
56 source : Optional[dict]
57 Source config.
59 Returns
60 -------
61 dest : dict
62 A reference to the destination config after the preprocessing.
63 """
64 if source is None:
65 source = {}
66 for key, val in dest.items():
67 if isinstance(val, str) and val.startswith("$") and val[1:] in source:
68 dest[key] = source[val[1:]]
69 return dest
72def merge_parameters(
73 *,
74 dest: dict,
75 source: Optional[dict] = None,
76 required_keys: Optional[Iterable[str]] = None,
77) -> dict:
78 """
79 Merge the source config dict into the destination config. Pick from the source
80 configs *ONLY* the keys that are already present in the destination config.
82 Parameters
83 ----------
84 dest : dict
85 Destination config.
86 source : Optional[dict]
87 Source config.
88 required_keys : Optional[Iterable[str]]
89 An optional list of keys that must be present in the destination config.
91 Returns
92 -------
93 dest : dict
94 A reference to the destination config after the merge.
95 """
96 if source is None:
97 source = {}
99 for key in set(dest).intersection(source):
100 dest[key] = source[key]
102 for key in required_keys or []:
103 if key in dest:
104 continue
105 if key in source:
106 dest[key] = source[key]
107 else:
108 raise ValueError("Missing required parameter: " + key)
110 return dest
113def path_join(*args: str, abs_path: bool = False) -> str:
114 """
115 Joins the path components and normalizes the path.
117 Parameters
118 ----------
119 args : str
120 Path components.
122 abs_path : bool
123 If True, the path is converted to be absolute.
125 Returns
126 -------
127 str
128 Joined path.
129 """
130 path = os.path.join(*args)
131 if abs_path:
132 path = os.path.abspath(path)
133 return os.path.normpath(path).replace("\\", "/")
136def prepare_class_load(
137 config: dict,
138 global_config: Optional[Dict[str, Any]] = None,
139) -> Tuple[str, Dict[str, Any]]:
140 """
141 Extract the class instantiation parameters from the configuration.
143 Parameters
144 ----------
145 config : dict
146 Configuration of the optimizer.
147 global_config : dict
148 Global configuration parameters (optional).
150 Returns
151 -------
152 (class_name, class_config) : (str, dict)
153 Name of the class to instantiate and its configuration.
154 """
155 class_name = config["class"]
156 class_config = config.setdefault("config", {})
158 merge_parameters(dest=class_config, source=global_config)
160 if _LOG.isEnabledFor(logging.DEBUG):
161 _LOG.debug(
162 "Instantiating: %s with config:\n%s", class_name, json.dumps(class_config, indent=2)
163 )
165 return (class_name, class_config)
168def get_class_from_name(class_name: str) -> type:
169 """
170 Gets the class from the fully qualified name.
172 Parameters
173 ----------
174 class_name : str
175 Fully qualified class name.
177 Returns
178 -------
179 type
180 Class object.
181 """
182 # We need to import mlos_bench to make the factory methods work.
183 class_name_split = class_name.split(".")
184 module_name = ".".join(class_name_split[:-1])
185 class_id = class_name_split[-1]
187 module = importlib.import_module(module_name)
188 cls = getattr(module, class_id)
189 assert isinstance(cls, type)
190 return cls
193# FIXME: Technically, this should return a type "class_name" derived from "base_class".
194def instantiate_from_config(
195 base_class: Type[BaseTypeVar],
196 class_name: str,
197 *args: Any,
198 **kwargs: Any,
199) -> BaseTypeVar:
200 """
201 Factory method for a new class instantiated from config.
203 Parameters
204 ----------
205 base_class : type
206 Base type of the class to instantiate.
207 Currently it's one of {Environment, Service, Optimizer}.
208 class_name : str
209 FQN of a Python class to instantiate, e.g.,
210 "mlos_bench.environments.remote.HostEnv".
211 Must be derived from the `base_class`.
212 args : list
213 Positional arguments to pass to the constructor.
214 kwargs : dict
215 Keyword arguments to pass to the constructor.
217 Returns
218 -------
219 inst : Union[Environment, Service, Optimizer, Storage]
220 An instance of the `class_name` class.
221 """
222 impl = get_class_from_name(class_name)
223 _LOG.info("Instantiating: %s :: %s", class_name, impl)
225 assert issubclass(impl, base_class)
226 ret: BaseTypeVar = impl(*args, **kwargs)
227 assert isinstance(ret, base_class)
228 return ret
231def check_required_params(config: Mapping[str, Any], required_params: Iterable[str]) -> None:
232 """
233 Check if all required parameters are present in the configuration. Raise ValueError
234 if any of the parameters are missing.
236 Parameters
237 ----------
238 config : dict
239 Free-format dictionary with the configuration
240 of the service or benchmarking environment.
241 required_params : Iterable[str]
242 A collection of identifiers of the parameters that must be present
243 in the configuration.
244 """
245 missing_params = set(required_params).difference(config)
246 if missing_params:
247 raise ValueError(
248 "The following parameters must be provided in the configuration"
249 + f" or as command line arguments: {missing_params}"
250 )
253def get_git_info(path: str = __file__) -> Tuple[str, str, str]:
254 """
255 Get the git repository, commit hash, and local path of the given file.
257 Parameters
258 ----------
259 path : str
260 Path to the file in git repository.
262 Returns
263 -------
264 (git_repo, git_commit, git_path) : Tuple[str, str, str]
265 Git repository URL, last commit hash, and relative file path.
266 """
267 dirname = os.path.dirname(path)
268 git_repo = subprocess.check_output(
269 ["git", "-C", dirname, "remote", "get-url", "origin"], text=True
270 ).strip()
271 git_commit = subprocess.check_output(
272 ["git", "-C", dirname, "rev-parse", "HEAD"], text=True
273 ).strip()
274 git_root = subprocess.check_output(
275 ["git", "-C", dirname, "rev-parse", "--show-toplevel"], text=True
276 ).strip()
277 _LOG.debug("Current git branch: %s %s", git_repo, git_commit)
278 rel_path = os.path.relpath(os.path.abspath(path), os.path.abspath(git_root))
279 return (git_repo, git_commit, rel_path.replace("\\", "/"))
282# Note: to avoid circular imports, we don't specify TunableValue here.
283def try_parse_val(val: Optional[str]) -> Optional[Union[int, float, str]]:
284 """
285 Try to parse the value as an int or float, otherwise return the string.
287 This can help with config schema validation to make sure early on that
288 the args we're expecting are the right type.
290 Parameters
291 ----------
292 val : str
293 The initial cmd line arg value.
295 Returns
296 -------
297 TunableValue
298 The parsed value.
299 """
300 if val is None:
301 return val
302 try:
303 val_float = float(val)
304 try:
305 val_int = int(val)
306 return val_int if val_int == val_float else val_float
307 except (ValueError, OverflowError):
308 return val_float
309 except ValueError:
310 return str(val)
313def nullable(func: Callable, value: Optional[Any]) -> Optional[Any]:
314 """
315 Poor man's Maybe monad: apply the function to the value if it's not None.
317 Parameters
318 ----------
319 func : Callable
320 Function to apply to the value.
321 value : Optional[Any]
322 Value to apply the function to.
324 Returns
325 -------
326 value : Optional[Any]
327 The result of the function application or None if the value is None.
328 """
329 return None if value is None else func(value)
332def utcify_timestamp(timestamp: datetime, *, origin: Literal["utc", "local"]) -> datetime:
333 """
334 Augment a timestamp with zoneinfo if missing and convert it to UTC.
336 Parameters
337 ----------
338 timestamp : datetime
339 A timestamp to convert to UTC.
340 Note: The original datetime may or may not have tzinfo associated with it.
342 origin : Literal["utc", "local"]
343 Whether the source timestamp is considered to be in UTC or local time.
344 In the case of loading data from storage, where we intentionally convert all
345 timestamps to UTC, this can help us retrieve the original timezone when the
346 storage backend doesn't explicitly store it.
347 In the case of receiving data from a client or other source, this can help us
348 convert the timestamp to UTC if it's not already.
350 Returns
351 -------
352 datetime
353 A datetime with zoneinfo in UTC.
354 """
355 if timestamp.tzinfo is not None or origin == "local":
356 # A timestamp with no zoneinfo is interpretted as "local" time
357 # (e.g., according to the TZ environment variable).
358 # That could be UTC or some other timezone, but either way we convert it to
359 # be explicitly UTC with zone info.
360 return timestamp.astimezone(pytz.UTC)
361 elif origin == "utc":
362 # If the timestamp is already in UTC, we just add the zoneinfo without conversion.
363 # Converting with astimezone() when the local time is *not* UTC would cause
364 # a timestamp conversion which we don't want.
365 return timestamp.replace(tzinfo=pytz.UTC)
366 else:
367 raise ValueError(f"Invalid origin: {origin}")
370def utcify_nullable_timestamp(
371 timestamp: Optional[datetime],
372 *,
373 origin: Literal["utc", "local"],
374) -> Optional[datetime]:
375 """A nullable version of utcify_timestamp."""
376 return utcify_timestamp(timestamp, origin=origin) if timestamp is not None else None
379# All timestamps in the telemetry data must be greater than this date
380# (a very rough approximation for the start of this feature).
381_MIN_TS = datetime(2024, 1, 1, 0, 0, 0, tzinfo=pytz.UTC)
384def datetime_parser(
385 datetime_col: pandas.Series,
386 *,
387 origin: Literal["utc", "local"],
388) -> pandas.Series:
389 """
390 Attempt to convert a pandas column to a datetime format.
392 Parameters
393 ----------
394 datetime_col : pandas.Series
395 The column to convert.
397 origin : Literal["utc", "local"]
398 Whether to interpret naive timestamps as originating from UTC or local time.
400 Returns
401 -------
402 pandas.Series
403 The converted datetime column.
405 Raises
406 ------
407 ValueError
408 On parse errors.
409 """
410 new_datetime_col = pandas.to_datetime(datetime_col, utc=False)
411 # If timezone data is missing, assume the provided origin timezone.
412 if new_datetime_col.dt.tz is None:
413 if origin == "local":
414 tzinfo = datetime.now().astimezone().tzinfo
415 elif origin == "utc":
416 tzinfo = pytz.UTC
417 else:
418 raise ValueError(f"Invalid timezone origin: {origin}")
419 new_datetime_col = new_datetime_col.dt.tz_localize(tzinfo)
420 assert new_datetime_col.dt.tz is not None
421 # And convert it to UTC.
422 new_datetime_col = new_datetime_col.dt.tz_convert("UTC")
423 if new_datetime_col.isna().any():
424 raise ValueError(f"Invalid date format in the data: {datetime_col}")
425 if new_datetime_col.le(_MIN_TS).any():
426 raise ValueError(f"Invalid date range in the data: {datetime_col}")
427 return new_datetime_col