Coverage for mlos_bench/mlos_bench/services/remote/ssh/ssh_service.py: 89%

158 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2025-01-21 01:50 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5"""A collection functions for interacting with SSH servers as file shares.""" 

6 

7import logging 

8import os 

9from abc import ABCMeta 

10from asyncio import Event as CoroEvent 

11from asyncio import Lock as CoroLock 

12from collections.abc import Callable, Coroutine 

13from threading import current_thread 

14from types import TracebackType 

15from typing import Any, Literal 

16from warnings import warn 

17 

18import asyncssh 

19from asyncssh.connection import SSHClientConnection 

20 

21from mlos_bench.event_loop_context import ( 

22 CoroReturnType, 

23 EventLoopContext, 

24 FutureReturnType, 

25) 

26from mlos_bench.services.base_service import Service 

27from mlos_bench.util import nullable 

28 

29_LOG = logging.getLogger(__name__) 

30 

31 

32class SshClient(asyncssh.SSHClient): 

33 """ 

34 Wrapper around SSHClient to help provide connection caching and reconnect logic. 

35 

36 Used by the SshService to try and maintain a single connection to hosts, handle 

37 reconnects if possible, and use that to run commands rather than reconnect for each 

38 command. 

39 """ 

40 

41 _CONNECTION_PENDING = "INIT" 

42 _CONNECTION_LOST = "LOST" 

43 

44 def __init__(self, *args: tuple, **kwargs: dict): 

45 self._connection_id: str = SshClient._CONNECTION_PENDING 

46 self._connection: SSHClientConnection | None = None 

47 self._conn_event: CoroEvent = CoroEvent() 

48 super().__init__(*args, **kwargs) 

49 

50 def __repr__(self) -> str: 

51 return self._connection_id 

52 

53 @staticmethod 

54 def id_from_connection(connection: SSHClientConnection) -> str: 

55 """Gets a unique id repr for the connection.""" 

56 # pylint: disable=protected-access 

57 return f"{connection._username}@{connection._host}:{connection._port}" 

58 

59 @staticmethod 

60 def id_from_params(connect_params: dict) -> str: 

61 """Gets a unique id repr for the connection.""" 

62 return ( 

63 f"""{connect_params.get("username")}@{connect_params["host"]}""" 

64 f""":{connect_params.get("port")}""" 

65 ) 

66 

67 def connection_made(self, conn: SSHClientConnection) -> None: 

68 """ 

69 Override hook provided by asyncssh.SSHClient. 

70 

71 Changes the connection_id from _CONNECTION_PENDING to a unique id repr. 

72 """ 

73 self._conn_event.clear() 

74 _LOG.debug( 

75 "%s: Connection made by %s: %s", 

76 current_thread().name, 

77 conn._options.env, # pylint: disable=protected-access 

78 conn, 

79 ) 

80 self._connection_id = SshClient.id_from_connection(conn) 

81 self._connection = conn 

82 self._conn_event.set() 

83 return super().connection_made(conn) 

84 

85 def connection_lost(self, exc: Exception | None) -> None: 

86 self._conn_event.clear() 

87 _LOG.debug("%s: %s", current_thread().name, "connection_lost") 

88 if exc is None: 

89 _LOG.debug( 

90 "%s: gracefully disconnected ssh from %s: %s", 

91 current_thread().name, 

92 self._connection_id, 

93 exc, 

94 ) 

95 else: 

96 _LOG.debug( 

97 "%s: ssh connection lost on %s: %s", 

98 current_thread().name, 

99 self._connection_id, 

100 exc, 

101 ) 

102 self._connection_id = SshClient._CONNECTION_LOST 

103 self._connection = None 

104 self._conn_event.set() 

105 return super().connection_lost(exc) 

106 

107 async def connection(self) -> SSHClientConnection | None: 

108 """Waits for and returns the asyncssh.connection.SSHClientConnection to be 

109 established or lost. 

110 """ 

111 _LOG.debug("%s: Waiting for connection to be available.", current_thread().name) 

112 await self._conn_event.wait() 

113 _LOG.debug("%s: Connection available for %s", current_thread().name, self._connection_id) 

114 return self._connection 

115 

116 

117class SshClientCache: 

118 """ 

119 Manages a cache of SshClient connections. 

120 

121 Note: Only one per event loop thread supported. 

122 See additional details in SshService comments. 

123 """ 

124 

125 def __init__(self) -> None: 

126 self._cache: dict[str, tuple[SSHClientConnection, SshClient]] = {} 

127 self._cache_lock = CoroLock() 

128 self._refcnt: int = 0 

129 

130 def __str__(self) -> str: 

131 return str(self._cache) 

132 

133 def __len__(self) -> int: 

134 return len(self._cache) 

135 

136 def enter(self) -> None: 

137 """ 

138 Manages the cache lifecycle with reference counting. 

139 

140 To be used in the __enter__ method of a caller's context manager. 

141 """ 

142 self._refcnt += 1 

143 

144 def exit(self) -> None: 

145 """ 

146 Manages the cache lifecycle with reference counting. 

147 

148 To be used in the __exit__ method of a caller's context manager. 

149 """ 

