Coverage for mlos_bench/mlos_bench/tests/environments/base_env_test.py: 100%

19 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2024-12-14 01:58 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5"""Unit tests for base environment class functionality.""" 

6 

7from typing import Dict 

8 

9import pytest 

10 

11from mlos_bench.environments.base_environment import Environment 

12from mlos_bench.tunables.tunable import TunableValue 

13 

14_GROUPS = { 

15 "group": ["a", "b"], 

16 "list": ["c", "d"], 

17 "str": "efg", 

18 "empty": [], 

19 "other": ["h", "i", "j"], 

20} 

21 

22# pylint: disable=protected-access 

23 

24 

25def test_expand_groups() -> None: 

26 """Check the dollar variable expansion for tunable groups.""" 

27 assert Environment._expand_groups( 

28 [ 

29 "begin", 

30 "$list", 

31 "$empty", 

32 "$str", 

33 "end", 

34 ], 

35 _GROUPS, 

36 ) == [ 

37 "begin", 

38 "c", 

39 "d", 

40 "efg", 

41 "end", 

42 ] 

43 

44 

45def test_expand_groups_empty_input() -> None: 

46 """Make sure an empty group stays empty.""" 

47 assert Environment._expand_groups([], _GROUPS) == [] 

48 

49 

50def test_expand_groups_empty_list() -> None: 

51 """Make sure an empty group expansion works properly.""" 

52 assert not Environment._expand_groups(["$empty"], _GROUPS) 

53 

54 

55def test_expand_groups_unknown() -> None: 

56 """Make sure we fail on unknown $GROUP names expansion.""" 

57 with pytest.raises(KeyError): 

58 Environment._expand_groups(["$list", "$UNKNOWN", "$str", "end"], _GROUPS) 

59 

60 

61def test_expand_const_args() -> None: 

62 """Test expansion of const args via expand_vars.""" 

63 const_args: Dict[str, TunableValue] = { 

64 "a": "b", 

65 "foo": "$bar/baz", 

66 "1": 1, 

67 "recursive": "$foo/expansion", 

68 } 

69 global_config: Dict[str, TunableValue] = { 

70 "bar": "blah", 

71 } 

72 result = Environment._expand_vars(const_args, global_config) 

73 assert result == { 

74 "a": "b", 

75 "foo": "blah/baz", 

76 "1": 1, 

77 "recursive": "blah/baz/expansion", 

78 }