Coverage for mlos_bench/mlos_bench/services/remote/ssh/ssh_service.py: 89%

157 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-10-07 01:52 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5"""A collection functions for interacting with SSH servers as file shares.""" 

6 

7import logging 

8import os 

9from abc import ABCMeta 

10from asyncio import Event as CoroEvent 

11from asyncio import Lock as CoroLock 

12from threading import current_thread 

13from types import TracebackType 

14from typing import ( 

15 Any, 

16 Callable, 

17 Coroutine, 

18 Dict, 

19 List, 

20 Literal, 

21 Optional, 

22 Tuple, 

23 Type, 

24 Union, 

25) 

26from warnings import warn 

27 

28import asyncssh 

29from asyncssh.connection import SSHClientConnection 

30 

31from mlos_bench.event_loop_context import ( 

32 CoroReturnType, 

33 EventLoopContext, 

34 FutureReturnType, 

35) 

36from mlos_bench.services.base_service import Service 

37from mlos_bench.util import nullable 

38 

39_LOG = logging.getLogger(__name__) 

40 

41 

42class SshClient(asyncssh.SSHClient): 

43 """ 

44 Wrapper around SSHClient to help provide connection caching and reconnect logic. 

45 

46 Used by the SshService to try and maintain a single connection to hosts, handle 

47 reconnects if possible, and use that to run commands rather than reconnect for each 

48 command. 

49 """ 

50 

51 _CONNECTION_PENDING = "INIT" 

52 _CONNECTION_LOST = "LOST" 

53 

54 def __init__(self, *args: tuple, **kwargs: dict): 

55 self._connection_id: str = SshClient._CONNECTION_PENDING 

56 self._connection: Optional[SSHClientConnection] = None 

57 self._conn_event: CoroEvent = CoroEvent() 

58 super().__init__(*args, **kwargs) 

59 

60 def __repr__(self) -> str: 

61 return self._connection_id 

62 

63 @staticmethod 

64 def id_from_connection(connection: SSHClientConnection) -> str: 

65 """Gets a unique id repr for the connection.""" 

66 # pylint: disable=protected-access 

67 return f"{connection._username}@{connection._host}:{connection._port}" 

68 

69 @staticmethod 

70 def id_from_params(connect_params: dict) -> str: 

71 """Gets a unique id repr for the connection.""" 

72 return ( 

73 f"{connect_params.get('username')}@{connect_params['host']}" 

74 f":{connect_params.get('port')}" 

75 ) 

76 

77 def connection_made(self, conn: SSHClientConnection) -> None: 

78 """ 

79 Override hook provided by asyncssh.SSHClient. 

80 

81 Changes the connection_id from _CONNECTION_PENDING to a unique id repr. 

82 """ 

83 self._conn_event.clear() 

84 _LOG.debug( 

85 "%s: Connection made by %s: %s", 

86 current_thread().name, 

87 conn._options.env, # pylint: disable=protected-access 

88 conn, 

89 ) 

90 self._connection_id = SshClient.id_from_connection(conn) 

91 self._connection = conn 

92 self._conn_event.set() 

93 return super().connection_made(conn) 

94 

95 def connection_lost(self, exc: Optional[Exception]) -> None: 

96 self._conn_event.clear() 

97 _LOG.debug("%s: %s", current_thread().name, "connection_lost") 

98 if exc is None: 

99 _LOG.debug( 

100 "%s: gracefully disconnected ssh from %s: %s", 

101 current_thread().name, 

102 self._connection_id, 

103 exc, 

104 ) 

105 else: 

106 _LOG.debug( 

107 "%s: ssh connection lost on %s: %s", 

108 current_thread().name, 

109 self._connection_id, 

110 exc, 

111 ) 

112 self._connection_id = SshClient._CONNECTION_LOST 

113 self._connection = None 

114 self._conn_event.set() 

115 return super().connection_lost(exc) 

116 

