Coverage for mlos_bench/mlos_bench/tests/tunables/tunable_to_configspace_distr_test.py: 100%

23 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2024-12-14 01:58 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5"""Unit tests for converting tunable parameters with explicitly specified distributions 

6to ConfigSpace. 

7""" 

8 

9import pytest 

10from ConfigSpace import ( 

11 BetaFloatHyperparameter, 

12 BetaIntegerHyperparameter, 

13 CategoricalHyperparameter, 

14 NormalFloatHyperparameter, 

15 NormalIntegerHyperparameter, 

16 UniformFloatHyperparameter, 

17 UniformIntegerHyperparameter, 

18) 

19 

20from mlos_bench.optimizers.convert_configspace import ( 

21 special_param_names, 

22 tunable_groups_to_configspace, 

23) 

24from mlos_bench.tunables.tunable import DistributionName 

25from mlos_bench.tunables.tunable_groups import TunableGroups 

26 

27_CS_HYPERPARAMETER = { 

28 ("float", "beta"): BetaFloatHyperparameter, 

29 ("int", "beta"): BetaIntegerHyperparameter, 

30 ("float", "normal"): NormalFloatHyperparameter, 

31 ("int", "normal"): NormalIntegerHyperparameter, 

32 ("float", "uniform"): UniformFloatHyperparameter, 

33 ("int", "uniform"): UniformIntegerHyperparameter, 

34} 

35 

36 

37@pytest.mark.parametrize("param_type", ["int", "float"]) 

38@pytest.mark.parametrize( 

39 "distr_name,distr_params", 

40 [ 

41 ("normal", {"mu": 0.0, "sigma": 1.0}), 

42 ("beta", {"alpha": 2, "beta": 5}), 

43 ("uniform", {}), 

44 ], 

45) 

46def test_convert_numerical_distributions( 

47 param_type: str, 

48 distr_name: DistributionName, 

49 distr_params: dict, 

50) -> None: 

51 """Convert a numerical Tunable with explicit distribution to ConfigSpace.""" 

52 tunable_name = "x" 

53 tunable_groups = TunableGroups( 

54 { 

55 "tunable_group": { 

56 "cost": 1, 

57 "params": { 

58 tunable_name: { 

59 "type": param_type, 

60 "range": [0, 100], 

61 "special": [-1, 0], 

62 "special_weights": [0.1, 0.2], 

63 "range_weight": 0.7, 

64 "distribution": {"type": distr_name, "params": distr_params}, 

65 "default": 0, 

66 } 

67 }, 

68 } 

69 } 

70 ) 

71 

72 (tunable, _group) = tunable_groups.get_tunable(tunable_name) 

73 assert tunable.distribution == distr_name 

74 assert tunable.distribution_params == distr_params 

75 

76 space = tunable_groups_to_configspace(tunable_groups) 

77 

78 (tunable_special, tunable_type) = special_param_names(tunable_name) 

79 assert set(space.keys()) == {tunable_name, tunable_type, tunable_special} 

80 

81 assert isinstance(space[tunable_special], CategoricalHyperparameter) 

82 assert isinstance(space[tunable_type], CategoricalHyperparameter) 

83 

84 cs_param = space[tunable_name] 

85 assert isinstance(cs_param, _CS_HYPERPARAMETER[param_type, distr_name]) 

86 for key, val in distr_params.items(): 

87 assert getattr(cs_param, key) == val