Coverage for mlos_bench/mlos_bench/services/base_service.py: 94%
107 statements
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-21 01:50 +0000
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-21 01:50 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""Base class for the service mix-ins."""
7from __future__ import annotations
9import json
10import logging
11from collections.abc import Callable
12from contextlib import AbstractContextManager as ContextManager
13from types import TracebackType
14from typing import Any, Literal
16from mlos_bench.config.schemas import ConfigSchema
17from mlos_bench.services.types.bound_method import BoundMethod
18from mlos_bench.services.types.config_loader_type import SupportsConfigLoading
19from mlos_bench.util import instantiate_from_config
21_LOG = logging.getLogger(__name__)
24class Service(ContextManager):
25 """An abstract base of all Environment Services and used to build up mix-ins."""
27 @classmethod
28 def new(
29 cls,
30 class_name: str,
31 config: dict[str, Any] | None = None,
32 global_config: dict[str, Any] | None = None,
33 parent: Service | None = None,
34 ) -> Service:
35 """
36 Factory method for a new service with a given config.
38 Parameters
39 ----------
40 class_name: str
41 FQN of a Python class to instantiate, e.g.,
42 "mlos_bench.services.remote.azure.AzureVMService".
43 Must be derived from the `Service` class.
44 config : dict
45 Free-format dictionary that contains the service configuration.
46 It will be passed as a constructor parameter of the class
47 specified by `class_name`.
48 global_config : dict
49 Free-format dictionary of global parameters.
50 parent : Service
51 A parent service that can provide mixin functions.
53 Returns
54 -------
55 svc : Service
56 An instance of the `Service` class initialized with `config`.
57 """
58 assert issubclass(cls, Service)
59 return instantiate_from_config(cls, class_name, config, global_config, parent)
61 def __init__(
62 self,
63 config: dict[str, Any] | None = None,
64 global_config: dict[str, Any] | None = None,
65 parent: Service | None = None,
66 methods: dict[str, Callable] | list[Callable] | None = None,
67 ):
68 """
69 Create a new service with a given config.
71 Parameters
72 ----------
73 config : dict
74 Free-format dictionary that contains the service configuration.
75 It will be passed as a constructor parameter of the class
76 specified by `class_name`.
77 global_config : dict
78 Free-format dictionary of global parameters.
79 parent : Service
80 An optional parent service that can provide mixin functions.
81 methods : Union[dict[str, Callable], list[Callable], None]
82 New methods to register with the service.
83 """
84 self.config = config or {}
85 self._validate_json_config(self.config)
86 self._parent = parent
87 self._service_methods: dict[str, Callable] = {}
88 self._services: set[Service] = set()
89 self._service_contexts: list[Service] = []
90 self._in_context = False
92 if parent:
93 self.register(parent.export())
94 if methods:
95 self.register(methods)
97 self._config_loader_service: SupportsConfigLoading
98 if parent and isinstance(parent, SupportsConfigLoading):
99 self._config_loader_service = parent
101 if _LOG.isEnabledFor(logging.DEBUG):
102 _LOG.debug("Service: %s Config:\n%s", self, json.dumps(self.config, indent=2))
103 _LOG.debug("Service: %s Globals:\n%s", self, json.dumps(global_config or {}, indent=2))
104 _LOG.debug("Service: %s Parent: %s", self, parent.pprint() if parent else None)
106 @staticmethod
107 def merge_methods(
108 ext_methods: dict[str, Callable] | list[Callable] | None,
109 local_methods: dict[str, Callable] | list[Callable],
110 ) -> dict[str, Callable]:
111 """
112 Merge methods from the external caller with the local ones.
114 This function is usually called by the derived class constructor just before
115 invoking the constructor of the base class.
116 """
117 if isinstance(local_methods, dict):
118 local_methods = local_methods.copy()
119 else:
120 local_methods = {svc.__name__: svc for svc in local_methods}
122 if not ext_methods:
123 return local_methods
125 if not isinstance(ext_methods, dict):
126 ext_methods = {svc.__name__: svc for svc in ext_methods}
128 local_methods.update(ext_methods)
129 return local_methods
131 def __enter__(self) -> Service:
132 """
133 Enter the Service mix-in context.
135 Calls the _enter_context() method of all the Services registered under this one.
136 """
137 if self._in_context:
138 # Multiple environments can share the same Service, so we need to
139 # add a check and make this a re-entrant Service context.
140 assert self._service_contexts
141 assert all(svc._in_context for svc in self._services)
142 return self
143 self._service_contexts = [svc._enter_context() for svc in self._services]
144 self._in_context = True
145 return self
147 def __exit__(
148 self,
149 ex_type: type[BaseException] | None,
150 ex_val: BaseException | None,
151 ex_tb: TracebackType | None,
152 ) -> Literal[False]:
153 """
154 Exit the Service mix-in context.
156 Calls the _exit_context() method of all the Services registered under this one.
157 """
158 if not self._in_context:
159 # Multiple environments can share the same Service, so we need to
160 # add a check and make this a re-entrant Service context.
161 assert not self._service_contexts
162 assert all(not svc._in_context for svc in self._services)
163 return False
164 ex_throw = None
165 for svc in reversed(self._service_contexts):
166 try:
167 svc._exit_context(ex_type, ex_val, ex_tb)
168 # pylint: disable=broad-exception-caught
169 except Exception as ex:
170 _LOG.error("Exception while exiting Service context '%s': %s", svc, ex)
171 ex_throw = ex
172 self._service_contexts = []
173 if ex_throw:
174 raise ex_throw
175 self._in_context = False
176 return False
178 def _enter_context(self) -> Service:
179 """
180 Enters the context for this particular Service instance.
182 Called by the base __enter__ method of the Service class so it can be used with
183 mix-ins and overridden by subclasses.
184 """
185 assert not self._in_context
186 self._in_context = True
187 return self
189 def _exit_context(
190 self,
191 ex_type: type[BaseException] | None,
192 ex_val: BaseException | None,
193 ex_tb: TracebackType | None,
194 ) -> Literal[False]:
195 """
196 Exits the context for this particular Service instance.
198 Called by the base __enter__ method of the Service class so it can be used with
199 mix-ins and overridden by subclasses.
200 """
201 # pylint: disable=unused-argument
202 assert self._in_context
203 self._in_context = False
204 return False
206 def _validate_json_config(self, config: dict) -> None:
207 """Reconstructs a basic json config that this class might have been instantiated
208 from in order to validate configs provided outside the file loading
209 mechanism.
210 """
211 if self.__class__ == Service:
212 # Skip over the case where instantiate a bare base Service class in
213 # order to build up a mix-in.
214 assert config == {}
215 return
216 json_config: dict = {
217 "class": self.__class__.__module__ + "." + self.__class__.__name__,
218 }
219 if config:
220 json_config["config"] = config
221 ConfigSchema.SERVICE.validate(json_config)
223 def __repr__(self) -> str:
224 return f"{self.__class__.__name__}@{hex(id(self))}"
226 def pprint(self) -> str:
227 """Produce a human-readable string listing all public methods of the service."""
228 return f"{self} ::\n" + "\n".join(
229 f' "{key}": {getattr(val, "__self__", "stand-alone")}'
230 for (key, val) in self._service_methods.items()
231 )
233 @property
234 def config_loader_service(self) -> SupportsConfigLoading:
235 """
236 Return a config loader service.
238 Returns
239 -------
240 config_loader_service : SupportsConfigLoading
241 A config loader service.
242 """
243 return self._config_loader_service
245 def register(self, services: dict[str, Callable] | list[Callable]) -> None:
246 """
247 Register new mix-in services.
249 Parameters
250 ----------
251 services : dict or list
252 A dictionary of string -> function pairs.
253 """
254 if not isinstance(services, dict):
255 services = {svc.__name__: svc for svc in services}
257 self._service_methods.update(services)
258 self.__dict__.update(self._service_methods)
260 if _LOG.isEnabledFor(logging.DEBUG):
261 _LOG.debug("Added methods to: %s", self.pprint())
263 # In order to get a list of all child contexts, we need to look at only
264 # the bound methods that were not overridden by another mixin.
265 # Then we inspect the internally bound __self__ variable to discover
266 # which Service instance that method belongs too.
267 # To do this we also
269 # All service loading must happen prior to entering a context.
270 assert not self._in_context
271 assert not self._service_contexts
272 self._services = {
273 # Enumerate the Services that are bound to this instance in the
274 # order they were added.
275 # Unfortunately, by creating a set, we may destroy the ability to
276 # preserve the context enter/exit order, but hopefully it doesn't
277 # matter.
278 svc_method.__self__
279 for _, svc_method in self._service_methods.items()
280 # Note: some methods are actually stand alone functions, so we need
281 # to filter them out.
282 if isinstance(svc_method, BoundMethod) and isinstance(svc_method.__self__, Service)
283 }
285 def export(self) -> dict[str, Callable]:
286 """
287 Return a dictionary of functions available in this service.
289 Returns
290 -------
291 services : dict
292 A dictionary of string -> function pairs.
293 """
294 if _LOG.isEnabledFor(logging.DEBUG):
295 _LOG.debug("Export methods from: %s", self.pprint())
297 return self._service_methods