150 self._refcnt -= 1 

151 if self._refcnt <= 0: 

152 self.cleanup() 

153 if self._cache_lock.locked(): 

154 warn(RuntimeWarning("SshClientCache lock was still held on exit.")) 

155 self._cache_lock.release() 

156 

157 async def get_client_connection( 

158 self, 

159 connect_params: dict, 

160 ) -> tuple[SSHClientConnection, SshClient]: 

161 """ 

162 Gets a (possibly cached) client connection. 

163 

164 Parameters 

165 ---------- 

166 connect_params: dict 

167 Parameters to pass to asyncssh.create_connection. 

168 

169 Returns 

170 ------- 

171 tuple[asyncssh.connection.SSHClientConnection, SshClient] 

172 A tuple of (SSHClientConnection, SshClient). 

173 """ 

174 _LOG.debug("%s: get_client_connection: %s", current_thread().name, connect_params) 

175 async with self._cache_lock: 

176 connection_id = SshClient.id_from_params(connect_params) 

177 client: None | SshClient | asyncssh.SSHClient 

178 _, client = self._cache.get(connection_id, (None, None)) 

179 if client: 

180 _LOG.debug("%s: Checking cached client %s", current_thread().name, connection_id) 

181 connection = await client.connection() 

182 if not connection: 

183 _LOG.debug( 

184 "%s: Removing stale client connection %s from cache.", 

185 current_thread().name, 

186 connection_id, 

187 ) 

188 self._cache.pop(connection_id) 

189 # Try to reconnect next. 

190 else: 

191 _LOG.debug("%s: Using cached client %s", current_thread().name, connection_id) 

192 if connection_id not in self._cache: 

193 _LOG.debug( 

194 "%s: Establishing client connection to %s", 

195 current_thread().name, 

196 connection_id, 

197 ) 

198 connection, client = await asyncssh.create_connection(SshClient, **connect_params) 

199 assert isinstance(client, SshClient) 

200 self._cache[connection_id] = (connection, client) 

201 _LOG.debug("%s: Created connection to %s.", current_thread().name, connection_id) 

202 return self._cache[connection_id] 

203 

204 def cleanup(self) -> None: 

205 """Closes all cached connections.""" 

206 for connection, _ in self._cache.values(): 

207 connection.close() 

208 self._cache = {} 

209 

210 

211class SshService(Service, metaclass=ABCMeta): 

212 """Base class for SSH services.""" 

213 

214 # AsyncSSH requires an asyncio event loop to be running to work. 

215 # However, running that event loop blocks the main thread. 

216 # To avoid having to change our entire API to use async/await, all the way 

217 # up the stack, we run the event loop that runs any async code in a 

218 # background thread and submit async code to it using 

219 # asyncio.run_coroutine_threadsafe, interacting with Futures after that. 

220 # This is a bit of a hack, but it works for now. 

221 # 

222 # The event loop is created on demand and shared across all SshService 

223 # instances, hence we need to lock it when doing the creation/cleanup, 

224 # or later, during context enter and exit. 

225 # 

226 # We ran tests to ensure that multiple requests can still be executing 

227 # concurrently inside that event loop so there should be no practical 

228 # performance loss for our initial cases even with just single background 

229 # thread running the event loop. 

230 # 

231 # Note: the tests were run to confirm that this works with two threads. 

232 # Using a larger thread pool requires a bit more work since asyncssh 

233 # requires that run() requests are submitted to the same event loop handler 

234 # that the connection was made on. 

235 # In that case, each background thread should get its own SshClientCache. 

236 

237 # Maintain one just one event loop thread for all SshService instances. 

238 # But only keep it running while they are within a context. 

239 _EVENT_LOOP_CONTEXT = EventLoopContext() 

240 _EVENT_LOOP_THREAD_SSH_CLIENT_CACHE = SshClientCache() 

241 

242 _REQUEST_TIMEOUT: float | None = None # seconds 

243 

244 def __init__( 

245 self, 

246 config: dict[str, Any] | None = None, 

247 global_config: dict[str, Any] | None = None, 

248 parent: Service | None = None, 

249 methods: dict[str, Callable] | list[Callable] | None = None, 

250 ): 

251 super().__init__(config, global_config, parent, methods) 

252 

253 # Make sure that the value we allow overriding on a per-connection 

254 # basis are present in the config so merge_parameters can do its thing. 

255 self.config.setdefault("ssh_port", None) 

256 assert isinstance(self.config["ssh_port"], (int, type(None))) 

257 self.config.setdefault("ssh_username", None) 

258 assert isinstance(self.config["ssh_username"], (str, type(None))) 

259 self.config.setdefault("ssh_priv_key_path", None) 

260 assert isinstance(self.config["ssh_priv_key_path"], (str, type(None))) 

261 

262 # None can be used to disable the request timeout. 

263 self._request_timeout = self.config.get("ssh_request_timeout", self._REQUEST_TIMEOUT) 

264 self._request_timeout = nullable(float, self._request_timeout) 

265 

266 # Prep an initial connect_params. 

