Coverage for mlos_core/mlos_core/spaces/converters/flaml.py: 96%

24 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2024-12-14 01:58 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5"""Contains space converters for 

6:py:class:`~mlos_core.optimizers.flaml_optimizer.FlamlOptimizer` 

7""" 

8 

9import sys 

10from typing import TYPE_CHECKING, Dict 

11 

12import ConfigSpace 

13import flaml.tune 

14import flaml.tune.sample 

15import numpy as np 

16 

17if TYPE_CHECKING: 

18 from ConfigSpace.hyperparameters import Hyperparameter 

19 

20if sys.version_info >= (3, 10): 

21 from typing import TypeAlias 

22else: 

23 from typing_extensions import TypeAlias 

24 

25 

26FlamlDomain: TypeAlias = flaml.tune.sample.Domain 

27"""Flaml domain type alias.""" 

28 

29FlamlSpace: TypeAlias = Dict[str, flaml.tune.sample.Domain] 

30"""Flaml space type alias - a `Dict[str, FlamlDomain]`""" 

31 

32 

33def configspace_to_flaml_space( 

34 config_space: ConfigSpace.ConfigurationSpace, 

35) -> Dict[str, FlamlDomain]: 

36 """ 

37 Converts a ConfigSpace.ConfigurationSpace to dict. 

38 

39 Parameters 

40 ---------- 

41 config_space : ConfigSpace.ConfigurationSpace 

42 Input configuration space. 

43 

44 Returns 

45 ------- 

46 flaml_space : dict 

47 A dictionary of flaml.tune.sample.Domain objects keyed by parameter name. 

48 """ 

49 flaml_numeric_type = { 

50 (ConfigSpace.UniformIntegerHyperparameter, False): flaml.tune.randint, 

51 (ConfigSpace.UniformIntegerHyperparameter, True): flaml.tune.lograndint, 

52 (ConfigSpace.UniformFloatHyperparameter, False): flaml.tune.uniform, 

53 (ConfigSpace.UniformFloatHyperparameter, True): flaml.tune.loguniform, 

54 } 

55 

56 def _one_parameter_convert(parameter: "Hyperparameter") -> FlamlDomain: 

57 if isinstance(parameter, ConfigSpace.UniformFloatHyperparameter): 

58 # FIXME: upper isn't included in the range 

59 return flaml_numeric_type[(type(parameter), parameter.log)]( 

60 parameter.lower, 

61 parameter.upper, 

62 ) 

63 elif isinstance(parameter, ConfigSpace.UniformIntegerHyperparameter): 

64 return flaml_numeric_type[(type(parameter), parameter.log)]( 

65 parameter.lower, 

66 parameter.upper + 1, 

67 ) 

68 elif isinstance(parameter, ConfigSpace.CategoricalHyperparameter): 

69 if len(np.unique(parameter.probabilities)) > 1: 

70 raise ValueError( 

71 "FLAML doesn't support categorical parameters with non-uniform probabilities." 

72 ) 

73 return flaml.tune.choice(parameter.choices) # TODO: set order? 

74 raise ValueError(f"Type of parameter {parameter} ({type(parameter)}) not supported.") 

75 

76 return {param.name: _one_parameter_convert(param) for param in config_space.values()}