Coverage for mlos_bench/mlos_bench/tests/optimizers/llamatune_opt_test.py: 100%
20 statements
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-14 01:58 +0000
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-14 01:58 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""Unit tests for mock mlos_bench optimizer."""
7import pytest
9from mlos_bench.environments.status import Status
10from mlos_bench.optimizers.mlos_core_optimizer import MlosCoreOptimizer
11from mlos_bench.tests import SEED
12from mlos_bench.tunables.tunable_groups import TunableGroups
14# pylint: disable=redefined-outer-name
17@pytest.fixture
18def llamatune_opt(tunable_groups: TunableGroups) -> MlosCoreOptimizer:
19 """Test fixture for mlos_core SMAC optimizer."""
20 return MlosCoreOptimizer(
21 tunables=tunable_groups,
22 service=None,
23 config={
24 "space_adapter_type": "LLAMATUNE",
25 "space_adapter_config": {
26 "num_low_dims": 2,
27 },
28 "optimization_targets": {"score": "min"},
29 "max_suggestions": 10,
30 "optimizer_type": "SMAC",
31 "seed": SEED,
32 # "start_with_defaults": False,
33 },
34 )
37@pytest.fixture
38def mock_scores() -> list:
39 """A list of fake benchmark scores to test the optimizers."""
40 return [88.88, 66.66, 99.99]
43def test_llamatune_optimizer(llamatune_opt: MlosCoreOptimizer, mock_scores: list) -> None:
44 """Make sure that llamatune+smac optimizer initializes and works correctly."""
45 for score in mock_scores:
46 assert llamatune_opt.not_converged()
47 tunables = llamatune_opt.suggest()
48 # FIXME: Emukit optimizer is not deterministic, so we can't check the tunables here.
49 llamatune_opt.register(tunables, Status.SUCCEEDED, {"score": score})
51 (best_score, best_tunables) = llamatune_opt.get_best_observation()
52 assert best_score is not None and len(best_score) == 1
53 assert isinstance(best_tunables, TunableGroups)
54 assert best_score["score"] == pytest.approx(66.66, 0.01)
57if __name__ == "__main__":
58 # For attaching debugger debugging:
59 pytest.main(["-vv", "-n1", "-k", "test_llamatune_optimizer", __file__])