Coverage for mlos_bench/mlos_bench/tests/config/schemas/optimizers/test_optimizer_schemas.py: 100%

52 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2025-01-21 01:50 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5"""Tests for optimizer schema validation.""" 

6 

7from os import path 

8 

9import pytest 

10 

11from mlos_bench.config.schemas import ConfigSchema 

12from mlos_bench.optimizers.base_optimizer import Optimizer 

13from mlos_bench.tests import try_resolve_class_name 

14from mlos_bench.tests.config.schemas import ( 

15 check_test_case_against_schema, 

16 check_test_case_config_with_extra_param, 

17 get_schema_test_cases, 

18) 

19from mlos_core.optimizers import OptimizerType 

20from mlos_core.spaces.adapters import SpaceAdapterType 

21from mlos_core.tests import get_all_concrete_subclasses 

22 

23# General testing strategy: 

24# - hand code a set of good/bad configs (useful to test editor schema checking) 

25# - enumerate and try to check that we've covered all the cases 

26# - for each config, load and validate against expected schema 

27 

28TEST_CASES = get_schema_test_cases(path.join(path.dirname(__file__), "test-cases")) 

29 

30 

31# Dynamically enumerate some of the cases we want to make sure we cover. 

32 

33expected_mlos_bench_optimizer_class_names = [ 

34 subclass.__module__ + "." + subclass.__name__ 

35 for subclass in get_all_concrete_subclasses( 

36 Optimizer, # type: ignore[type-abstract] 

37 pkg_name="mlos_bench", 

38 ) 

39] 

40assert expected_mlos_bench_optimizer_class_names 

41 

42# Also make sure that we check for configs where the optimizer_type or 

43# space_adapter_type are left unspecified (None). 

44 

45expected_mlos_core_optimizer_types = list(OptimizerType) + [None] 

46assert expected_mlos_core_optimizer_types 

47 

48expected_mlos_core_space_adapter_types = list(SpaceAdapterType) + [None] 

49assert expected_mlos_core_space_adapter_types 

50 

51 

52# Do the full cross product of all the test cases and all the optimizer types. 

53@pytest.mark.parametrize("test_case_subtype", sorted(TEST_CASES.by_subtype)) 

54@pytest.mark.parametrize("mlos_bench_optimizer_type", expected_mlos_bench_optimizer_class_names) 

55def test_case_coverage_mlos_bench_optimizer_type( 

56 test_case_subtype: str, 

57 mlos_bench_optimizer_type: str, 

58) -> None: 

59 """Checks to see if there is a given type of test case for the given mlos_bench 

60 optimizer type. 

61 """ 

62 for test_case in TEST_CASES.by_subtype[test_case_subtype].values(): 

63 if try_resolve_class_name(test_case.config.get("class")) == mlos_bench_optimizer_type: 

64 return 

65 raise NotImplementedError( 

66 f"Missing test case for subtype {test_case_subtype} " 

67 f"for Optimizer class {mlos_bench_optimizer_type}" 

68 ) 

69 

70 

71# Being a little lazy for the moment and relaxing the requirement that we have 

72# a subtype test case for each optimizer and space adapter combo. 

73 

74 

75@pytest.mark.parametrize("test_case_type", sorted(TEST_CASES.by_type)) 

76# @pytest.mark.parametrize("test_case_subtype", sorted(TEST_CASES.by_subtype)) 

77@pytest.mark.parametrize("mlos_core_optimizer_type", expected_mlos_core_optimizer_types) 

78def test_case_coverage_mlos_core_optimizer_type( 

79 test_case_type: str, 

80 mlos_core_optimizer_type: OptimizerType | None, 

81) -> None: 

82 """Checks to see if there is a given type of test case for the given mlos_core 

83 optimizer type. 

84 """ 

85 optimizer_name = None if mlos_core_optimizer_type is None else mlos_core_optimizer_type.name 

86 for test_case in TEST_CASES.by_type[test_case_type].values(): 

87 if ( 

88 try_resolve_class_name(test_case.config.get("class")) 

89 == "mlos_bench.optimizers.mlos_core_optimizer.MlosCoreOptimizer" 

90 ): 

91 optimizer_type = None 

92 if test_case.config.get("config"): 

93 optimizer_type = test_case.config["config"].get("optimizer_type", None) 

94 if optimizer_type == optimizer_name: 

95 return 

96 raise NotImplementedError( 

97 f"Missing test case for type {test_case_type} " 

98 f"for MlosCore Optimizer type {mlos_core_optimizer_type}" 

99 ) 

100 

101 

102@pytest.mark.parametrize("test_case_type", sorted(TEST_CASES.by_type)) 

103# @pytest.mark.parametrize("test_case_subtype", sorted(TEST_CASES.by_subtype)) 

104@pytest.mark.parametrize("mlos_core_space_adapter_type", expected_mlos_core_space_adapter_types) 

105def test_case_coverage_mlos_core_space_adapter_type( 

106 test_case_type: str, 

107 mlos_core_space_adapter_type: SpaceAdapterType | None, 

108) -> None: 

109 """Checks to see if there is a given type of test case for the given mlos_core space 

110 adapter type. 

111 """ 

112 space_adapter_name = ( 

113 None if mlos_core_space_adapter_type is None else mlos_core_space_adapter_type.name 

114 ) 

115 for test_case in TEST_CASES.by_type[test_case_type].values(): 

116 if ( 

117 try_resolve_class_name(test_case.config.get("class")) 

118 == "mlos_bench.optimizers.mlos_core_optimizer.MlosCoreOptimizer" 

119 ): 

120 space_adapter_type = None 

121 if test_case.config.get("config"): 

122 space_adapter_type = test_case.config["config"].get("space_adapter_type", None) 

123 if space_adapter_type == space_adapter_name: 

124 return 

125 raise NotImplementedError( 

126 f"Missing test case for type {test_case_type} " 

127 f"for SpaceAdapter type {mlos_core_space_adapter_type}" 

128 ) 

129 

130 

131# Now we actually perform all of those validation tests. 

132 

133 

134@pytest.mark.parametrize("test_case_name", sorted(TEST_CASES.by_path)) 

135def test_optimizer_configs_against_schema(test_case_name: str) -> None: 

136 """Checks that the optimizer config validates against the schema.""" 

137 check_test_case_against_schema(TEST_CASES.by_path[test_case_name], ConfigSchema.OPTIMIZER) 

138 check_test_case_against_schema(TEST_CASES.by_path[test_case_name], ConfigSchema.UNIFIED) 

139 

140 

141@pytest.mark.parametrize("test_case_name", sorted(TEST_CASES.by_type["good"])) 

142def test_optimizer_configs_with_extra_param(test_case_name: str) -> None: 

143 """Checks that the optimizer config fails to validate if extra params are present in 

144 certain places. 

145 """ 

146 check_test_case_config_with_extra_param( 

147 TEST_CASES.by_type["good"][test_case_name], 

148 ConfigSchema.OPTIMIZER, 

149 ) 

150 check_test_case_config_with_extra_param( 

151 TEST_CASES.by_type["good"][test_case_name], 

152 ConfigSchema.UNIFIED, 

153 ) 

154 

155 

156if __name__ == "__main__": 

157 pytest.main(args=["-n0", "-k", "grid_search_optimizer"])