Coverage for mlos_bench/mlos_bench/services/base_service.py: 94%

104 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-10-07 01:52 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5"""Base class for the service mix-ins.""" 

6 

7import json 

8import logging 

9from types import TracebackType 

10from typing import Any, Callable, Dict, List, Optional, Set, Type, Union 

11 

12from typing_extensions import Literal 

13 

14from mlos_bench.config.schemas import ConfigSchema 

15from mlos_bench.services.types.config_loader_type import SupportsConfigLoading 

16from mlos_bench.util import instantiate_from_config 

17 

18_LOG = logging.getLogger(__name__) 

19 

20 

21class Service: 

22 """An abstract base of all Environment Services and used to build up mix-ins.""" 

23 

24 @classmethod 

25 def new( 

26 cls, 

27 class_name: str, 

28 config: Optional[Dict[str, Any]] = None, 

29 global_config: Optional[Dict[str, Any]] = None, 

30 parent: Optional["Service"] = None, 

31 ) -> "Service": 

32 """ 

33 Factory method for a new service with a given config. 

34 

35 Parameters 

36 ---------- 

37 class_name: str 

38 FQN of a Python class to instantiate, e.g., 

39 "mlos_bench.services.remote.azure.AzureVMService". 

40 Must be derived from the `Service` class. 

41 config : dict 

42 Free-format dictionary that contains the service configuration. 

43 It will be passed as a constructor parameter of the class 

44 specified by `class_name`. 

45 global_config : dict 

46 Free-format dictionary of global parameters. 

47 parent : Service 

48 A parent service that can provide mixin functions. 

49 

50 Returns 

51 ------- 

52 svc : Service 

53 An instance of the `Service` class initialized with `config`. 

54 """ 

55 assert issubclass(cls, Service) 

56 return instantiate_from_config(cls, class_name, config, global_config, parent) 

57 

58 def __init__( 

59 self, 

60 config: Optional[Dict[str, Any]] = None, 

61 global_config: Optional[Dict[str, Any]] = None, 

62 parent: Optional["Service"] = None, 

63 methods: Union[Dict[str, Callable], List[Callable], None] = None, 

64 ): 

65 """ 

66 Create a new service with a given config. 

67 

68 Parameters 

69 ---------- 

70 config : dict 

71 Free-format dictionary that contains the service configuration. 

72 It will be passed as a constructor parameter of the class 

73 specified by `class_name`. 

74 global_config : dict 

75 Free-format dictionary of global parameters. 

76 parent : Service 

77 An optional parent service that can provide mixin functions. 

78 methods : Union[Dict[str, Callable], List[Callable], None] 

79 New methods to register with the service. 

80 """ 

81 self.config = config or {} 

82 self._validate_json_config(self.config) 

83 self._parent = parent 

84 self._service_methods: Dict[str, Callable] = {} 

85 self._services: Set[Service] = set() 

86 self._service_contexts: List[Service] = [] 

87 self._in_context = False 

88 

89 if parent: 

90 self.register(parent.export()) 

91 if methods: 

92 self.register(methods) 

93 

94 self._config_loader_service: SupportsConfigLoading 

95 if parent and isinstance(parent, SupportsConfigLoading): 

96 self._config_loader_service = parent 

97 

98 if _LOG.isEnabledFor(logging.DEBUG): 

99 _LOG.debug("Service: %s Config:\n%s", self, json.dumps(self.config, indent=2)) 

100 _LOG.debug("Service: %s Globals:\n%s", self, json.dumps(global_config or {}, indent=2)) 

101 _LOG.debug("Service: %s Parent: %s", self, parent.pprint() if parent else None) 

102 

103 @staticmethod 

104 def merge_methods( 

105 ext_methods: Union[Dict[str, Callable], List[Callable], None], 

106 local_methods: Union[Dict[str, Callable], List[Callable]], 

107 ) -> Dict[str, Callable]: 

108 """ 

109 Merge methods from the external caller with the local ones. 

110 

111 This function is usually called by the derived class constructor just before 

112 invoking the constructor of the base class. 

113 """ 

114 if isinstance(local_methods, dict): 

115 local_methods = local_methods.copy() 

116 else: 

117 local_methods = {svc.__name__: svc for svc in local_methods} 

118 

119 if not ext_methods: 

120 return local_methods 

121 

122 if not isinstance(ext_methods, dict): 

123 ext_methods = {svc.__name__: svc for svc in ext_methods} 

124 

125 local_methods.update(ext_methods) 

126 return local_methods 

127 

128 def __enter__(self) -> "Service": 

129 """ 

130 Enter the Service mix-in context. 

131 

132 Calls the _enter_context() method of all the Services registered under this one. 

133 """ 

134 if self._in_context: 

135 # Multiple environments can share the same Service, so we need to 

136 # add a check and make this a re-entrant Service context. 

137 assert self._service_contexts 

138 assert all(svc._in_context for svc in self._services) 

139 return self 

140 self._service_contexts = [svc._enter_context() for svc in self._services] 

141 self._in_context = True 

142 return self 

143 

144 def __exit__( 

145 self, 

146 ex_type: Optional[Type[BaseException]], 

147 ex_val: Optional[BaseException], 

148 ex_tb: Optional[TracebackType], 

149 ) -> Literal[False]: 

