Coverage for mlos_bench/mlos_bench/optimizers/mock_optimizer.py: 97%
35 statements
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-14 01:58 +0000
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-14 01:58 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""Mock optimizer for mlos_bench."""
7import logging
8import random
9from typing import Callable, Dict, Optional, Sequence
11from mlos_bench.environments.status import Status
12from mlos_bench.optimizers.track_best_optimizer import TrackBestOptimizer
13from mlos_bench.services.base_service import Service
14from mlos_bench.tunables.tunable import Tunable, TunableValue
15from mlos_bench.tunables.tunable_groups import TunableGroups
17_LOG = logging.getLogger(__name__)
20class MockOptimizer(TrackBestOptimizer):
21 """Mock optimizer to test the Environment API."""
23 def __init__(
24 self,
25 tunables: TunableGroups,
26 config: dict,
27 global_config: Optional[dict] = None,
28 service: Optional[Service] = None,
29 ):
30 super().__init__(tunables, config, global_config, service)
31 rnd = random.Random(self.seed)
32 self._random: Dict[str, Callable[[Tunable], TunableValue]] = {
33 "categorical": lambda tunable: rnd.choice(tunable.categories),
34 "float": lambda tunable: rnd.uniform(*tunable.range),
35 "int": lambda tunable: rnd.randint(*tunable.range),
36 }
38 def bulk_register(
39 self,
40 configs: Sequence[dict],
41 scores: Sequence[Optional[Dict[str, TunableValue]]],
42 status: Optional[Sequence[Status]] = None,
43 ) -> bool:
44 if not super().bulk_register(configs, scores, status):
45 return False
46 if status is None:
47 status = [Status.SUCCEEDED] * len(configs)
48 for params, score, trial_status in zip(configs, scores, status):
49 tunables = self._tunables.copy().assign(params)
50 self.register(tunables, trial_status, score)
51 if _LOG.isEnabledFor(logging.DEBUG):
52 (best_score, _) = self.get_best_observation()
53 _LOG.debug("Bulk register END: %s = %s", self, best_score)
54 return True
56 def suggest(self) -> TunableGroups:
57 """Generate the next (random) suggestion."""
58 tunables = super().suggest()
59 if self._start_with_defaults:
60 _LOG.info("Use default tunable values")
61 self._start_with_defaults = False
62 else:
63 for tunable, _group in tunables:
64 tunable.value = self._random[tunable.type](tunable)
65 _LOG.info("Iteration %d :: Suggest: %s", self._iter, tunables)
66 return tunables