Coverage for mlos_bench/mlos_bench/tests/config/environments/test_load_environment_config_examples.py: 100%

69 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2025-01-21 01:50 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5"""Tests for loading environment config examples.""" 

6import logging 

7 

8import pytest 

9 

10from mlos_bench.config.schemas.config_schemas import ConfigSchema 

11from mlos_bench.environments.base_environment import Environment 

12from mlos_bench.environments.composite_env import CompositeEnv 

13from mlos_bench.services.config_persistence import ConfigPersistenceService 

14from mlos_bench.tests.config import locate_config_examples 

15from mlos_bench.tunables.tunable_groups import TunableGroups 

16 

17_LOG = logging.getLogger(__name__) 

18_LOG.setLevel(logging.DEBUG) 

19 

20 

21# Get the set of configs to test. 

22CONFIG_TYPE = "environments" 

23 

24 

25def filter_configs(configs_to_filter: list[str]) -> list[str]: 

26 """If necessary, filter out json files that aren't for the module we're testing.""" 

27 configs_to_filter = [ 

28 config_path 

29 for config_path in configs_to_filter 

30 if not config_path.endswith("-tunables.jsonc") 

31 ] 

32 return configs_to_filter 

33 

34 

35configs = locate_config_examples( 

36 ConfigPersistenceService.BUILTIN_CONFIG_PATH, 

37 CONFIG_TYPE, 

38 filter_configs, 

39) 

40assert configs 

41 

42 

43@pytest.mark.parametrize("config_path", configs) 

44def test_load_environment_config_examples( 

45 config_loader_service: ConfigPersistenceService, 

46 config_path: str, 

47) -> None: 

48 """Tests loading an environment config example.""" 

49 envs = load_environment_config_examples(config_loader_service, config_path) 

50 for env in envs: 

51 assert env is not None 

52 assert isinstance(env, Environment) 

53 

54 

55def load_environment_config_examples( 

56 config_loader_service: ConfigPersistenceService, 

57 config_path: str, 

58) -> list[Environment]: 

59 """Loads an environment config example.""" 

60 # Make sure that any "required_args" are provided. 

61 global_config = config_loader_service.load_config( 

62 "experiments/experiment_test_config.jsonc", 

63 ConfigSchema.GLOBALS, 

64 ) 

65 global_config.setdefault("trial_id", 1) # normally populated by Launcher 

66 

67 # Make sure we have the required services for the envs being used. 

68 mock_service_configs = [ 

69 "services/local/mock/mock_local_exec_service.jsonc", 

70 "services/remote/mock/mock_fileshare_service.jsonc", 

71 "services/remote/mock/mock_network_service.jsonc", 

72 "services/remote/mock/mock_vm_service.jsonc", 

73 "services/remote/mock/mock_remote_exec_service.jsonc", 

74 "services/remote/mock/mock_auth_service.jsonc", 

75 ] 

76 

77 tunable_groups = TunableGroups() # base tunable groups that all others get built on 

78 

79 for mock_service_config_path in mock_service_configs: 

80 mock_service_config = config_loader_service.load_config( 

81 mock_service_config_path, 

82 ConfigSchema.SERVICE, 

83 ) 

84 config_loader_service.register( 

85 config_loader_service.build_service( 

86 config=mock_service_config, 

87 parent=config_loader_service, 

88 ).export() 

89 ) 

90 

91 envs = config_loader_service.load_environment_list( 

92 config_path, 

93 tunable_groups, 

94 global_config, 

95 service=config_loader_service, 

96 ) 

97 return envs 

98 

99 

100composite_configs = locate_config_examples( 

101 ConfigPersistenceService.BUILTIN_CONFIG_PATH, 

102 "environments/root/", 

103) 

104assert composite_configs 

105 

106 

107@pytest.mark.parametrize("config_path", composite_configs) 

108def test_load_composite_env_config_examples( 

109 config_loader_service: ConfigPersistenceService, 

110 config_path: str, 

111) -> None: 

112 """Tests loading a composite env config example.""" 

113 envs = load_environment_config_examples(config_loader_service, config_path) 

114 assert len(envs) == 1 

115 assert isinstance(envs[0], CompositeEnv) 

116 composite_env: CompositeEnv = envs[0] 

117 

118 for child_env in composite_env.children: 

119 assert child_env is not None 

120 assert isinstance(child_env, Environment) 

121 assert child_env.tunable_params is not None 

122 

123 checked_child_env_groups = set() 

124 for child_tunable, child_group in child_env.tunable_params: 

125 # Lookup that tunable in the composite env. 

126 assert child_tunable in composite_env.tunable_params 

127 (composite_tunable, composite_group) = composite_env.tunable_params.get_tunable( 

128 child_tunable 

129 ) 

130 # Check that the tunables are the same object. 

131 assert child_tunable is composite_tunable 

132 if child_group.name not in checked_child_env_groups: 

133 assert child_group is composite_group 

134 checked_child_env_groups.add(child_group.name) 

135 

136 # Check that when we change a child env, it's value is reflected in the 

137 # composite env as well. 

138 # That is to say, they refer to the same objects, despite having 

139 # potentially been loaded from separate configs. 

140 if child_tunable.is_categorical: 

141 old_cat_value = child_tunable.category 

142 assert child_tunable.value == old_cat_value 

143 assert child_group[child_tunable] == old_cat_value 

144 assert composite_env.tunable_params[child_tunable] == old_cat_value 

145 new_cat_value = [x for x in child_tunable.categories if x != old_cat_value][0] 

146 child_tunable.category = new_cat_value 

147 assert child_env.tunable_params[child_tunable] == new_cat_value 

148 assert composite_env.tunable_params[child_tunable] == child_tunable.category 

149 elif child_tunable.is_numerical: 

150 old_num_value = child_tunable.numerical_value 

151 assert child_tunable.value == old_num_value 

152 assert child_group[child_tunable] == old_num_value 

153 assert composite_env.tunable_params[child_tunable] == old_num_value 

154 child_tunable.numerical_value += 1 

155 assert child_env.tunable_params[child_tunable] == old_num_value + 1 

156 assert composite_env.tunable_params[child_tunable] == child_tunable.numerical_value