Coverage for mlos_bench/mlos_bench/tests/optimizers/llamatune_opt_test.py: 100%

20 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-10-07 01:52 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5"""Unit tests for mock mlos_bench optimizer.""" 

6 

7import pytest 

8 

9from mlos_bench.environments.status import Status 

10from mlos_bench.optimizers.mlos_core_optimizer import MlosCoreOptimizer 

11from mlos_bench.tests import SEED 

12from mlos_bench.tunables.tunable_groups import TunableGroups 

13 

14# pylint: disable=redefined-outer-name 

15 

16 

17@pytest.fixture 

18def llamatune_opt(tunable_groups: TunableGroups) -> MlosCoreOptimizer: 

19 """Test fixture for mlos_core SMAC optimizer.""" 

20 return MlosCoreOptimizer( 

21 tunables=tunable_groups, 

22 service=None, 

23 config={ 

24 "space_adapter_type": "LLAMATUNE", 

25 "space_adapter_config": { 

26 "num_low_dims": 2, 

27 }, 

28 "optimization_targets": {"score": "min"}, 

29 "max_suggestions": 10, 

30 "optimizer_type": "SMAC", 

31 "seed": SEED, 

32 # "start_with_defaults": False, 

33 }, 

34 ) 

35 

36 

37@pytest.fixture 

38def mock_scores() -> list: 

39 """A list of fake benchmark scores to test the optimizers.""" 

40 return [88.88, 66.66, 99.99] 

41 

42 

43def test_llamatune_optimizer(llamatune_opt: MlosCoreOptimizer, mock_scores: list) -> None: 

44 """Make sure that llamatune+smac optimizer initializes and works correctly.""" 

45 for score in mock_scores: 

46 assert llamatune_opt.not_converged() 

47 tunables = llamatune_opt.suggest() 

48 # FIXME: Emukit optimizer is not deterministic, so we can't check the tunables here. 

49 llamatune_opt.register(tunables, Status.SUCCEEDED, {"score": score}) 

50 

51 (best_score, best_tunables) = llamatune_opt.get_best_observation() 

52 assert best_score is not None and len(best_score) == 1 

53 assert isinstance(best_tunables, TunableGroups) 

54 assert best_score["score"] == pytest.approx(66.66, 0.01) 

55 

56 

57if __name__ == "__main__": 

58 # For attaching debugger debugging: 

59 pytest.main(["-vv", "-n1", "-k", "test_llamatune_optimizer", __file__])