Coverage for mlos_bench/mlos_bench/services/base_service.py: 94%

103 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2024-12-14 01:58 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5"""Base class for the service mix-ins.""" 

6 

7import json 

8import logging 

9from types import TracebackType 

10from typing import Any, Callable, Dict, List, Literal, Optional, Set, Type, Union 

11 

12from mlos_bench.config.schemas import ConfigSchema 

13from mlos_bench.services.types.config_loader_type import SupportsConfigLoading 

14from mlos_bench.util import instantiate_from_config 

15 

16_LOG = logging.getLogger(__name__) 

17 

18 

19class Service: 

20 """An abstract base of all Environment Services and used to build up mix-ins.""" 

21 

22 @classmethod 

23 def new( 

24 cls, 

25 class_name: str, 

26 config: Optional[Dict[str, Any]] = None, 

27 global_config: Optional[Dict[str, Any]] = None, 

28 parent: Optional["Service"] = None, 

29 ) -> "Service": 

30 """ 

31 Factory method for a new service with a given config. 

32 

33 Parameters 

34 ---------- 

35 class_name: str 

36 FQN of a Python class to instantiate, e.g., 

37 "mlos_bench.services.remote.azure.AzureVMService". 

38 Must be derived from the `Service` class. 

39 config : dict 

40 Free-format dictionary that contains the service configuration. 

41 It will be passed as a constructor parameter of the class 

42 specified by `class_name`. 

43 global_config : dict 

44 Free-format dictionary of global parameters. 

45 parent : Service 

46 A parent service that can provide mixin functions. 

47 

48 Returns 

49 ------- 

50 svc : Service 

51 An instance of the `Service` class initialized with `config`. 

52 """ 

53 assert issubclass(cls, Service) 

54 return instantiate_from_config(cls, class_name, config, global_config, parent) 

55 

56 def __init__( 

57 self, 

58 config: Optional[Dict[str, Any]] = None, 

59 global_config: Optional[Dict[str, Any]] = None, 

60 parent: Optional["Service"] = None, 

61 methods: Union[Dict[str, Callable], List[Callable], None] = None, 

62 ): 

63 """ 

64 Create a new service with a given config. 

65 

66 Parameters 

67 ---------- 

68 config : dict 

69 Free-format dictionary that contains the service configuration. 

70 It will be passed as a constructor parameter of the class 

71 specified by `class_name`. 

72 global_config : dict 

73 Free-format dictionary of global parameters. 

74 parent : Service 

75 An optional parent service that can provide mixin functions. 

76 methods : Union[Dict[str, Callable], List[Callable], None] 

77 New methods to register with the service. 

78 """ 

79 self.config = config or {} 

80 self._validate_json_config(self.config) 

81 self._parent = parent 

82 self._service_methods: Dict[str, Callable] = {} 

83 self._services: Set[Service] = set() 

84 self._service_contexts: List[Service] = [] 

85 self._in_context = False 

86 

87 if parent: 

88 self.register(parent.export()) 

89 if methods: 

90 self.register(methods) 

91 

92 self._config_loader_service: SupportsConfigLoading 

93 if parent and isinstance(parent, SupportsConfigLoading): 

94 self._config_loader_service = parent 

95 

96 if _LOG.isEnabledFor(logging.DEBUG): 

97 _LOG.debug("Service: %s Config:\n%s", self, json.dumps(self.config, indent=2)) 

98 _LOG.debug("Service: %s Globals:\n%s", self, json.dumps(global_config or {}, indent=2)) 

99 _LOG.debug("Service: %s Parent: %s", self, parent.pprint() if parent else None) 

100 

101 @staticmethod 

102 def merge_methods( 

103 ext_methods: Union[Dict[str, Callable], List[Callable], None], 

104 local_methods: Union[Dict[str, Callable], List[Callable]], 

105 ) -> Dict[str, Callable]: 

106 """ 

107 Merge methods from the external caller with the local ones. 

108 

109 This function is usually called by the derived class constructor just before 

110 invoking the constructor of the base class. 

111 """ 

112 if isinstance(local_methods, dict): 

113 local_methods = local_methods.copy() 

114 else: 

115 local_methods = {svc.__name__: svc for svc in local_methods} 

116 

117 if not ext_methods: 

118 return local_methods 

119 

120 if not isinstance(ext_methods, dict): 

121 ext_methods = {svc.__name__: svc for svc in ext_methods} 

122 

123 local_methods.update(ext_methods) 

124 return local_methods 

125 

126 def __enter__(self) -> "Service": 

127 """ 

128 Enter the Service mix-in context. 

129 

130 Calls the _enter_context() method of all the Services registered under this one. 

131 """ 

132 if self._in_context: 

133 # Multiple environments can share the same Service, so we need to 

134 # add a check and make this a re-entrant Service context. 

135 assert self._service_contexts 

136 assert all(svc._in_context for svc in self._services) 

137 return self 

138 self._service_contexts = [svc._enter_context() for svc in self._services] 

139 self._in_context = True 

140 return self 

141 

142 def __exit__( 

143 self, 

144 ex_type: Optional[Type[BaseException]], 

145 ex_val: Optional[BaseException], 

146 ex_tb: Optional[TracebackType], 

147 ) -> Literal[False]: 