117 async def connection(self) -> Optional[SSHClientConnection]: 

118 """Waits for and returns the SSHClientConnection to be established or lost.""" 

119 _LOG.debug("%s: Waiting for connection to be available.", current_thread().name) 

120 await self._conn_event.wait() 

121 _LOG.debug("%s: Connection available for %s", current_thread().name, self._connection_id) 

122 return self._connection 

123 

124 

125class SshClientCache: 

126 """ 

127 Manages a cache of SshClient connections. 

128 

129 Note: Only one per event loop thread supported. 

130 See additional details in SshService comments. 

131 """ 

132 

133 def __init__(self) -> None: 

134 self._cache: Dict[str, Tuple[SSHClientConnection, SshClient]] = {} 

135 self._cache_lock = CoroLock() 

136 self._refcnt: int = 0 

137 

138 def __str__(self) -> str: 

139 return str(self._cache) 

140 

141 def __len__(self) -> int: 

142 return len(self._cache) 

143 

144 def enter(self) -> None: 

145 """ 

146 Manages the cache lifecycle with reference counting. 

147 

148 To be used in the __enter__ method of a caller's context manager. 

149 """ 

150 self._refcnt += 1 

151 

152 def exit(self) -> None: 

153 """ 

154 Manages the cache lifecycle with reference counting. 

155 

156 To be used in the __exit__ method of a caller's context manager. 

157 """ 

158 self._refcnt -= 1 

159 if self._refcnt <= 0: 

160 self.cleanup() 

161 if self._cache_lock.locked(): 

162 warn(RuntimeWarning("SshClientCache lock was still held on exit.")) 

163 self._cache_lock.release() 

164 

165 async def get_client_connection( 

166 self, 

167 connect_params: dict, 

168 ) -> Tuple[SSHClientConnection, SshClient]: 

169 """ 

170 Gets a (possibly cached) client connection. 

171 

172 Parameters 

173 ---------- 

174 connect_params: dict 

175 Parameters to pass to asyncssh.create_connection. 

176 

177 Returns 

178 ------- 

179 Tuple[SSHClientConnection, SshClient] 

180 A tuple of (SSHClientConnection, SshClient). 

181 """ 

182 _LOG.debug("%s: get_client_connection: %s", current_thread().name, connect_params) 

183 async with self._cache_lock: 

184 connection_id = SshClient.id_from_params(connect_params) 

185 client: Union[None, SshClient, asyncssh.SSHClient] 

186 _, client = self._cache.get(connection_id, (None, None)) 

187 if client: 

188 _LOG.debug("%s: Checking cached client %s", current_thread().name, connection_id) 

189 connection = await client.connection() 

190 if not connection: 

191 _LOG.debug( 

192 "%s: Removing stale client connection %s from cache.", 

193 current_thread().name, 

194 connection_id, 

195 ) 

196 self._cache.pop(connection_id) 

197 # Try to reconnect next. 

198 else: 

199 _LOG.debug("%s: Using cached client %s", current_thread().name, connection_id) 

200 if connection_id not in self._cache: 

201 _LOG.debug( 

202 "%s: Establishing client connection to %s", 

203 current_thread().name, 

204 connection_id, 

205 ) 

206 connection, client = await asyncssh.create_connection(SshClient, **connect_params) 

207 assert isinstance(client, SshClient) 

208 self._cache[connection_id] = (connection, client) 

209 _LOG.debug("%s: Created connection to %s.", current_thread().name, connection_id) 

210 return self._cache[connection_id] 

211 

212 def cleanup(self) -> None: 

213 """Closes all cached connections.""" 

214 for connection, _ in self._cache.values(): 

215 connection.close() 

216 self._cache = {} 

217 

218 

219class SshService(Service, metaclass=ABCMeta): 

220 """Base class for SSH services.""" 

221 

222 # AsyncSSH requires an asyncio event loop to be running to work. 

223 # However, running that event loop blocks the main thread. 

