Coverage for mlos_bench/mlos_bench/tests/environments/composite_env_service_test.py: 100%

19 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2025-01-21 01:50 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5"""Check how the services get inherited and overridden in child environments.""" 

6import os 

7 

8import pytest 

9 

10from mlos_bench.environments.composite_env import CompositeEnv 

11from mlos_bench.services.config_persistence import ConfigPersistenceService 

12from mlos_bench.services.local.local_exec import LocalExecService 

13from mlos_bench.services.types.local_exec_type import SupportsLocalExec 

14from mlos_bench.tunables.tunable_groups import TunableGroups 

15from mlos_bench.util import path_join 

16 

17# pylint: disable=redefined-outer-name 

18 

19 

20@pytest.fixture 

21def composite_env(tunable_groups: TunableGroups) -> CompositeEnv: 

22 """Test fixture for CompositeEnv with services included on multiple levels.""" 

23 return CompositeEnv( 

24 name="Root", 

25 config={ 

26 "children": [ 

27 { 

28 "name": "Env 1 :: tmp_global", 

29 "class": "mlos_bench.environments.mock_env.MockEnv", 

30 }, 

31 { 

32 "name": "Env 2 :: tmp_other_2", 

33 "class": "mlos_bench.environments.mock_env.MockEnv", 

34 "include_services": ["services/local/mock/mock_local_exec_service_2.jsonc"], 

35 }, 

36 { 

37 "name": "Env 3 :: tmp_other_3", 

38 "class": "mlos_bench.environments.mock_env.MockEnv", 

39 "include_services": ["services/local/mock/mock_local_exec_service_3.jsonc"], 

40 }, 

41 ] 

42 }, 

43 tunables=tunable_groups, 

44 service=LocalExecService( 

45 config={"temp_dir": "_test_tmp_global"}, 

46 parent=ConfigPersistenceService( 

47 { 

48 "config_path": [ 

49 path_join(os.path.dirname(__file__), "../config", abs_path=True), 

50 ] 

51 } 

52 ), 

53 ), 

54 ) 

55 

56 

57def test_composite_services(composite_env: CompositeEnv) -> None: 

58 """Check that each environment gets its own instance of the services.""" 

59 for i, path in ((0, "_test_tmp_global"), (1, "_test_tmp_other_2"), (2, "_test_tmp_other_3")): 

60 service = composite_env.children[i]._service # pylint: disable=protected-access 

61 assert service is not None and hasattr(service, "temp_dir_context") 

62 assert isinstance(service, SupportsLocalExec) 

63 with service.temp_dir_context() as temp_dir: 

64 assert os.path.samefile(temp_dir, path) 

65 os.rmdir(path)