Coverage for mlos_bench/mlos_bench/tests/config/environments/test_load_environment_config_examples.py: 100%

70 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2024-12-14 01:58 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5"""Tests for loading environment config examples.""" 

6import logging 

7from typing import List 

8 

9import pytest 

10 

11from mlos_bench.config.schemas.config_schemas import ConfigSchema 

12from mlos_bench.environments.base_environment import Environment 

13from mlos_bench.environments.composite_env import CompositeEnv 

14from mlos_bench.services.config_persistence import ConfigPersistenceService 

15from mlos_bench.tests.config import locate_config_examples 

16from mlos_bench.tunables.tunable_groups import TunableGroups 

17 

18_LOG = logging.getLogger(__name__) 

19_LOG.setLevel(logging.DEBUG) 

20 

21 

22# Get the set of configs to test. 

23CONFIG_TYPE = "environments" 

24 

25 

26def filter_configs(configs_to_filter: List[str]) -> List[str]: 

27 """If necessary, filter out json files that aren't for the module we're testing.""" 

28 configs_to_filter = [ 

29 config_path 

30 for config_path in configs_to_filter 

31 if not config_path.endswith("-tunables.jsonc") 

32 ] 

33 return configs_to_filter 

34 

35 

36configs = locate_config_examples( 

37 ConfigPersistenceService.BUILTIN_CONFIG_PATH, 

38 CONFIG_TYPE, 

39 filter_configs, 

40) 

41assert configs 

42 

43 

44@pytest.mark.parametrize("config_path", configs) 

45def test_load_environment_config_examples( 

46 config_loader_service: ConfigPersistenceService, 

47 config_path: str, 

48) -> None: 

49 """Tests loading an environment config example.""" 

50 envs = load_environment_config_examples(config_loader_service, config_path) 

51 for env in envs: 

52 assert env is not None 

53 assert isinstance(env, Environment) 

54 

55 

56def load_environment_config_examples( 

57 config_loader_service: ConfigPersistenceService, 

58 config_path: str, 

59) -> List[Environment]: 

60 """Loads an environment config example.""" 

61 # Make sure that any "required_args" are provided. 

62 global_config = config_loader_service.load_config( 

63 "experiments/experiment_test_config.jsonc", 

64 ConfigSchema.GLOBALS, 

65 ) 

66 global_config.setdefault("trial_id", 1) # normally populated by Launcher 

67 

68 # Make sure we have the required services for the envs being used. 

69 mock_service_configs = [ 

70 "services/local/mock/mock_local_exec_service.jsonc", 

71 "services/remote/mock/mock_fileshare_service.jsonc", 

72 "services/remote/mock/mock_network_service.jsonc", 

73 "services/remote/mock/mock_vm_service.jsonc", 

74 "services/remote/mock/mock_remote_exec_service.jsonc", 

75 "services/remote/mock/mock_auth_service.jsonc", 

76 ] 

77 

78 tunable_groups = TunableGroups() # base tunable groups that all others get built on 

79 

80 for mock_service_config_path in mock_service_configs: 

81 mock_service_config = config_loader_service.load_config( 

82 mock_service_config_path, 

83 ConfigSchema.SERVICE, 

84 ) 

85 config_loader_service.register( 

86 config_loader_service.build_service( 

87 config=mock_service_config, 

88 parent=config_loader_service, 

89 ).export() 

90 ) 

91 

92 envs = config_loader_service.load_environment_list( 

93 config_path, 

94 tunable_groups, 

95 global_config, 

96 service=config_loader_service, 

97 ) 

98 return envs 

99 

100 

101composite_configs = locate_config_examples( 

102 ConfigPersistenceService.BUILTIN_CONFIG_PATH, 

103 "environments/root/", 

104) 

105assert composite_configs 

106 

107 

108@pytest.mark.parametrize("config_path", composite_configs) 

109def test_load_composite_env_config_examples( 

110 config_loader_service: ConfigPersistenceService, 

111 config_path: str, 

112) -> None: 

113 """Tests loading a composite env config example.""" 

114 envs = load_environment_config_examples(config_loader_service, config_path) 

115 assert len(envs) == 1 

116 assert isinstance(envs[0], CompositeEnv) 

117 composite_env: CompositeEnv = envs[0] 

118 

119 for child_env in composite_env.children: 

120 assert child_env is not None 

121 assert isinstance(child_env, Environment) 

122 assert child_env.tunable_params is not None 

123 

124 checked_child_env_groups = set() 

125 for child_tunable, child_group in child_env.tunable_params: 

126 # Lookup that tunable in the composite env. 

127 assert child_tunable in composite_env.tunable_params 

128 (composite_tunable, composite_group) = composite_env.tunable_params.get_tunable( 

129 child_tunable 

130 ) 

131 # Check that the tunables are the same object. 

132 assert child_tunable is composite_tunable 

133 if child_group.name not in checked_child_env_groups: 

134 assert child_group is composite_group 

135 checked_child_env_groups.add(child_group.name) 

136 

137 # Check that when we change a child env, it's value is reflected in the 

138 # composite env as well. 

139 # That is to say, they refer to the same objects, despite having 

140 # potentially been loaded from separate configs. 

141 if child_tunable.is_categorical: 

142 old_cat_value = child_tunable.category 

143 assert child_tunable.value == old_cat_value 

144 assert child_group[child_tunable] == old_cat_value 

145 assert composite_env.tunable_params[child_tunable] == old_cat_value 

146 new_cat_value = [x for x in child_tunable.categories if x != old_cat_value][0] 

147 child_tunable.category = new_cat_value 

148 assert child_env.tunable_params[child_tunable] == new_cat_value 

149 assert composite_env.tunable_params[child_tunable] == child_tunable.category 

150 elif child_tunable.is_numerical: 

151 old_num_value = child_tunable.numerical_value 

152 assert child_tunable.value == old_num_value 

153 assert child_group[child_tunable] == old_num_value 

154 assert composite_env.tunable_params[child_tunable] == old_num_value 

155 child_tunable.numerical_value += 1 

156 assert child_env.tunable_params[child_tunable] == old_num_value + 1 

157 assert composite_env.tunable_params[child_tunable] == child_tunable.numerical_value