224 # To avoid having to change our entire API to use async/await, all the way 

225 # up the stack, we run the event loop that runs any async code in a 

226 # background thread and submit async code to it using 

227 # asyncio.run_coroutine_threadsafe, interacting with Futures after that. 

228 # This is a bit of a hack, but it works for now. 

229 # 

230 # The event loop is created on demand and shared across all SshService 

231 # instances, hence we need to lock it when doing the creation/cleanup, 

232 # or later, during context enter and exit. 

233 # 

234 # We ran tests to ensure that multiple requests can still be executing 

235 # concurrently inside that event loop so there should be no practical 

236 # performance loss for our initial cases even with just single background 

237 # thread running the event loop. 

238 # 

239 # Note: the tests were run to confirm that this works with two threads. 

240 # Using a larger thread pool requires a bit more work since asyncssh 

241 # requires that run() requests are submitted to the same event loop handler 

242 # that the connection was made on. 

243 # In that case, each background thread should get its own SshClientCache. 

244 

245 # Maintain one just one event loop thread for all SshService instances. 

246 # But only keep it running while they are within a context. 

247 _EVENT_LOOP_CONTEXT = EventLoopContext() 

248 _EVENT_LOOP_THREAD_SSH_CLIENT_CACHE = SshClientCache() 

249 

250 _REQUEST_TIMEOUT: Optional[float] = None # seconds 

251 

252 def __init__( 

253 self, 

254 config: Optional[Dict[str, Any]] = None, 

255 global_config: Optional[Dict[str, Any]] = None, 

256 parent: Optional[Service] = None, 

257 methods: Union[Dict[str, Callable], List[Callable], None] = None, 

258 ): 

259 super().__init__(config, global_config, parent, methods) 

260 

261 # Make sure that the value we allow overriding on a per-connection 

262 # basis are present in the config so merge_parameters can do its thing. 

263 self.config.setdefault("ssh_port", None) 

264 assert isinstance(self.config["ssh_port"], (int, type(None))) 

265 self.config.setdefault("ssh_username", None) 

266 assert isinstance(self.config["ssh_username"], (str, type(None))) 

267 self.config.setdefault("ssh_priv_key_path", None) 

268 assert isinstance(self.config["ssh_priv_key_path"], (str, type(None))) 

269 

270 # None can be used to disable the request timeout. 

271 self._request_timeout = self.config.get("ssh_request_timeout", self._REQUEST_TIMEOUT) 

272 self._request_timeout = nullable(float, self._request_timeout) 

273 

274 # Prep an initial connect_params. 

275 self._connect_params: dict = { 

276 # In general scripted commands shouldn't need a pty and having one 

277 # available can confuse some commands, though we may need to make 

278 # this configurable in the future. 

279 "request_pty": False, 

280 # By default disable known_hosts checking (since most VMs expected to be 

281 # dynamically created). 

282 "known_hosts": None, 

283 } 

284 

285 if "ssh_known_hosts_file" in self.config: 

286 self._connect_params["known_hosts"] = self.config.get("ssh_known_hosts_file", None) 

287 if isinstance(self._connect_params["known_hosts"], str): 

288 known_hosts_file = os.path.expanduser(self._connect_params["known_hosts"]) 

289 if not os.path.exists(known_hosts_file): 

290 raise ValueError(f"ssh_known_hosts_file {known_hosts_file} does not exist") 

291 self._connect_params["known_hosts"] = known_hosts_file 

292 if self._connect_params["known_hosts"] is None: 

293 _LOG.info("%s known_hosts checking is disabled per config.", self) 

294 

295 if "ssh_keepalive_interval" in self.config: 

296 keepalive_internal = self.config.get("ssh_keepalive_interval") 

297 self._connect_params["keepalive_interval"] = nullable(int, keepalive_internal) 

298 

299 def _enter_context(self) -> "SshService": 

300 # Start the background thread if it's not already running. 

