Coverage for mlos_bench/mlos_bench/dict_templater.py: 100%

29 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-10-07 01:52 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5"""Simple class to help with nested dictionary $var templating.""" 

6 

7from copy import deepcopy 

8from string import Template 

9from typing import Any, Dict, Optional 

10 

11from mlos_bench.os_environ import environ 

12 

13 

14class DictTemplater: # pylint: disable=too-few-public-methods 

15 """Simple class to help with nested dictionary $var templating.""" 

16 

17 def __init__(self, source_dict: Dict[str, Any]): 

18 """ 

19 Initialize the templater. 

20 

21 Parameters 

22 ---------- 

23 source_dict : Dict[str, Any] 

24 The template dict to use for source variables. 

25 """ 

26 # A copy of the initial data structure we were given with templates intact. 

27 self._template_dict = deepcopy(source_dict) 

28 # The source/target dictionary to expand. 

29 self._dict: Dict[str, Any] = {} 

30 

31 def expand_vars( 

32 self, 

33 *, 

34 extra_source_dict: Optional[Dict[str, Any]] = None, 

35 use_os_env: bool = False, 

36 ) -> Dict[str, Any]: 

37 """ 

38 Expand the template variables in the destination dictionary. 

39 

40 Parameters 

41 ---------- 

42 extra_source_dict : Dict[str, Any] 

43 An optional extra source dictionary to use for expansion. 

44 use_os_env : bool 

45 Whether to use the os environment variables a final fallback for expansion. 

46 

47 Returns 

48 ------- 

49 Dict[str, Any] 

50 The expanded dictionary. 

51 

52 Raises 

53 ------ 

54 ValueError on unsupported nested types. 

55 """ 

56 self._dict = deepcopy(self._template_dict) 

57 self._dict = self._expand_vars(self._dict, extra_source_dict, use_os_env) 

58 assert isinstance(self._dict, dict) 

59 return self._dict 

60 

61 def _expand_vars( 

62 self, 

63 value: Any, 

64 extra_source_dict: Optional[Dict[str, Any]], 

65 use_os_env: bool, 

66 ) -> Any: 

67 """Recursively expand $var strings in the currently operating dictionary.""" 

68 if isinstance(value, str): 

69 # First try to expand all $vars internally. 

70 value = Template(value).safe_substitute(self._dict) 

71 # Next, if there are any left, try to expand them from the extra source dict. 

72 if extra_source_dict: 

73 value = Template(value).safe_substitute(extra_source_dict) 

74 # Finally, fallback to the os environment. 

75 if use_os_env: 

76 value = Template(value).safe_substitute(dict(environ)) 

77 elif isinstance(value, dict): 

78 # Note: we use a loop instead of dict comprehension in order to 

79 # allow secondary expansion of subsequent values immediately. 

80 for key, val in value.items(): 

81 value[key] = self._expand_vars(val, extra_source_dict, use_os_env) 

82 elif isinstance(value, list): 

83 value = [self._expand_vars(val, extra_source_dict, use_os_env) for val in value] 

84 elif isinstance(value, (int, float, bool)) or value is None: 

85 return value 

86 else: 

87 raise ValueError(f"Unexpected type {type(value)} for value {value}") 

88 return value