Coverage for mlos_bench/mlos_bench/optimizers/track_best_optimizer.py: 97%

38 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2024-12-14 01:58 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5"""Mock optimizer for mlos_bench.""" 

6 

7import logging 

8from abc import ABCMeta 

9from typing import Dict, Optional, Tuple, Union 

10 

11from mlos_bench.environments.status import Status 

12from mlos_bench.optimizers.base_optimizer import Optimizer 

13from mlos_bench.services.base_service import Service 

14from mlos_bench.tunables.tunable import TunableValue 

15from mlos_bench.tunables.tunable_groups import TunableGroups 

16 

17_LOG = logging.getLogger(__name__) 

18 

19 

20class TrackBestOptimizer(Optimizer, metaclass=ABCMeta): 

21 """Base Optimizer class that keeps track of the best score and configuration.""" 

22 

23 def __init__( 

24 self, 

25 tunables: TunableGroups, 

26 config: dict, 

27 global_config: Optional[dict] = None, 

28 service: Optional[Service] = None, 

29 ): 

30 super().__init__(tunables, config, global_config, service) 

31 self._best_config: Optional[TunableGroups] = None 

32 self._best_score: Optional[Dict[str, float]] = None 

33 

34 def register( 

35 self, 

36 tunables: TunableGroups, 

37 status: Status, 

38 score: Optional[Dict[str, TunableValue]] = None, 

39 ) -> Optional[Dict[str, float]]: 

40 registered_score = super().register(tunables, status, score) 

41 if status.is_succeeded() and self._is_better(registered_score): 

42 self._best_score = registered_score 

43 self._best_config = tunables.copy() 

44 return registered_score 

45 

46 def _is_better(self, registered_score: Optional[Dict[str, float]]) -> bool: 

47 """Compare the optimization scores to the best ones so far lexicographically.""" 

48 if self._best_score is None: 

49 return True 

50 assert registered_score is not None 

51 for opt_target, best_score in self._best_score.items(): 

52 score = registered_score[opt_target] 

53 if score < best_score: 

54 return True 

55 elif score > best_score: 

56 return False 

57 return False 

58 

59 def get_best_observation( 

60 self, 

61 ) -> Union[Tuple[Dict[str, float], TunableGroups], Tuple[None, None]]: 

62 if self._best_score is None: 

63 return (None, None) 

64 score = self._get_scores(Status.SUCCEEDED, self._best_score) 

65 assert score is not None 

66 assert self._best_config is not None 

67 return (score, self._best_config)