301 assert not self._in_context 

302 SshService._EVENT_LOOP_CONTEXT.enter() 

303 SshService._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE.enter() 

304 super()._enter_context() 

305 return self 

306 

307 def _exit_context( 

308 self, 

309 ex_type: Optional[Type[BaseException]], 

310 ex_val: Optional[BaseException], 

311 ex_tb: Optional[TracebackType], 

312 ) -> Literal[False]: 

313 # Stop the background thread if it's not needed anymore and potentially 

314 # cleanup the cache as well. 

315 assert self._in_context 

316 SshService._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE.exit() 

317 SshService._EVENT_LOOP_CONTEXT.exit() 

318 return super()._exit_context(ex_type, ex_val, ex_tb) 

319 

320 @classmethod 

321 def clear_client_cache(cls) -> None: 

322 """ 

323 Clears the cache of client connections. 

324 

325 Note: This may cause in flight operations to fail. 

326 """ 

327 cls._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE.cleanup() 

328 

329 def _run_coroutine(self, coro: Coroutine[Any, Any, CoroReturnType]) -> FutureReturnType: 

330 """ 

331 Runs the given coroutine in the background event loop thread. 

332 

333 Parameters 

334 ---------- 

335 coro : Coroutine[Any, Any, CoroReturnType] 

336 The coroutine to run. 

337 

338 Returns 

339 ------- 

340 Future[CoroReturnType] 

341 A future that will be completed when the coroutine completes. 

342 """ 

343 assert self._in_context 

344 return self._EVENT_LOOP_CONTEXT.run_coroutine(coro) 

345 

346 def _get_connect_params(self, params: dict) -> dict: 

347 """ 

348 Produces a dict of connection parameters for asyncssh.create_connection. 

349 

350 Parameters 

351 ---------- 

352 params : dict 

353 Additional connection parameters specific to this host. 

354 

355 Returns 

356 ------- 

357 dict 

358 A dict of connection parameters for asyncssh.create_connection. 

359 """ 

360 # Setup default connect_params dict for all SshClients we might need to create. 

361 

362 # Note: None is an acceptable value for several of these, in which case 

363 # reasonable defaults or values from ~/.ssh/config will take effect. 

364 

365 # Start with the base config params. 

366 connect_params = self._connect_params.copy() 

367 

368 connect_params["host"] = params["ssh_hostname"] # required 

369 

370 if params.get("ssh_port"): 

371 connect_params["port"] = int(params.pop("ssh_port")) 

372 elif self.config["ssh_port"]: 

373 connect_params["port"] = int(self.config["ssh_port"]) 

374 

375 if "ssh_username" in params: 

376 connect_params["username"] = str(params.pop("ssh_username")) 

377 elif self.config["ssh_username"]: 

378 connect_params["username"] = str(self.config["ssh_username"]) 

379 

380 priv_key_file: Optional[str] = params.get( 

381 "ssh_priv_key_path", 

382 self.config["ssh_priv_key_path"], 

383 ) 

384 if priv_key_file: 

385 priv_key_file = os.path.expanduser(priv_key_file) 

386 if not os.path.exists(priv_key_file): 

387 raise ValueError(f"ssh_priv_key_path {priv_key_file} does not exist") 

388 connect_params["client_keys"] = [priv_key_file] 

389 

390 return connect_params 

391 

392 async def _get_client_connection(self, params: dict) -> Tuple[SSHClientConnection, SshClient]: 

393 """ 

394 Gets a (possibly cached) SshClient (connection) for the given connection params. 

395 

396 Parameters 

397 ---------- 

398 params : dict 

399 Optional override connection parameters. 

400 

401 Returns 

402 ------- 

403 Tuple[SSHClientConnection, SshClient] 

404 The connection and client objects. 

405 """ 

406 assert self._in_context 

407 return await SshService._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE.get_client_connection( 

408 self._get_connect_params(params) 

409 )