Coverage for mlos_bench/mlos_bench/tests/config/environments/test_load_environment_config_examples.py: 100%
69 statements
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-21 01:50 +0000
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-21 01:50 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""Tests for loading environment config examples."""
6import logging
8import pytest
10from mlos_bench.config.schemas.config_schemas import ConfigSchema
11from mlos_bench.environments.base_environment import Environment
12from mlos_bench.environments.composite_env import CompositeEnv
13from mlos_bench.services.config_persistence import ConfigPersistenceService
14from mlos_bench.tests.config import locate_config_examples
15from mlos_bench.tunables.tunable_groups import TunableGroups
17_LOG = logging.getLogger(__name__)
18_LOG.setLevel(logging.DEBUG)
21# Get the set of configs to test.
22CONFIG_TYPE = "environments"
25def filter_configs(configs_to_filter: list[str]) -> list[str]:
26 """If necessary, filter out json files that aren't for the module we're testing."""
27 configs_to_filter = [
28 config_path
29 for config_path in configs_to_filter
30 if not config_path.endswith("-tunables.jsonc")
31 ]
32 return configs_to_filter
35configs = locate_config_examples(
36 ConfigPersistenceService.BUILTIN_CONFIG_PATH,
37 CONFIG_TYPE,
38 filter_configs,
39)
40assert configs
43@pytest.mark.parametrize("config_path", configs)
44def test_load_environment_config_examples(
45 config_loader_service: ConfigPersistenceService,
46 config_path: str,
47) -> None:
48 """Tests loading an environment config example."""
49 envs = load_environment_config_examples(config_loader_service, config_path)
50 for env in envs:
51 assert env is not None
52 assert isinstance(env, Environment)
55def load_environment_config_examples(
56 config_loader_service: ConfigPersistenceService,
57 config_path: str,
58) -> list[Environment]:
59 """Loads an environment config example."""
60 # Make sure that any "required_args" are provided.
61 global_config = config_loader_service.load_config(
62 "experiments/experiment_test_config.jsonc",
63 ConfigSchema.GLOBALS,
64 )
65 global_config.setdefault("trial_id", 1) # normally populated by Launcher
67 # Make sure we have the required services for the envs being used.
68 mock_service_configs = [
69 "services/local/mock/mock_local_exec_service.jsonc",
70 "services/remote/mock/mock_fileshare_service.jsonc",
71 "services/remote/mock/mock_network_service.jsonc",
72 "services/remote/mock/mock_vm_service.jsonc",
73 "services/remote/mock/mock_remote_exec_service.jsonc",
74 "services/remote/mock/mock_auth_service.jsonc",
75 ]
77 tunable_groups = TunableGroups() # base tunable groups that all others get built on
79 for mock_service_config_path in mock_service_configs:
80 mock_service_config = config_loader_service.load_config(
81 mock_service_config_path,
82 ConfigSchema.SERVICE,
83 )
84 config_loader_service.register(
85 config_loader_service.build_service(
86 config=mock_service_config,
87 parent=config_loader_service,
88 ).export()
89 )
91 envs = config_loader_service.load_environment_list(
92 config_path,
93 tunable_groups,
94 global_config,
95 service=config_loader_service,
96 )
97 return envs
100composite_configs = locate_config_examples(
101 ConfigPersistenceService.BUILTIN_CONFIG_PATH,
102 "environments/root/",
103)
104assert composite_configs
107@pytest.mark.parametrize("config_path", composite_configs)
108def test_load_composite_env_config_examples(
109 config_loader_service: ConfigPersistenceService,
110 config_path: str,
111) -> None:
112 """Tests loading a composite env config example."""
113 envs = load_environment_config_examples(config_loader_service, config_path)
114 assert len(envs) == 1
115 assert isinstance(envs[0], CompositeEnv)
116 composite_env: CompositeEnv = envs[0]
118 for child_env in composite_env.children:
119 assert child_env is not None
120 assert isinstance(child_env, Environment)
121 assert child_env.tunable_params is not None
123 checked_child_env_groups = set()
124 for child_tunable, child_group in child_env.tunable_params:
125 # Lookup that tunable in the composite env.
126 assert child_tunable in composite_env.tunable_params
127 (composite_tunable, composite_group) = composite_env.tunable_params.get_tunable(
128 child_tunable
129 )
130 # Check that the tunables are the same object.
131 assert child_tunable is composite_tunable
132 if child_group.name not in checked_child_env_groups:
133 assert child_group is composite_group
134 checked_child_env_groups.add(child_group.name)
136 # Check that when we change a child env, it's value is reflected in the
137 # composite env as well.
138 # That is to say, they refer to the same objects, despite having
139 # potentially been loaded from separate configs.
140 if child_tunable.is_categorical:
141 old_cat_value = child_tunable.category
142 assert child_tunable.value == old_cat_value
143 assert child_group[child_tunable] == old_cat_value
144 assert composite_env.tunable_params[child_tunable] == old_cat_value
145 new_cat_value = [x for x in child_tunable.categories if x != old_cat_value][0]
146 child_tunable.category = new_cat_value
147 assert child_env.tunable_params[child_tunable] == new_cat_value
148 assert composite_env.tunable_params[child_tunable] == child_tunable.category
149 elif child_tunable.is_numerical:
150 old_num_value = child_tunable.numerical_value
151 assert child_tunable.value == old_num_value
152 assert child_group[child_tunable] == old_num_value
153 assert composite_env.tunable_params[child_tunable] == old_num_value
154 child_tunable.numerical_value += 1
155 assert child_env.tunable_params[child_tunable] == old_num_value + 1
156 assert composite_env.tunable_params[child_tunable] == child_tunable.numerical_value