Coverage for mlos_bench/mlos_bench/optimizers/track_best_optimizer.py: 97%

37 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2025-01-21 01:50 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5"""Mock optimizer for mlos_bench.""" 

6 

7import logging 

8from abc import ABCMeta 

9 

10from mlos_bench.environments.status import Status 

11from mlos_bench.optimizers.base_optimizer import Optimizer 

12from mlos_bench.services.base_service import Service 

13from mlos_bench.tunables.tunable import TunableValue 

14from mlos_bench.tunables.tunable_groups import TunableGroups 

15 

16_LOG = logging.getLogger(__name__) 

17 

18 

19class TrackBestOptimizer(Optimizer, metaclass=ABCMeta): 

20 """Base Optimizer class that keeps track of the best score and configuration.""" 

21 

22 def __init__( 

23 self, 

24 tunables: TunableGroups, 

25 config: dict, 

26 global_config: dict | None = None, 

27 service: Service | None = None, 

28 ): 

29 super().__init__(tunables, config, global_config, service) 

30 self._best_config: TunableGroups | None = None 

31 self._best_score: dict[str, float] | None = None 

32 

33 def register( 

34 self, 

35 tunables: TunableGroups, 

36 status: Status, 

37 score: dict[str, TunableValue] | None = None, 

38 ) -> dict[str, float] | None: 

39 registered_score = super().register(tunables, status, score) 

40 if status.is_succeeded() and self._is_better(registered_score): 

41 self._best_score = registered_score 

42 self._best_config = tunables.copy() 

43 return registered_score 

44 

45 def _is_better(self, registered_score: dict[str, float] | None) -> bool: 

46 """Compare the optimization scores to the best ones so far lexicographically.""" 

47 if self._best_score is None: 

48 return True 

49 assert registered_score is not None 

50 for opt_target, best_score in self._best_score.items(): 

51 score = registered_score[opt_target] 

52 if score < best_score: 

53 return True 

54 elif score > best_score: 

55 return False 

56 return False 

57 

58 def get_best_observation( 

59 self, 

60 ) -> tuple[dict[str, float], TunableGroups] | tuple[None, None]: 

61 if self._best_score is None: 

62 return (None, None) 

63 score = self._get_scores(Status.SUCCEEDED, self._best_score) 

64 assert score is not None 

65 assert self._best_config is not None 

66 return (score, self._best_config)