Coverage for mlos_bench/mlos_bench/tests/environments/composite_env_service_test.py: 100%

17 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2024-12-14 01:58 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5"""Check how the services get inherited and overridden in child environments.""" 

6import os 

7 

8import pytest 

9 

10from mlos_bench.environments.composite_env import CompositeEnv 

11from mlos_bench.services.config_persistence import ConfigPersistenceService 

12from mlos_bench.services.local.local_exec import LocalExecService 

13from mlos_bench.tunables.tunable_groups import TunableGroups 

14from mlos_bench.util import path_join 

15 

16# pylint: disable=redefined-outer-name 

17 

18 

19@pytest.fixture 

20def composite_env(tunable_groups: TunableGroups) -> CompositeEnv: 

21 """Test fixture for CompositeEnv with services included on multiple levels.""" 

22 return CompositeEnv( 

23 name="Root", 

24 config={ 

25 "children": [ 

26 { 

27 "name": "Env 1 :: tmp_global", 

28 "class": "mlos_bench.environments.mock_env.MockEnv", 

29 }, 

30 { 

31 "name": "Env 2 :: tmp_other_2", 

32 "class": "mlos_bench.environments.mock_env.MockEnv", 

33 "include_services": ["services/local/mock/mock_local_exec_service_2.jsonc"], 

34 }, 

35 { 

36 "name": "Env 3 :: tmp_other_3", 

37 "class": "mlos_bench.environments.mock_env.MockEnv", 

38 "include_services": ["services/local/mock/mock_local_exec_service_3.jsonc"], 

39 }, 

40 ] 

41 }, 

42 tunables=tunable_groups, 

43 service=LocalExecService( 

44 config={"temp_dir": "_test_tmp_global"}, 

45 parent=ConfigPersistenceService( 

46 { 

47 "config_path": [ 

48 path_join(os.path.dirname(__file__), "../config", abs_path=True), 

49 ] 

50 } 

51 ), 

52 ), 

53 ) 

54 

55 

56def test_composite_services(composite_env: CompositeEnv) -> None: 

57 """Check that each environment gets its own instance of the services.""" 

58 for i, path in ((0, "_test_tmp_global"), (1, "_test_tmp_other_2"), (2, "_test_tmp_other_3")): 

59 service = composite_env.children[i]._service # pylint: disable=protected-access 

60 assert service is not None and hasattr(service, "temp_dir_context") 

61 with service.temp_dir_context() as temp_dir: 

62 assert os.path.samefile(temp_dir, path) 

63 os.rmdir(path)