Coverage for mlos_bench/mlos_bench/tests/environments/base_env_test.py: 100%
19 statements
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-14 01:58 +0000
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-14 01:58 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""Unit tests for base environment class functionality."""
7from typing import Dict
9import pytest
11from mlos_bench.environments.base_environment import Environment
12from mlos_bench.tunables.tunable import TunableValue
14_GROUPS = {
15 "group": ["a", "b"],
16 "list": ["c", "d"],
17 "str": "efg",
18 "empty": [],
19 "other": ["h", "i", "j"],
20}
22# pylint: disable=protected-access
25def test_expand_groups() -> None:
26 """Check the dollar variable expansion for tunable groups."""
27 assert Environment._expand_groups(
28 [
29 "begin",
30 "$list",
31 "$empty",
32 "$str",
33 "end",
34 ],
35 _GROUPS,
36 ) == [
37 "begin",
38 "c",
39 "d",
40 "efg",
41 "end",
42 ]
45def test_expand_groups_empty_input() -> None:
46 """Make sure an empty group stays empty."""
47 assert Environment._expand_groups([], _GROUPS) == []
50def test_expand_groups_empty_list() -> None:
51 """Make sure an empty group expansion works properly."""
52 assert not Environment._expand_groups(["$empty"], _GROUPS)
55def test_expand_groups_unknown() -> None:
56 """Make sure we fail on unknown $GROUP names expansion."""
57 with pytest.raises(KeyError):
58 Environment._expand_groups(["$list", "$UNKNOWN", "$str", "end"], _GROUPS)
61def test_expand_const_args() -> None:
62 """Test expansion of const args via expand_vars."""
63 const_args: Dict[str, TunableValue] = {
64 "a": "b",
65 "foo": "$bar/baz",
66 "1": 1,
67 "recursive": "$foo/expansion",
68 }
69 global_config: Dict[str, TunableValue] = {
70 "bar": "blah",
71 }
72 result = Environment._expand_vars(const_args, global_config)
73 assert result == {
74 "a": "b",
75 "foo": "blah/baz",
76 "1": 1,
77 "recursive": "blah/baz/expansion",
78 }