Coverage for mlos_bench/mlos_bench/services/base_service.py: 94%

107 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2025-01-21 01:50 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5"""Base class for the service mix-ins.""" 

6 

7from __future__ import annotations 

8 

9import json 

10import logging 

11from collections.abc import Callable 

12from contextlib import AbstractContextManager as ContextManager 

13from types import TracebackType 

14from typing import Any, Literal 

15 

16from mlos_bench.config.schemas import ConfigSchema 

17from mlos_bench.services.types.bound_method import BoundMethod 

18from mlos_bench.services.types.config_loader_type import SupportsConfigLoading 

19from mlos_bench.util import instantiate_from_config 

20 

21_LOG = logging.getLogger(__name__) 

22 

23 

24class Service(ContextManager): 

25 """An abstract base of all Environment Services and used to build up mix-ins.""" 

26 

27 @classmethod 

28 def new( 

29 cls, 

30 class_name: str, 

31 config: dict[str, Any] | None = None, 

32 global_config: dict[str, Any] | None = None, 

33 parent: Service | None = None, 

34 ) -> Service: 

35 """ 

36 Factory method for a new service with a given config. 

37 

38 Parameters 

39 ---------- 

40 class_name: str 

41 FQN of a Python class to instantiate, e.g., 

42 "mlos_bench.services.remote.azure.AzureVMService". 

43 Must be derived from the `Service` class. 

44 config : dict 

45 Free-format dictionary that contains the service configuration. 

46 It will be passed as a constructor parameter of the class 

47 specified by `class_name`. 

48 global_config : dict 

49 Free-format dictionary of global parameters. 

50 parent : Service 

51 A parent service that can provide mixin functions. 

52 

53 Returns 

54 ------- 

55 svc : Service 

56 An instance of the `Service` class initialized with `config`. 

57 """ 

58 assert issubclass(cls, Service) 

59 return instantiate_from_config(cls, class_name, config, global_config, parent) 

60 

61 def __init__( 

62 self, 

63 config: dict[str, Any] | None = None, 

64 global_config: dict[str, Any] | None = None, 

65 parent: Service | None = None, 

66 methods: dict[str, Callable] | list[Callable] | None = None, 

67 ): 

68 """ 

69 Create a new service with a given config. 

70 

71 Parameters 

72 ---------- 

73 config : dict 

74 Free-format dictionary that contains the service configuration. 

75 It will be passed as a constructor parameter of the class 

76 specified by `class_name`. 

77 global_config : dict 

78 Free-format dictionary of global parameters. 

79 parent : Service 

80 An optional parent service that can provide mixin functions. 

81 methods : Union[dict[str, Callable], list[Callable], None] 

82 New methods to register with the service. 

83 """ 

84 self.config = config or {} 

85 self._validate_json_config(self.config) 

86 self._parent = parent 

87 self._service_methods: dict[str, Callable] = {} 

88 self._services: set[Service] = set() 

89 self._service_contexts: list[Service] = [] 

90 self._in_context = False 

91 

92 if parent: 

93 self.register(parent.export()) 

94 if methods: 

95 self.register(methods) 

96 

97 self._config_loader_service: SupportsConfigLoading 

98 if parent and isinstance(parent, SupportsConfigLoading): 

99 self._config_loader_service = parent 

100 

101 if _LOG.isEnabledFor(logging.DEBUG): 

102 _LOG.debug("Service: %s Config:\n%s", self, json.dumps(self.config, indent=2)) 

103 _LOG.debug("Service: %s Globals:\n%s", self, json.dumps(global_config or {}, indent=2)) 

104 _LOG.debug("Service: %s Parent: %s", self, parent.pprint() if parent else None) 

105 

106 @staticmethod 

107 def merge_methods( 

108 ext_methods: dict[str, Callable] | list[Callable] | None, 

109 local_methods: dict[str, Callable] | list[Callable], 

110 ) -> dict[str, Callable]: 

111 """ 

112 Merge methods from the external caller with the local ones. 

113 

114 This function is usually called by the derived class constructor just before 

115 invoking the constructor of the base class. 

116 """ 

117 if isinstance(local_methods, dict): 

118 local_methods = local_methods.copy() 

119 else: 

120 local_methods = {svc.__name__: svc for svc in local_methods} 

121 

122 if not ext_methods: 

123 return local_methods 

124 

125 if not isinstance(ext_methods, dict): 

126 ext_methods = {svc.__name__: svc for svc in ext_methods} 

127 

128 local_methods.update(ext_methods) 

129 return local_methods 

130 

131 def __enter__(self) -> Service: 

132 """ 

133 Enter the Service mix-in context. 

134 

135 Calls the _enter_context() method of all the Services registered under this one. 

136 """ 

137 if self._in_context: 

138 # Multiple environments can share the same Service, so we need to 

139 # add a check and make this a re-entrant Service context. 

140 assert self._service_contexts 

141 assert all(svc._in_context for svc in self._services) 

142 return self 

143 self._service_contexts = [svc._enter_context() for svc in self._services] 

144 self._in_context = True 

145 return self 

146 

