Coverage for mlos_bench/mlos_bench/services/base_service.py: 94%
104 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-10-07 01:52 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-10-07 01:52 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""Base class for the service mix-ins."""
7import json
8import logging
9from types import TracebackType
10from typing import Any, Callable, Dict, List, Optional, Set, Type, Union
12from typing_extensions import Literal
14from mlos_bench.config.schemas import ConfigSchema
15from mlos_bench.services.types.config_loader_type import SupportsConfigLoading
16from mlos_bench.util import instantiate_from_config
18_LOG = logging.getLogger(__name__)
21class Service:
22 """An abstract base of all Environment Services and used to build up mix-ins."""
24 @classmethod
25 def new(
26 cls,
27 class_name: str,
28 config: Optional[Dict[str, Any]] = None,
29 global_config: Optional[Dict[str, Any]] = None,
30 parent: Optional["Service"] = None,
31 ) -> "Service":
32 """
33 Factory method for a new service with a given config.
35 Parameters
36 ----------
37 class_name: str
38 FQN of a Python class to instantiate, e.g.,
39 "mlos_bench.services.remote.azure.AzureVMService".
40 Must be derived from the `Service` class.
41 config : dict
42 Free-format dictionary that contains the service configuration.
43 It will be passed as a constructor parameter of the class
44 specified by `class_name`.
45 global_config : dict
46 Free-format dictionary of global parameters.
47 parent : Service
48 A parent service that can provide mixin functions.
50 Returns
51 -------
52 svc : Service
53 An instance of the `Service` class initialized with `config`.
54 """
55 assert issubclass(cls, Service)
56 return instantiate_from_config(cls, class_name, config, global_config, parent)
58 def __init__(
59 self,
60 config: Optional[Dict[str, Any]] = None,
61 global_config: Optional[Dict[str, Any]] = None,
62 parent: Optional["Service"] = None,
63 methods: Union[Dict[str, Callable], List[Callable], None] = None,
64 ):
65 """
66 Create a new service with a given config.
68 Parameters
69 ----------
70 config : dict
71 Free-format dictionary that contains the service configuration.
72 It will be passed as a constructor parameter of the class
73 specified by `class_name`.
74 global_config : dict
75 Free-format dictionary of global parameters.
76 parent : Service
77 An optional parent service that can provide mixin functions.
78 methods : Union[Dict[str, Callable], List[Callable], None]
79 New methods to register with the service.
80 """
81 self.config = config or {}
82 self._validate_json_config(self.config)
83 self._parent = parent
84 self._service_methods: Dict[str, Callable] = {}
85 self._services: Set[Service] = set()
86 self._service_contexts: List[Service] = []
87 self._in_context = False
89 if parent:
90 self.register(parent.export())
91 if methods:
92 self.register(methods)
94 self._config_loader_service: SupportsConfigLoading
95 if parent and isinstance(parent, SupportsConfigLoading):
96 self._config_loader_service = parent
98 if _LOG.isEnabledFor(logging.DEBUG):
99 _LOG.debug("Service: %s Config:\n%s", self, json.dumps(self.config, indent=2))
100 _LOG.debug("Service: %s Globals:\n%s", self, json.dumps(global_config or {}, indent=2))
101 _LOG.debug("Service: %s Parent: %s", self, parent.pprint() if parent else None)
103 @staticmethod
104 def merge_methods(
105 ext_methods: Union[Dict[str, Callable], List[Callable], None],
106 local_methods: Union[Dict[str, Callable], List[Callable]],
107 ) -> Dict[str, Callable]:
108 """
109 Merge methods from the external caller with the local ones.
111 This function is usually called by the derived class constructor just before
112 invoking the constructor of the base class.
113 """
114 if isinstance(local_methods, dict):
115 local_methods = local_methods.copy()
116 else:
117 local_methods = {svc.__name__: svc for svc in local_methods}
119 if not ext_methods:
120 return local_methods
122 if not isinstance(ext_methods, dict):
123 ext_methods = {svc.__name__: svc for svc in ext_methods}
125 local_methods.update(ext_methods)
126 return local_methods
128 def __enter__(self) -> "Service":
129 """
130 Enter the Service mix-in context.
132 Calls the _enter_context() method of all the Services registered under this one.
133 """
134 if self._in_context:
135 # Multiple environments can share the same Service, so we need to
136 # add a check and make this a re-entrant Service context.
137 assert self._service_contexts
138 assert all(svc._in_context for svc in self._services)
139 return self
140 self._service_contexts = [svc._enter_context() for svc in self._services]
141 self._in_context = True
142 return self
144 def __exit__(
145 self,
146 ex_type: Optional[Type[BaseException]],
147 ex_val: Optional[BaseException],
148 ex_tb: Optional[TracebackType],
149 ) -> Literal[False]:
150 """
151 Exit the Service mix-in context.
153 Calls the _exit_context() method of all the Services registered under this one.
154 """
155 if not self._in_context:
156 # Multiple environments can share the same Service, so we need to
157 # add a check and make this a re-entrant Service context.
158 assert not self._service_contexts
159 assert all(not svc._in_context for svc in self._services)
160 return False
161 ex_throw = None
162 for svc in reversed(self._service_contexts):
163 try:
164 svc._exit_context(ex_type, ex_val, ex_tb)
165 # pylint: disable=broad-exception-caught
166 except Exception as ex:
167 _LOG.error("Exception while exiting Service context '%s': %s", svc, ex)
168 ex_throw = ex
169 self._service_contexts = []
170 if ex_throw:
171 raise ex_throw
172 self._in_context = False
173 return False
175 def _enter_context(self) -> "Service":
176 """
177 Enters the context for this particular Service instance.
179 Called by the base __enter__ method of the Service class so it can be used with
180 mix-ins and overridden by subclasses.
181 """
182 assert not self._in_context
183 self._in_context = True
184 return self
186 def _exit_context(
187 self,
188 ex_type: Optional[Type[BaseException]],
189 ex_val: Optional[BaseException],
190 ex_tb: Optional[TracebackType],
191 ) -> Literal[False]:
192 """
193 Exits the context for this particular Service instance.
195 Called by the base __enter__ method of the Service class so it can be used with
196 mix-ins and overridden by subclasses.
197 """
198 # pylint: disable=unused-argument
199 assert self._in_context
200 self._in_context = False
201 return False
203 def _validate_json_config(self, config: dict) -> None:
204 """Reconstructs a basic json config that this class might have been instantiated
205 from in order to validate configs provided outside the file loading
206 mechanism.
207 """
208 if self.__class__ == Service:
209 # Skip over the case where instantiate a bare base Service class in
210 # order to build up a mix-in.
211 assert config == {}
212 return
213 json_config: dict = {
214 "class": self.__class__.__module__ + "." + self.__class__.__name__,
215 }
216 if config:
217 json_config["config"] = config
218 ConfigSchema.SERVICE.validate(json_config)
220 def __repr__(self) -> str:
221 return f"{self.__class__.__name__}@{hex(id(self))}"
223 def pprint(self) -> str:
224 """Produce a human-readable string listing all public methods of the service."""
225 return f"{self} ::\n" + "\n".join(
226 f' "{key}": {getattr(val, "__self__", "stand-alone")}'
227 for (key, val) in self._service_methods.items()
228 )
230 @property
231 def config_loader_service(self) -> SupportsConfigLoading:
232 """
233 Return a config loader service.
235 Returns
236 -------
237 config_loader_service : SupportsConfigLoading
238 A config loader service.
239 """
240 return self._config_loader_service
242 def register(self, services: Union[Dict[str, Callable], List[Callable]]) -> None:
243 """
244 Register new mix-in services.
246 Parameters
247 ----------
248 services : dict or list
249 A dictionary of string -> function pairs.
250 """
251 if not isinstance(services, dict):
252 services = {svc.__name__: svc for svc in services}
254 self._service_methods.update(services)
255 self.__dict__.update(self._service_methods)
257 if _LOG.isEnabledFor(logging.DEBUG):
258 _LOG.debug("Added methods to: %s", self.pprint())
260 # In order to get a list of all child contexts, we need to look at only
261 # the bound methods that were not overridden by another mixin.
262 # Then we inspect the internally bound __self__ variable to discover
263 # which Service instance that method belongs too.
264 # To do this we also
266 # All service loading must happen prior to entering a context.
267 assert not self._in_context
268 assert not self._service_contexts
269 self._services = {
270 # Enumerate the Services that are bound to this instance in the
271 # order they were added.
272 # Unfortunately, by creating a set, we may destroy the ability to
273 # preserve the context enter/exit order, but hopefully it doesn't
274 # matter.
275 svc_method.__self__
276 for _, svc_method in self._service_methods.items()
277 # Note: some methods are actually stand alone functions, so we need
278 # to filter them out.
279 if hasattr(svc_method, "__self__") and isinstance(svc_method.__self__, Service)
280 }
282 def export(self) -> Dict[str, Callable]:
283 """
284 Return a dictionary of functions available in this service.
286 Returns
287 -------
288 services : dict
289 A dictionary of string -> function pairs.
290 """
291 if _LOG.isEnabledFor(logging.DEBUG):
292 _LOG.debug("Export methods from: %s", self.pprint())
294 return self._service_methods