Coverage for mlos_core/mlos_core/tests/optimizers/data_class_test.py: 100%

88 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2024-12-14 01:58 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5"""Tests for Observation Data Class.""" 

6 

7 

8import pandas as pd 

9import pytest 

10 

11from mlos_core.data_classes import Observation, Observations, Suggestion 

12from mlos_core.util import compare_optional_series 

13 

14# pylint: disable=redefined-outer-name 

15 

16 

17@pytest.fixture 

18def config() -> pd.Series: 

19 """Toy configuration used to build various data classes.""" 

20 return pd.Series( 

21 { 

22 "y": "b", 

23 "x": 0.4, 

24 "z": 3, 

25 } 

26 ) 

27 

28 

29@pytest.fixture 

30def score() -> pd.Series: 

31 """Toy score used for tests.""" 

32 return pd.Series( 

33 { 

34 "main_score": 0.1, 

35 "other_score": 0.2, 

36 } 

37 ) 

38 

39 

40@pytest.fixture 

41def score2() -> pd.Series: 

42 """Toy score used for tests.""" 

43 return pd.Series( 

44 { 

45 "main_score": 0.3, 

46 "other_score": 0.1, 

47 } 

48 ) 

49 

50 

51@pytest.fixture 

52def metadata() -> pd.Series: 

53 """Toy metadata used for tests.""" 

54 return pd.Series( 

55 { 

56 "metadata": "test", 

57 } 

58 ) 

59 

60 

61@pytest.fixture 

62def context() -> pd.Series: 

63 """Toy context used for tests.""" 

64 return pd.Series( 

65 { 

66 "context": "test", 

67 } 

68 ) 

69 

70 

71@pytest.fixture 

72def config2() -> pd.Series: 

73 """An alternative toy configuration used to build various data classes.""" 

74 return pd.Series( 

75 { 

76 "y": "c", 

77 "x": 0.7, 

78 "z": 1, 

79 } 

80 ) 

81 

82 

83@pytest.fixture 

84def observation_with_context( 

85 config: pd.Series, 

86 score: pd.Series, 

87 metadata: pd.Series, 

88 context: pd.Series, 

89) -> Observation: 

90 """Toy observation used for tests.""" 

91 return Observation( 

92 config=config, 

93 score=score, 

94 metadata=metadata, 

95 context=context, 

96 ) 

97 

98 

99@pytest.fixture 

100def observation_without_context(config2: pd.Series, score2: pd.Series) -> Observation: 

101 """Toy observation used for tests.""" 

102 return Observation( 

103 config=config2, 

104 score=score2, 

105 ) 

106 

107 

108@pytest.fixture 

109def observations_with_context(observation_with_context: Observation) -> Observations: 

110 """Toy observation used for tests.""" 

111 return Observations( 

112 observations=[observation_with_context, observation_with_context, observation_with_context] 

113 ) 

114 

115 

116@pytest.fixture 

117def observations_without_context(observation_without_context: Observation) -> Observations: 

118 """Toy observation used for tests.""" 

119 return Observations( 

120 observations=[ 

121 observation_without_context, 

122 observation_without_context, 

123 observation_without_context, 

124 ] 

125 ) 

126 

127 

128@pytest.fixture 

129def suggestion_with_context( 

130 config: pd.Series, 

131 metadata: pd.Series, 

132 context: pd.Series, 

133) -> Suggestion: 

134 """Toy suggestion used for tests.""" 

135 return Suggestion( 

136 config=config, 

137 metadata=metadata, 

138 context=context, 

139 ) 

140 

141 

142@pytest.fixture 

143def suggestion_without_context(config2: pd.Series) -> Suggestion: 

144 """Toy suggestion used for tests.""" 

145 return Suggestion( 

146 config=config2, 

147 ) 

148 

149 

150def test_observation_to_suggestion( 

151 observation_with_context: Observation, 

152 observation_without_context: Observation, 

153) -> None: 

154 """Toy problem to test one-hot encoding of dataframe.""" 

155 for observation in (observation_with_context, observation_without_context): 

156 suggestion = observation.to_suggestion() 

157 assert compare_optional_series(suggestion.config, observation.config) 

158 assert compare_optional_series(suggestion.metadata, observation.metadata) 

159 assert compare_optional_series(suggestion.context, observation.context) 

160 

161 

162def test_observation_equality_operators( 

163 observation_with_context: Observation, 

164 observation_without_context: Observation, 

165) -> None: 

166 """Test equality operators.""" 

167 # pylint: disable=comparison-with-itself 

168 assert observation_with_context == observation_with_context 

169 assert observation_with_context != observation_without_context 

170 assert observation_without_context == observation_without_context 

171 

172 

173def test_observations_init_components( 

174 config: pd.Series, 

175 score: pd.Series, 

176 metadata: pd.Series, 

177 context: pd.Series, 

178) -> None: 