148 """ 

149 Exit the Service mix-in context. 

150 

151 Calls the _exit_context() method of all the Services registered under this one. 

152 """ 

153 if not self._in_context: 

154 # Multiple environments can share the same Service, so we need to 

155 # add a check and make this a re-entrant Service context. 

156 assert not self._service_contexts 

157 assert all(not svc._in_context for svc in self._services) 

158 return False 

159 ex_throw = None 

160 for svc in reversed(self._service_contexts): 

161 try: 

162 svc._exit_context(ex_type, ex_val, ex_tb) 

163 # pylint: disable=broad-exception-caught 

164 except Exception as ex: 

165 _LOG.error("Exception while exiting Service context '%s': %s", svc, ex) 

166 ex_throw = ex 

167 self._service_contexts = [] 

168 if ex_throw: 

169 raise ex_throw 

170 self._in_context = False 

171 return False 

172 

173 def _enter_context(self) -> "Service": 

174 """ 

175 Enters the context for this particular Service instance. 

176 

177 Called by the base __enter__ method of the Service class so it can be used with 

178 mix-ins and overridden by subclasses. 

179 """ 

180 assert not self._in_context 

181 self._in_context = True 

182 return self 

183 

184 def _exit_context( 

185 self, 

186 ex_type: Optional[Type[BaseException]], 

187 ex_val: Optional[BaseException], 

188 ex_tb: Optional[TracebackType], 

189 ) -> Literal[False]: 

190 """ 

191 Exits the context for this particular Service instance. 

192 

193 Called by the base __enter__ method of the Service class so it can be used with 

194 mix-ins and overridden by subclasses. 

195 """ 

196 # pylint: disable=unused-argument 

197 assert self._in_context 

198 self._in_context = False 

199 return False 

200 

201 def _validate_json_config(self, config: dict) -> None: 

202 """Reconstructs a basic json config that this class might have been instantiated 

203 from in order to validate configs provided outside the file loading 

204 mechanism. 

205 """ 

206 if self.__class__ == Service: 

207 # Skip over the case where instantiate a bare base Service class in 

208 # order to build up a mix-in. 

209 assert config == {} 

210 return 

211 json_config: dict = { 

212 "class": self.__class__.__module__ + "." + self.__class__.__name__, 

213 } 

214 if config: 

215 json_config["config"] = config 

216 ConfigSchema.SERVICE.validate(json_config) 

217 

218 def __repr__(self) -> str: 

219 return f"{self.__class__.__name__}@{hex(id(self))}" 

220 

221 def pprint(self) -> str: 

222 """Produce a human-readable string listing all public methods of the service.""" 

223 return f"{self} ::\n" + "\n".join( 

224 f' "{key}": {getattr(val, "__self__", "stand-alone")}' 

225 for (key, val) in self._service_methods.items() 

226 ) 

227 

228 @property 

229 def config_loader_service(self) -> SupportsConfigLoading: 

230 """ 

231 Return a config loader service. 

232 

233 Returns 

234 ------- 

235 config_loader_service : SupportsConfigLoading 

236 A config loader service. 

237 """ 

238 return self._config_loader_service 

239 

240 def register(self, services: Union[Dict[str, Callable], List[Callable]]) -> None: 

241 """ 

242 Register new mix-in services. 

243 

244 Parameters 

245 ---------- 

246 services : dict or list 

247 A dictionary of string -> function pairs. 

248 """ 

249 if not isinstance(services, dict): 

250 services = {svc.__name__: svc for svc in services} 

251 

252 self._service_methods.update(services) 

253 self.__dict__.update(self._service_methods) 

254 

255 if _LOG.isEnabledFor(logging.DEBUG): 

256 _LOG.debug("Added methods to: %s", self.pprint()) 

257 

258 # In order to get a list of all child contexts, we need to look at only 

259 # the bound methods that were not overridden by another mixin. 

260 # Then we inspect the internally bound __self__ variable to discover 

261 # which Service instance that method belongs too. 

262 # To do this we also 

263 

264 # All service loading must happen prior to entering a context. 

265 assert not self._in_context 

266 assert not self._service_contexts 

267 self._services = { 

268 # Enumerate the Services that are bound to this instance in the 

269 # order they were added. 

270 # Unfortunately, by creating a set, we may destroy the ability to 

271 # preserve the context enter/exit order, but hopefully it doesn't 

272 # matter. 

273 svc_method.__self__ 

274 for _, svc_method in self._service_methods.items() 

275 # Note: some methods are actually stand alone functions, so we need 

276 # to filter them out. 

277 if hasattr(svc_method, "__self__") and isinstance(svc_method.__self__, Service) 

278 } 

279 

280 def export(self) -> Dict[str, Callable]: 

281 """ 

282 Return a dictionary of functions available in this service. 

283 

284 Returns 

285 ------- 

286 services : dict 

287 A dictionary of string -> function pairs. 

288 """ 

289 if _LOG.isEnabledFor(logging.DEBUG): 

290 _LOG.debug("Export methods from: %s", self.pprint()) 

291 

292 return self._service_methods