267 self._connect_params: dict = { 

268 # In general scripted commands shouldn't need a pty and having one 

269 # available can confuse some commands, though we may need to make 

270 # this configurable in the future. 

271 "request_pty": False, 

272 # By default disable known_hosts checking (since most VMs expected to be 

273 # dynamically created). 

274 "known_hosts": None, 

275 } 

276 

277 if "ssh_known_hosts_file" in self.config: 

278 self._connect_params["known_hosts"] = self.config.get("ssh_known_hosts_file", None) 

279 if isinstance(self._connect_params["known_hosts"], str): 

280 known_hosts_file = os.path.expanduser(self._connect_params["known_hosts"]) 

281 if not os.path.exists(known_hosts_file): 

282 raise ValueError(f"ssh_known_hosts_file {known_hosts_file} does not exist") 

283 self._connect_params["known_hosts"] = known_hosts_file 

284 if self._connect_params["known_hosts"] is None: 

285 _LOG.info("%s known_hosts checking is disabled per config.", self) 

286 

287 if "ssh_keepalive_interval" in self.config: 

288 keepalive_internal = self.config.get("ssh_keepalive_interval") 

289 self._connect_params["keepalive_interval"] = nullable(int, keepalive_internal) 

290 

291 def _enter_context(self) -> "SshService": 

292 # Start the background thread if it's not already running. 

293 assert not self._in_context 

294 SshService._EVENT_LOOP_CONTEXT.enter() 

295 SshService._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE.enter() 

296 super()._enter_context() 

297 return self 

298 

299 def _exit_context( 

300 self, 

301 ex_type: type[BaseException] | None, 

302 ex_val: BaseException | None, 

303 ex_tb: TracebackType | None, 

304 ) -> Literal[False]: 

305 # Stop the background thread if it's not needed anymore and potentially 

306 # cleanup the cache as well. 

307 assert self._in_context 

308 SshService._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE.exit() 

309 SshService._EVENT_LOOP_CONTEXT.exit() 

310 return super()._exit_context(ex_type, ex_val, ex_tb) 

311 

312 @classmethod 

313 def clear_client_cache(cls) -> None: 

314 """ 

315 Clears the cache of client connections. 

316 

317 Note: This may cause in flight operations to fail. 

318 """ 

319 cls._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE.cleanup() 

320 

321 def _run_coroutine(self, coro: Coroutine[Any, Any, CoroReturnType]) -> FutureReturnType: 

322 """ 

323 Runs the given coroutine in the background event loop thread. 

324 

325 Parameters 

326 ---------- 

327 coro : Coroutine[Any, Any, CoroReturnType] 

328 The coroutine to run. 

329 

330 Returns 

331 ------- 

332 Future[CoroReturnType] 

333 A future that will be completed when the coroutine completes. 

334 """ 

335 assert self._in_context 

336 return self._EVENT_LOOP_CONTEXT.run_coroutine(coro) 

337 

338 def _get_connect_params(self, params: dict) -> dict: 

339 """ 

340 Produces a dict of connection parameters for asyncssh.create_connection. 

341 

342 Parameters 

343 ---------- 

344 params : dict 

345 Additional connection parameters specific to this host. 

346 

347 Returns 

348 ------- 

349 dict 

350 A dict of connection parameters for asyncssh.create_connection. 

351 """ 

352 # Setup default connect_params dict for all SshClients we might need to create. 

353 

354 # Note: None is an acceptable value for several of these, in which case 

355 # reasonable defaults or values from ~/.ssh/config will take effect. 

356 

357 # Start with the base config params. 

358 connect_params = self._connect_params.copy() 

359 

360 connect_params["host"] = params["ssh_hostname"] # required 

361 

362 if params.get("ssh_port"): 

363 connect_params["port"] = int(params.pop("ssh_port")) 

364 elif self.config["ssh_port"]: 

365 connect_params["port"] = int(self.config["ssh_port"]) 

366 

367 if "ssh_username" in params: 

368 connect_params["username"] = str(params.pop("ssh_username")) 

369 elif self.config["ssh_username"]: 

370 connect_params["username"] = str(self.config["ssh_username"]) 

371 

372 priv_key_file: str | None = params.get( 

373 "ssh_priv_key_path", 

374 self.config["ssh_priv_key_path"], 

375 ) 

376 if priv_key_file: 

377 priv_key_file = os.path.expanduser(priv_key_file) 

378 if not os.path.exists(priv_key_file): 

379 raise ValueError(f"ssh_priv_key_path {priv_key_file} does not exist") 

380 connect_params["client_keys"] = [priv_key_file] 

381 

382 return connect_params 

383 

384 async def _get_client_connection(self, params: dict) -> tuple[SSHClientConnection, SshClient]: 

385 """ 

386 Gets a (possibly cached) SshClient (connection) for the given connection params. 

387 

388 Parameters 

389 ---------- 

390 params : dict 

391 Optional override connection parameters. 

392 

393 Returns 

394 ------- 

395 tuple[SSHClientConnection, SshClient] 

396 The connection and client objects. 

397 """ 

398 assert self._in_context 

399 return await SshService._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE.get_client_connection( 

400 self._get_connect_params(params) 

401 )