Coverage for mlos_bench/mlos_bench/optimizers/track_best_optimizer.py: 97%
37 statements
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-21 01:50 +0000
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-21 01:50 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""Mock optimizer for mlos_bench."""
7import logging
8from abc import ABCMeta
10from mlos_bench.environments.status import Status
11from mlos_bench.optimizers.base_optimizer import Optimizer
12from mlos_bench.services.base_service import Service
13from mlos_bench.tunables.tunable import TunableValue
14from mlos_bench.tunables.tunable_groups import TunableGroups
16_LOG = logging.getLogger(__name__)
19class TrackBestOptimizer(Optimizer, metaclass=ABCMeta):
20 """Base Optimizer class that keeps track of the best score and configuration."""
22 def __init__(
23 self,
24 tunables: TunableGroups,
25 config: dict,
26 global_config: dict | None = None,
27 service: Service | None = None,
28 ):
29 super().__init__(tunables, config, global_config, service)
30 self._best_config: TunableGroups | None = None
31 self._best_score: dict[str, float] | None = None
33 def register(
34 self,
35 tunables: TunableGroups,
36 status: Status,
37 score: dict[str, TunableValue] | None = None,
38 ) -> dict[str, float] | None:
39 registered_score = super().register(tunables, status, score)
40 if status.is_succeeded() and self._is_better(registered_score):
41 self._best_score = registered_score
42 self._best_config = tunables.copy()
43 return registered_score
45 def _is_better(self, registered_score: dict[str, float] | None) -> bool:
46 """Compare the optimization scores to the best ones so far lexicographically."""
47 if self._best_score is None:
48 return True
49 assert registered_score is not None
50 for opt_target, best_score in self._best_score.items():
51 score = registered_score[opt_target]
52 if score < best_score:
53 return True
54 elif score > best_score:
55 return False
56 return False
58 def get_best_observation(
59 self,
60 ) -> tuple[dict[str, float], TunableGroups] | tuple[None, None]:
61 if self._best_score is None:
62 return (None, None)
63 score = self._get_scores(Status.SUCCEEDED, self._best_score)
64 assert score is not None
65 assert self._best_config is not None
66 return (score, self._best_config)