147 def __exit__( 

148 self, 

149 ex_type: type[BaseException] | None, 

150 ex_val: BaseException | None, 

151 ex_tb: TracebackType | None, 

152 ) -> Literal[False]: 

153 """ 

154 Exit the Service mix-in context. 

155 

156 Calls the _exit_context() method of all the Services registered under this one. 

157 """ 

158 if not self._in_context: 

159 # Multiple environments can share the same Service, so we need to 

160 # add a check and make this a re-entrant Service context. 

161 assert not self._service_contexts 

162 assert all(not svc._in_context for svc in self._services) 

163 return False 

164 ex_throw = None 

165 for svc in reversed(self._service_contexts): 

166 try: 

167 svc._exit_context(ex_type, ex_val, ex_tb) 

168 # pylint: disable=broad-exception-caught 

169 except Exception as ex: 

170 _LOG.error("Exception while exiting Service context '%s': %s", svc, ex) 

171 ex_throw = ex 

172 self._service_contexts = [] 

173 if ex_throw: 

174 raise ex_throw 

175 self._in_context = False 

176 return False 

177 

178 def _enter_context(self) -> Service: 

179 """ 

180 Enters the context for this particular Service instance. 

181 

182 Called by the base __enter__ method of the Service class so it can be used with 

183 mix-ins and overridden by subclasses. 

184 """ 

185 assert not self._in_context 

186 self._in_context = True 

187 return self 

188 

189 def _exit_context( 

190 self, 

191 ex_type: type[BaseException] | None, 

192 ex_val: BaseException | None, 

193 ex_tb: TracebackType | None, 

194 ) -> Literal[False]: 

195 """ 

196 Exits the context for this particular Service instance. 

197 

198 Called by the base __enter__ method of the Service class so it can be used with 

199 mix-ins and overridden by subclasses. 

200 """ 

201 # pylint: disable=unused-argument 

202 assert self._in_context 

203 self._in_context = False 

204 return False 

205 

206 def _validate_json_config(self, config: dict) -> None: 

207 """Reconstructs a basic json config that this class might have been instantiated 

208 from in order to validate configs provided outside the file loading 

209 mechanism. 

210 """ 

211 if self.__class__ == Service: 

212 # Skip over the case where instantiate a bare base Service class in 

213 # order to build up a mix-in. 

214 assert config == {} 

215 return 

216 json_config: dict = { 

217 "class": self.__class__.__module__ + "." + self.__class__.__name__, 

218 } 

219 if config: 

220 json_config["config"] = config 

221 ConfigSchema.SERVICE.validate(json_config) 

222 

223 def __repr__(self) -> str: 

224 return f"{self.__class__.__name__}@{hex(id(self))}" 

225 

226 def pprint(self) -> str: 

227 """Produce a human-readable string listing all public methods of the service.""" 

228 return f"{self} ::\n" + "\n".join( 

229 f' "{key}": {getattr(val, "__self__", "stand-alone")}' 

230 for (key, val) in self._service_methods.items() 

231 ) 

232 

233 @property 

234 def config_loader_service(self) -> SupportsConfigLoading: 

235 """ 

236 Return a config loader service. 

237 

238 Returns 

239 ------- 

240 config_loader_service : SupportsConfigLoading 

241 A config loader service. 

242 """ 

243 return self._config_loader_service 

244 

245 def register(self, services: dict[str, Callable] | list[Callable]) -> None: 

246 """ 

247 Register new mix-in services. 

248 

249 Parameters 

250 ---------- 

251 services : dict or list 

252 A dictionary of string -> function pairs. 

253 """ 

254 if not isinstance(services, dict): 

255 services = {svc.__name__: svc for svc in services} 

256 

257 self._service_methods.update(services) 

258 self.__dict__.update(self._service_methods) 

259 

260 if _LOG.isEnabledFor(logging.DEBUG): 

261 _LOG.debug("Added methods to: %s", self.pprint()) 

262 

263 # In order to get a list of all child contexts, we need to look at only 

264 # the bound methods that were not overridden by another mixin. 

265 # Then we inspect the internally bound __self__ variable to discover 

266 # which Service instance that method belongs too. 

267 # To do this we also 

268 

269 # All service loading must happen prior to entering a context. 

270 assert not self._in_context 

271 assert not self._service_contexts 

272 self._services = { 

273 # Enumerate the Services that are bound to this instance in the 

274 # order they were added. 

275 # Unfortunately, by creating a set, we may destroy the ability to 

276 # preserve the context enter/exit order, but hopefully it doesn't 

277 # matter. 

278 svc_method.__self__ 

279 for _, svc_method in self._service_methods.items() 

280 # Note: some methods are actually stand alone functions, so we need 

281 # to filter them out. 

282 if isinstance(svc_method, BoundMethod) and isinstance(svc_method.__self__, Service) 

283 } 

284 

285 def export(self) -> dict[str, Callable]: 

286 """ 

287 Return a dictionary of functions available in this service. 

288 

289 Returns 

290 ------- 

291 services : dict 

292 A dictionary of string -> function pairs. 

293 """ 

294 if _LOG.isEnabledFor(logging.DEBUG): 

295 _LOG.debug("Export methods from: %s", self.pprint()) 

296 

297 return self._service_methods