Coverage for mlos_bench/mlos_bench/tests/environments/base_env_test.py: 100%

18 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2025-01-21 01:50 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5"""Unit tests for base environment class functionality.""" 

6 

7 

8import pytest 

9 

10from mlos_bench.environments.base_environment import Environment 

11from mlos_bench.tunables.tunable import TunableValue 

12 

13_GROUPS = { 

14 "group": ["a", "b"], 

15 "list": ["c", "d"], 

16 "str": "efg", 

17 "empty": [], 

18 "other": ["h", "i", "j"], 

19} 

20 

21# pylint: disable=protected-access 

22 

23 

24def test_expand_groups() -> None: 

25 """Check the dollar variable expansion for tunable groups.""" 

26 assert Environment._expand_groups( 

27 [ 

28 "begin", 

29 "$list", 

30 "$empty", 

31 "$str", 

32 "end", 

33 ], 

34 _GROUPS, 

35 ) == [ 

36 "begin", 

37 "c", 

38 "d", 

39 "efg", 

40 "end", 

41 ] 

42 

43 

44def test_expand_groups_empty_input() -> None: 

45 """Make sure an empty group stays empty.""" 

46 assert Environment._expand_groups([], _GROUPS) == [] 

47 

48 

49def test_expand_groups_empty_list() -> None: 

50 """Make sure an empty group expansion works properly.""" 

51 assert not Environment._expand_groups(["$empty"], _GROUPS) 

52 

53 

54def test_expand_groups_unknown() -> None: 

55 """Make sure we fail on unknown $GROUP names expansion.""" 

56 with pytest.raises(KeyError): 

57 Environment._expand_groups(["$list", "$UNKNOWN", "$str", "end"], _GROUPS) 

58 

59 

60def test_expand_const_args() -> None: 

61 """Test expansion of const args via expand_vars.""" 

62 const_args: dict[str, TunableValue] = { 

63 "a": "b", 

64 "foo": "$bar/baz", 

65 "1": 1, 

66 "recursive": "$foo/expansion", 

67 } 

68 global_config: dict[str, TunableValue] = { 

69 "bar": "blah", 

70 } 

71 result = Environment._expand_vars(const_args, global_config) 

72 assert result == { 

73 "a": "b", 

74 "foo": "blah/baz", 

75 "1": 1, 

76 "recursive": "blah/baz/expansion", 

77 }