Coverage for mlos_bench/mlos_bench/tests/tunables/tunable_to_configspace_distr_test.py: 100%
23 statements
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-14 01:58 +0000
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-14 01:58 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""Unit tests for converting tunable parameters with explicitly specified distributions
6to ConfigSpace.
7"""
9import pytest
10from ConfigSpace import (
11 BetaFloatHyperparameter,
12 BetaIntegerHyperparameter,
13 CategoricalHyperparameter,
14 NormalFloatHyperparameter,
15 NormalIntegerHyperparameter,
16 UniformFloatHyperparameter,
17 UniformIntegerHyperparameter,
18)
20from mlos_bench.optimizers.convert_configspace import (
21 special_param_names,
22 tunable_groups_to_configspace,
23)
24from mlos_bench.tunables.tunable import DistributionName
25from mlos_bench.tunables.tunable_groups import TunableGroups
27_CS_HYPERPARAMETER = {
28 ("float", "beta"): BetaFloatHyperparameter,
29 ("int", "beta"): BetaIntegerHyperparameter,
30 ("float", "normal"): NormalFloatHyperparameter,
31 ("int", "normal"): NormalIntegerHyperparameter,
32 ("float", "uniform"): UniformFloatHyperparameter,
33 ("int", "uniform"): UniformIntegerHyperparameter,
34}
37@pytest.mark.parametrize("param_type", ["int", "float"])
38@pytest.mark.parametrize(
39 "distr_name,distr_params",
40 [
41 ("normal", {"mu": 0.0, "sigma": 1.0}),
42 ("beta", {"alpha": 2, "beta": 5}),
43 ("uniform", {}),
44 ],
45)
46def test_convert_numerical_distributions(
47 param_type: str,
48 distr_name: DistributionName,
49 distr_params: dict,
50) -> None:
51 """Convert a numerical Tunable with explicit distribution to ConfigSpace."""
52 tunable_name = "x"
53 tunable_groups = TunableGroups(
54 {
55 "tunable_group": {
56 "cost": 1,
57 "params": {
58 tunable_name: {
59 "type": param_type,
60 "range": [0, 100],
61 "special": [-1, 0],
62 "special_weights": [0.1, 0.2],
63 "range_weight": 0.7,
64 "distribution": {"type": distr_name, "params": distr_params},
65 "default": 0,
66 }
67 },
68 }
69 }
70 )
72 (tunable, _group) = tunable_groups.get_tunable(tunable_name)
73 assert tunable.distribution == distr_name
74 assert tunable.distribution_params == distr_params
76 space = tunable_groups_to_configspace(tunable_groups)
78 (tunable_special, tunable_type) = special_param_names(tunable_name)
79 assert set(space.keys()) == {tunable_name, tunable_type, tunable_special}
81 assert isinstance(space[tunable_special], CategoricalHyperparameter)
82 assert isinstance(space[tunable_type], CategoricalHyperparameter)
84 cs_param = space[tunable_name]
85 assert isinstance(cs_param, _CS_HYPERPARAMETER[param_type, distr_name])
86 for key, val in distr_params.items():
87 assert getattr(cs_param, key) == val