Coverage for mlos_bench/mlos_bench/dict_templater.py: 100%
29 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-10-07 01:52 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-10-07 01:52 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""Simple class to help with nested dictionary $var templating."""
7from copy import deepcopy
8from string import Template
9from typing import Any, Dict, Optional
11from mlos_bench.os_environ import environ
14class DictTemplater: # pylint: disable=too-few-public-methods
15 """Simple class to help with nested dictionary $var templating."""
17 def __init__(self, source_dict: Dict[str, Any]):
18 """
19 Initialize the templater.
21 Parameters
22 ----------
23 source_dict : Dict[str, Any]
24 The template dict to use for source variables.
25 """
26 # A copy of the initial data structure we were given with templates intact.
27 self._template_dict = deepcopy(source_dict)
28 # The source/target dictionary to expand.
29 self._dict: Dict[str, Any] = {}
31 def expand_vars(
32 self,
33 *,
34 extra_source_dict: Optional[Dict[str, Any]] = None,
35 use_os_env: bool = False,
36 ) -> Dict[str, Any]:
37 """
38 Expand the template variables in the destination dictionary.
40 Parameters
41 ----------
42 extra_source_dict : Dict[str, Any]
43 An optional extra source dictionary to use for expansion.
44 use_os_env : bool
45 Whether to use the os environment variables a final fallback for expansion.
47 Returns
48 -------
49 Dict[str, Any]
50 The expanded dictionary.
52 Raises
53 ------
54 ValueError on unsupported nested types.
55 """
56 self._dict = deepcopy(self._template_dict)
57 self._dict = self._expand_vars(self._dict, extra_source_dict, use_os_env)
58 assert isinstance(self._dict, dict)
59 return self._dict
61 def _expand_vars(
62 self,
63 value: Any,
64 extra_source_dict: Optional[Dict[str, Any]],
65 use_os_env: bool,
66 ) -> Any:
67 """Recursively expand $var strings in the currently operating dictionary."""
68 if isinstance(value, str):
69 # First try to expand all $vars internally.
70 value = Template(value).safe_substitute(self._dict)
71 # Next, if there are any left, try to expand them from the extra source dict.
72 if extra_source_dict:
73 value = Template(value).safe_substitute(extra_source_dict)
74 # Finally, fallback to the os environment.
75 if use_os_env:
76 value = Template(value).safe_substitute(dict(environ))
77 elif isinstance(value, dict):
78 # Note: we use a loop instead of dict comprehension in order to
79 # allow secondary expansion of subsequent values immediately.
80 for key, val in value.items():
81 value[key] = self._expand_vars(val, extra_source_dict, use_os_env)
82 elif isinstance(value, list):
83 value = [self._expand_vars(val, extra_source_dict, use_os_env) for val in value]
84 elif isinstance(value, (int, float, bool)) or value is None:
85 return value
86 else:
87 raise ValueError(f"Unexpected type {type(value)} for value {value}")
88 return value