Coverage for mlos_core/mlos_core/spaces/converters/flaml.py: 95%

22 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-10-07 01:52 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5"""Contains space converters for FLAML.""" 

6 

7import sys 

8from typing import TYPE_CHECKING, Dict 

9 

10import ConfigSpace 

11import flaml.tune 

12import flaml.tune.sample 

13import numpy as np 

14 

15if TYPE_CHECKING: 

16 from ConfigSpace.hyperparameters import Hyperparameter 

17 

18if sys.version_info >= (3, 10): 

19 from typing import TypeAlias 

20else: 

21 from typing_extensions import TypeAlias 

22 

23 

24FlamlDomain: TypeAlias = flaml.tune.sample.Domain 

25FlamlSpace: TypeAlias = Dict[str, flaml.tune.sample.Domain] 

26 

27 

28def configspace_to_flaml_space( 

29 config_space: ConfigSpace.ConfigurationSpace, 

30) -> Dict[str, FlamlDomain]: 

31 """ 

32 Converts a ConfigSpace.ConfigurationSpace to dict. 

33 

34 Parameters 

35 ---------- 

36 config_space : ConfigSpace.ConfigurationSpace 

37 Input configuration space. 

38 

39 Returns 

40 ------- 

41 flaml_space : dict 

42 A dictionary of flaml.tune.sample.Domain objects keyed by parameter name. 

43 """ 

44 flaml_numeric_type = { 

45 (ConfigSpace.UniformIntegerHyperparameter, False): flaml.tune.randint, 

46 (ConfigSpace.UniformIntegerHyperparameter, True): flaml.tune.lograndint, 

47 (ConfigSpace.UniformFloatHyperparameter, False): flaml.tune.uniform, 

48 (ConfigSpace.UniformFloatHyperparameter, True): flaml.tune.loguniform, 

49 } 

50 

51 def _one_parameter_convert(parameter: "Hyperparameter") -> FlamlDomain: 

52 if isinstance(parameter, ConfigSpace.UniformFloatHyperparameter): 

53 # FIXME: upper isn't included in the range 

54 return flaml_numeric_type[(type(parameter), parameter.log)]( 

55 parameter.lower, 

56 parameter.upper, 

57 ) 

58 elif isinstance(parameter, ConfigSpace.UniformIntegerHyperparameter): 

59 return flaml_numeric_type[(type(parameter), parameter.log)]( 

60 parameter.lower, 

61 parameter.upper + 1, 

62 ) 

63 elif isinstance(parameter, ConfigSpace.CategoricalHyperparameter): 

64 if len(np.unique(parameter.probabilities)) > 1: 

65 raise ValueError( 

66 "FLAML doesn't support categorical parameters with non-uniform probabilities." 

67 ) 

68 return flaml.tune.choice(parameter.choices) # TODO: set order? 

69 raise ValueError(f"Type of parameter {parameter} ({type(parameter)}) not supported.") 

70 

71 return {param.name: _one_parameter_convert(param) for param in config_space.values()}