Coverage for mlos_core/mlos_core/spaces/converters/flaml.py: 96%
24 statements
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-14 01:58 +0000
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-14 01:58 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""Contains space converters for
6:py:class:`~mlos_core.optimizers.flaml_optimizer.FlamlOptimizer`
7"""
9import sys
10from typing import TYPE_CHECKING, Dict
12import ConfigSpace
13import flaml.tune
14import flaml.tune.sample
15import numpy as np
17if TYPE_CHECKING:
18 from ConfigSpace.hyperparameters import Hyperparameter
20if sys.version_info >= (3, 10):
21 from typing import TypeAlias
22else:
23 from typing_extensions import TypeAlias
26FlamlDomain: TypeAlias = flaml.tune.sample.Domain
27"""Flaml domain type alias."""
29FlamlSpace: TypeAlias = Dict[str, flaml.tune.sample.Domain]
30"""Flaml space type alias - a `Dict[str, FlamlDomain]`"""
33def configspace_to_flaml_space(
34 config_space: ConfigSpace.ConfigurationSpace,
35) -> Dict[str, FlamlDomain]:
36 """
37 Converts a ConfigSpace.ConfigurationSpace to dict.
39 Parameters
40 ----------
41 config_space : ConfigSpace.ConfigurationSpace
42 Input configuration space.
44 Returns
45 -------
46 flaml_space : dict
47 A dictionary of flaml.tune.sample.Domain objects keyed by parameter name.
48 """
49 flaml_numeric_type = {
50 (ConfigSpace.UniformIntegerHyperparameter, False): flaml.tune.randint,
51 (ConfigSpace.UniformIntegerHyperparameter, True): flaml.tune.lograndint,
52 (ConfigSpace.UniformFloatHyperparameter, False): flaml.tune.uniform,
53 (ConfigSpace.UniformFloatHyperparameter, True): flaml.tune.loguniform,
54 }
56 def _one_parameter_convert(parameter: "Hyperparameter") -> FlamlDomain:
57 if isinstance(parameter, ConfigSpace.UniformFloatHyperparameter):
58 # FIXME: upper isn't included in the range
59 return flaml_numeric_type[(type(parameter), parameter.log)](
60 parameter.lower,
61 parameter.upper,
62 )
63 elif isinstance(parameter, ConfigSpace.UniformIntegerHyperparameter):
64 return flaml_numeric_type[(type(parameter), parameter.log)](
65 parameter.lower,
66 parameter.upper + 1,
67 )
68 elif isinstance(parameter, ConfigSpace.CategoricalHyperparameter):
69 if len(np.unique(parameter.probabilities)) > 1:
70 raise ValueError(
71 "FLAML doesn't support categorical parameters with non-uniform probabilities."
72 )
73 return flaml.tune.choice(parameter.choices) # TODO: set order?
74 raise ValueError(f"Type of parameter {parameter} ({type(parameter)}) not supported.")
76 return {param.name: _one_parameter_convert(param) for param in config_space.values()}