150 """ 

151 Exit the Service mix-in context. 

152 

153 Calls the _exit_context() method of all the Services registered under this one. 

154 """ 

155 if not self._in_context: 

156 # Multiple environments can share the same Service, so we need to 

157 # add a check and make this a re-entrant Service context. 

158 assert not self._service_contexts 

159 assert all(not svc._in_context for svc in self._services) 

160 return False 

161 ex_throw = None 

162 for svc in reversed(self._service_contexts): 

163 try: 

164 svc._exit_context(ex_type, ex_val, ex_tb) 

165 # pylint: disable=broad-exception-caught 

166 except Exception as ex: 

167 _LOG.error("Exception while exiting Service context '%s': %s", svc, ex) 

168 ex_throw = ex 

169 self._service_contexts = [] 

170 if ex_throw: 

171 raise ex_throw 

172 self._in_context = False 

173 return False 

174 

175 def _enter_context(self) -> "Service": 

176 """ 

177 Enters the context for this particular Service instance. 

178 

179 Called by the base __enter__ method of the Service class so it can be used with 

180 mix-ins and overridden by subclasses. 

181 """ 

182 assert not self._in_context 

183 self._in_context = True 

184 return self 

185 

186 def _exit_context( 

187 self, 

188 ex_type: Optional[Type[BaseException]], 

189 ex_val: Optional[BaseException], 

190 ex_tb: Optional[TracebackType], 

191 ) -> Literal[False]: 

192 """ 

193 Exits the context for this particular Service instance. 

194 

195 Called by the base __enter__ method of the Service class so it can be used with 

196 mix-ins and overridden by subclasses. 

197 """ 

198 # pylint: disable=unused-argument 

199 assert self._in_context 

200 self._in_context = False 

201 return False 

202 

203 def _validate_json_config(self, config: dict) -> None: 

204 """Reconstructs a basic json config that this class might have been instantiated 

205 from in order to validate configs provided outside the file loading 

206 mechanism. 

207 """ 

208 if self.__class__ == Service: 

209 # Skip over the case where instantiate a bare base Service class in 

210 # order to build up a mix-in. 

211 assert config == {} 

212 return 

213 json_config: dict = { 

214 "class": self.__class__.__module__ + "." + self.__class__.__name__, 

215 } 

216 if config: 

217 json_config["config"] = config 

218 ConfigSchema.SERVICE.validate(json_config) 

219 

220 def __repr__(self) -> str: 

221 return f"{self.__class__.__name__}@{hex(id(self))}" 

222 

223 def pprint(self) -> str: 

224 """Produce a human-readable string listing all public methods of the service.""" 

225 return f"{self} ::\n" + "\n".join( 

226 f' "{key}": {getattr(val, "__self__", "stand-alone")}' 

227 for (key, val) in self._service_methods.items() 

228 ) 

229 

230 @property 

231 def config_loader_service(self) -> SupportsConfigLoading: 

232 """ 

233 Return a config loader service. 

234 

235 Returns 

236 ------- 

237 config_loader_service : SupportsConfigLoading 

238 A config loader service. 

239 """ 

240 return self._config_loader_service 

241 

242 def register(self, services: Union[Dict[str, Callable], List[Callable]]) -> None: 

243 """ 

244 Register new mix-in services. 

245 

246 Parameters 

247 ---------- 

248 services : dict or list 

249 A dictionary of string -> function pairs. 

250 """ 

251 if not isinstance(services, dict): 

252 services = {svc.__name__: svc for svc in services} 

253 

254 self._service_methods.update(services) 

255 self.__dict__.update(self._service_methods) 

256 

257 if _LOG.isEnabledFor(logging.DEBUG): 

258 _LOG.debug("Added methods to: %s", self.pprint()) 

259 

260 # In order to get a list of all child contexts, we need to look at only 

261 # the bound methods that were not overridden by another mixin. 

262 # Then we inspect the internally bound __self__ variable to discover 

263 # which Service instance that method belongs too. 

264 # To do this we also 

265 

266 # All service loading must happen prior to entering a context. 

267 assert not self._in_context 

268 assert not self._service_contexts 

269 self._services = { 

270 # Enumerate the Services that are bound to this instance in the 

271 # order they were added. 

272 # Unfortunately, by creating a set, we may destroy the ability to 

273 # preserve the context enter/exit order, but hopefully it doesn't 

274 # matter. 

275 svc_method.__self__ 

276 for _, svc_method in self._service_methods.items() 

277 # Note: some methods are actually stand alone functions, so we need 

278 # to filter them out. 

279 if hasattr(svc_method, "__self__") and isinstance(svc_method.__self__, Service) 

280 } 

281 

282 def export(self) -> Dict[str, Callable]: 

283 """ 

284 Return a dictionary of functions available in this service. 

285 

286 Returns 

287 ------- 

288 services : dict 

289 A dictionary of string -> function pairs. 

290 """ 

291 if _LOG.isEnabledFor(logging.DEBUG): 

292 _LOG.debug("Export methods from: %s", self.pprint()) 

293 

294 return self._service_methods