Coverage for mlos_core/mlos_core/spaces/converters/flaml.py: 95%
22 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-10-07 01:52 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-10-07 01:52 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""Contains space converters for FLAML."""
7import sys
8from typing import TYPE_CHECKING, Dict
10import ConfigSpace
11import flaml.tune
12import flaml.tune.sample
13import numpy as np
15if TYPE_CHECKING:
16 from ConfigSpace.hyperparameters import Hyperparameter
18if sys.version_info >= (3, 10):
19 from typing import TypeAlias
20else:
21 from typing_extensions import TypeAlias
24FlamlDomain: TypeAlias = flaml.tune.sample.Domain
25FlamlSpace: TypeAlias = Dict[str, flaml.tune.sample.Domain]
28def configspace_to_flaml_space(
29 config_space: ConfigSpace.ConfigurationSpace,
30) -> Dict[str, FlamlDomain]:
31 """
32 Converts a ConfigSpace.ConfigurationSpace to dict.
34 Parameters
35 ----------
36 config_space : ConfigSpace.ConfigurationSpace
37 Input configuration space.
39 Returns
40 -------
41 flaml_space : dict
42 A dictionary of flaml.tune.sample.Domain objects keyed by parameter name.
43 """
44 flaml_numeric_type = {
45 (ConfigSpace.UniformIntegerHyperparameter, False): flaml.tune.randint,
46 (ConfigSpace.UniformIntegerHyperparameter, True): flaml.tune.lograndint,
47 (ConfigSpace.UniformFloatHyperparameter, False): flaml.tune.uniform,
48 (ConfigSpace.UniformFloatHyperparameter, True): flaml.tune.loguniform,
49 }
51 def _one_parameter_convert(parameter: "Hyperparameter") -> FlamlDomain:
52 if isinstance(parameter, ConfigSpace.UniformFloatHyperparameter):
53 # FIXME: upper isn't included in the range
54 return flaml_numeric_type[(type(parameter), parameter.log)](
55 parameter.lower,
56 parameter.upper,
57 )
58 elif isinstance(parameter, ConfigSpace.UniformIntegerHyperparameter):
59 return flaml_numeric_type[(type(parameter), parameter.log)](
60 parameter.lower,
61 parameter.upper + 1,
62 )
63 elif isinstance(parameter, ConfigSpace.CategoricalHyperparameter):
64 if len(np.unique(parameter.probabilities)) > 1:
65 raise ValueError(
66 "FLAML doesn't support categorical parameters with non-uniform probabilities."
67 )
68 return flaml.tune.choice(parameter.choices) # TODO: set order?
69 raise ValueError(f"Type of parameter {parameter} ({type(parameter)}) not supported.")
71 return {param.name: _one_parameter_convert(param) for param in config_space.values()}