179 """Test Observations class.""" 

180 Observations( 

181 configs=pd.concat([config.to_frame().T, config.to_frame().T]), 

182 scores=pd.concat([score.to_frame().T, score.to_frame().T]), 

183 contexts=pd.concat([context.to_frame().T, context.to_frame().T]), 

184 metadata=pd.concat([metadata.to_frame().T, metadata.to_frame().T]), 

185 ) 

186 

187 

188def test_observations_init_observations(observation_with_context: Observation) -> None: 

189 """Test Observations class.""" 

190 Observations( 

191 observations=[observation_with_context, observation_with_context], 

192 ) 

193 

194 

195def test_observations_init_components_fails( 

196 config: pd.Series, 

197 score: pd.Series, 

198 metadata: pd.Series, 

199 context: pd.Series, 

200) -> None: 

201 """Test Observations class.""" 

202 with pytest.raises(AssertionError): 

203 Observations( 

204 configs=pd.concat([config.to_frame().T]), 

205 scores=pd.concat([score.to_frame().T, score.to_frame().T]), 

206 contexts=pd.concat([context.to_frame().T, context.to_frame().T]), 

207 metadata=pd.concat([metadata.to_frame().T, metadata.to_frame().T]), 

208 ) 

209 with pytest.raises(AssertionError): 

210 Observations( 

211 configs=pd.concat([config.to_frame().T, config.to_frame().T]), 

212 scores=pd.concat([score.to_frame().T]), 

213 contexts=pd.concat([context.to_frame().T, context.to_frame().T]), 

214 metadata=pd.concat([metadata.to_frame().T, metadata.to_frame().T]), 

215 ) 

216 with pytest.raises(AssertionError): 

217 Observations( 

218 configs=pd.concat([config.to_frame().T, config.to_frame().T]), 

219 scores=pd.concat([score.to_frame().T, score.to_frame().T]), 

220 contexts=pd.concat([context.to_frame().T, context.to_frame().T]), 

221 metadata=pd.concat([metadata.to_frame().T]), 

222 ) 

223 with pytest.raises(AssertionError): 

224 Observations( 

225 configs=pd.concat([config.to_frame().T, config.to_frame().T]), 

226 scores=pd.concat([score.to_frame().T, score.to_frame().T]), 

227 contexts=pd.concat([context.to_frame().T]), 

228 metadata=pd.concat([metadata.to_frame().T, metadata.to_frame().T]), 

229 ) 

230 

231 

232def test_observations_append(observation_with_context: Observation) -> None: 

233 """Test Observations class.""" 

234 observations = Observations() 

235 observations.append(observation_with_context) 

236 observations.append(observation_with_context) 

237 assert len(observations) == 2 

238 

239 

240def test_observations_append_fails( 

241 observation_with_context: Observation, 

242 observation_without_context: Observation, 

243) -> None: 

244 """Test Observations class.""" 

245 observations = Observations() 

246 observations.append(observation_with_context) 

247 with pytest.raises(AssertionError): 

248 observations.append(observation_without_context) 

249 

250 

251def test_observations_filter_by_index(observations_with_context: Observations) -> None: 

252 """Test Observations class.""" 

253 assert ( 

254 len( 

255 observations_with_context.filter_by_index(observations_with_context.configs.index[[0]]) 

256 ) 

257 == 1 

258 ) 

259 

260 

261def test_observations_to_list(observations_with_context: Observations) -> None: 

262 """Test Observations class.""" 

263 assert len(list(observations_with_context)) == 3 

264 assert all(isinstance(observation, Observation) for observation in observations_with_context) 

265 

266 

267def test_observations_equality_test( 

268 observations_with_context: Observations, 

269 observations_without_context: Observations, 

270) -> None: 

271 """Test Equality of observations.""" 

272 # pylint: disable=comparison-with-itself 

273 assert observations_with_context == observations_with_context 

274 assert observations_with_context != observations_without_context 

275 assert observations_without_context == observations_without_context 

276 

277 

278def test_suggestion_equality_test( 

279 suggestion_with_context: Suggestion, 

280 suggestion_without_context: Suggestion, 

281) -> None: 

282 """Test Equality of suggestions.""" 

283 # pylint: disable=comparison-with-itself 

284 assert suggestion_with_context == suggestion_with_context 

285 assert suggestion_with_context != suggestion_without_context 

286 assert suggestion_without_context == suggestion_without_context 

287 

288 

289def test_complete_suggestion( 

290 suggestion_with_context: Suggestion, 

291 score: pd.Series, 

292 observation_with_context: Observation, 

293) -> None: 

294 """Test ability to complete suggestions.""" 

295 assert suggestion_with_context.complete(score) == observation_with_context