Coverage for mlos_bench/mlos_bench/tests/environments/composite_env_service_test.py: 100%
17 statements
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-14 01:58 +0000
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-14 01:58 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""Check how the services get inherited and overridden in child environments."""
6import os
8import pytest
10from mlos_bench.environments.composite_env import CompositeEnv
11from mlos_bench.services.config_persistence import ConfigPersistenceService
12from mlos_bench.services.local.local_exec import LocalExecService
13from mlos_bench.tunables.tunable_groups import TunableGroups
14from mlos_bench.util import path_join
16# pylint: disable=redefined-outer-name
19@pytest.fixture
20def composite_env(tunable_groups: TunableGroups) -> CompositeEnv:
21 """Test fixture for CompositeEnv with services included on multiple levels."""
22 return CompositeEnv(
23 name="Root",
24 config={
25 "children": [
26 {
27 "name": "Env 1 :: tmp_global",
28 "class": "mlos_bench.environments.mock_env.MockEnv",
29 },
30 {
31 "name": "Env 2 :: tmp_other_2",
32 "class": "mlos_bench.environments.mock_env.MockEnv",
33 "include_services": ["services/local/mock/mock_local_exec_service_2.jsonc"],
34 },
35 {
36 "name": "Env 3 :: tmp_other_3",
37 "class": "mlos_bench.environments.mock_env.MockEnv",
38 "include_services": ["services/local/mock/mock_local_exec_service_3.jsonc"],
39 },
40 ]
41 },
42 tunables=tunable_groups,
43 service=LocalExecService(
44 config={"temp_dir": "_test_tmp_global"},
45 parent=ConfigPersistenceService(
46 {
47 "config_path": [
48 path_join(os.path.dirname(__file__), "../config", abs_path=True),
49 ]
50 }
51 ),
52 ),
53 )
56def test_composite_services(composite_env: CompositeEnv) -> None:
57 """Check that each environment gets its own instance of the services."""
58 for i, path in ((0, "_test_tmp_global"), (1, "_test_tmp_other_2"), (2, "_test_tmp_other_3")):
59 service = composite_env.children[i]._service # pylint: disable=protected-access
60 assert service is not None and hasattr(service, "temp_dir_context")
61 with service.temp_dir_context() as temp_dir:
62 assert os.path.samefile(temp_dir, path)
63 os.rmdir(path)