Coverage for mlos_bench/mlos_bench/environments/mock_env.py: 95%
40 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-10-07 01:52 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-10-07 01:52 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""Scheduler-side environment to mock the benchmark results."""
7import logging
8import random
9from datetime import datetime
10from typing import Dict, Optional, Tuple
12import numpy
14from mlos_bench.environments.base_environment import Environment
15from mlos_bench.environments.status import Status
16from mlos_bench.services.base_service import Service
17from mlos_bench.tunables import Tunable, TunableGroups, TunableValue
19_LOG = logging.getLogger(__name__)
22class MockEnv(Environment):
23 """Scheduler-side environment to mock the benchmark results."""
25 _NOISE_VAR = 0.2
26 """Variance of the Gaussian noise added to the benchmark value."""
28 def __init__( # pylint: disable=too-many-arguments
29 self,
30 *,
31 name: str,
32 config: dict,
33 global_config: Optional[dict] = None,
34 tunables: Optional[TunableGroups] = None,
35 service: Optional[Service] = None,
36 ):
37 """
38 Create a new environment that produces mock benchmark data.
40 Parameters
41 ----------
42 name: str
43 Human-readable name of the environment.
44 config : dict
45 Free-format dictionary that contains the benchmark environment configuration.
46 global_config : dict
47 Free-format dictionary of global parameters (e.g., security credentials)
48 to be mixed in into the "const_args" section of the local config.
49 Optional arguments are `mock_env_seed`, `mock_env_range`, and `mock_env_metrics`.
50 Set `mock_env_seed` to -1 for deterministic behavior, 0 for default randomness.
51 tunables : TunableGroups
52 A collection of tunable parameters for *all* environments.
53 service: Service
54 An optional service object. Not used by this class.
55 """
56 super().__init__(
57 name=name,
58 config=config,
59 global_config=global_config,
60 tunables=tunables,
61 service=service,
62 )
63 seed = int(self.config.get("mock_env_seed", -1))
64 self._random = random.Random(seed or None) if seed >= 0 else None
65 self._range = self.config.get("mock_env_range")
66 self._metrics = self.config.get("mock_env_metrics", ["score"])
67 self._is_ready = True
69 def run(self) -> Tuple[Status, datetime, Optional[Dict[str, TunableValue]]]:
70 """
71 Produce mock benchmark data for one experiment.
73 Returns
74 -------
75 (status, timestamp, output) : (Status, datetime, dict)
76 3-tuple of (Status, timestamp, output) values, where `output` is a dict
77 with the results or None if the status is not COMPLETED.
78 The keys of the `output` dict are the names of the metrics
79 specified in the config; by default it's just one metric
80 named "score". All output metrics have the same value.
81 """
82 (status, timestamp, _) = result = super().run()
83 if not status.is_ready():
84 return result
86 # Simple convex function of all tunable parameters.
87 score = numpy.mean(
88 numpy.square([self._normalized(tunable) for (tunable, _group) in self._tunable_params])
89 )
91 # Add noise and shift the benchmark value from [0, 1] to a given range.
92 noise = self._random.gauss(0, self._NOISE_VAR) if self._random else 0
93 score = numpy.clip(score + noise, 0, 1)
94 if self._range:
95 score = self._range[0] + score * (self._range[1] - self._range[0])
97 return (Status.SUCCEEDED, timestamp, {metric: score for metric in self._metrics})
99 @staticmethod
100 def _normalized(tunable: Tunable) -> float:
101 """
102 Get the NORMALIZED value of a tunable.
104 That is, map current value to the [0, 1] range.
105 """
106 val = None
107 if tunable.is_categorical:
108 val = tunable.categories.index(tunable.category) / float(len(tunable.categories) - 1)
109 elif tunable.is_numerical:
110 val = (tunable.numerical_value - tunable.range[0]) / float(
111 tunable.range[1] - tunable.range[0]
112 )
113 else:
114 raise ValueError("Invalid parameter type: " + tunable.type)
115 # Explicitly clip the value in case of numerical errors.
116 ret: float = numpy.clip(val, 0, 1)
117 return ret