Coverage for mlos_bench/mlos_bench/services/remote/ssh/ssh_service.py: 89%
158 statements
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-21 01:50 +0000
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-21 01:50 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""A collection functions for interacting with SSH servers as file shares."""
7import logging
8import os
9from abc import ABCMeta
10from asyncio import Event as CoroEvent
11from asyncio import Lock as CoroLock
12from collections.abc import Callable, Coroutine
13from threading import current_thread
14from types import TracebackType
15from typing import Any, Literal
16from warnings import warn
18import asyncssh
19from asyncssh.connection import SSHClientConnection
21from mlos_bench.event_loop_context import (
22 CoroReturnType,
23 EventLoopContext,
24 FutureReturnType,
25)
26from mlos_bench.services.base_service import Service
27from mlos_bench.util import nullable
29_LOG = logging.getLogger(__name__)
32class SshClient(asyncssh.SSHClient):
33 """
34 Wrapper around SSHClient to help provide connection caching and reconnect logic.
36 Used by the SshService to try and maintain a single connection to hosts, handle
37 reconnects if possible, and use that to run commands rather than reconnect for each
38 command.
39 """
41 _CONNECTION_PENDING = "INIT"
42 _CONNECTION_LOST = "LOST"
44 def __init__(self, *args: tuple, **kwargs: dict):
45 self._connection_id: str = SshClient._CONNECTION_PENDING
46 self._connection: SSHClientConnection | None = None
47 self._conn_event: CoroEvent = CoroEvent()
48 super().__init__(*args, **kwargs)
50 def __repr__(self) -> str:
51 return self._connection_id
53 @staticmethod
54 def id_from_connection(connection: SSHClientConnection) -> str:
55 """Gets a unique id repr for the connection."""
56 # pylint: disable=protected-access
57 return f"{connection._username}@{connection._host}:{connection._port}"
59 @staticmethod
60 def id_from_params(connect_params: dict) -> str:
61 """Gets a unique id repr for the connection."""
62 return (
63 f"""{connect_params.get("username")}@{connect_params["host"]}"""
64 f""":{connect_params.get("port")}"""
65 )
67 def connection_made(self, conn: SSHClientConnection) -> None:
68 """
69 Override hook provided by asyncssh.SSHClient.
71 Changes the connection_id from _CONNECTION_PENDING to a unique id repr.
72 """
73 self._conn_event.clear()
74 _LOG.debug(
75 "%s: Connection made by %s: %s",
76 current_thread().name,
77 conn._options.env, # pylint: disable=protected-access
78 conn,
79 )
80 self._connection_id = SshClient.id_from_connection(conn)
81 self._connection = conn
82 self._conn_event.set()
83 return super().connection_made(conn)
85 def connection_lost(self, exc: Exception | None) -> None:
86 self._conn_event.clear()
87 _LOG.debug("%s: %s", current_thread().name, "connection_lost")
88 if exc is None:
89 _LOG.debug(
90 "%s: gracefully disconnected ssh from %s: %s",
91 current_thread().name,
92 self._connection_id,
93 exc,
94 )
95 else:
96 _LOG.debug(
97 "%s: ssh connection lost on %s: %s",
98 current_thread().name,
99 self._connection_id,
100 exc,
101 )
102 self._connection_id = SshClient._CONNECTION_LOST
103 self._connection = None
104 self._conn_event.set()
105 return super().connection_lost(exc)
107 async def connection(self) -> SSHClientConnection | None:
108 """Waits for and returns the asyncssh.connection.SSHClientConnection to be
109 established or lost.
110 """
111 _LOG.debug("%s: Waiting for connection to be available.", current_thread().name)
112 await self._conn_event.wait()
113 _LOG.debug("%s: Connection available for %s", current_thread().name, self._connection_id)
114 return self._connection
117class SshClientCache:
118 """
119 Manages a cache of SshClient connections.
121 Note: Only one per event loop thread supported.
122 See additional details in SshService comments.
123 """
125 def __init__(self) -> None:
126 self._cache: dict[str, tuple[SSHClientConnection, SshClient]] = {}
127 self._cache_lock = CoroLock()
128 self._refcnt: int = 0
130 def __str__(self) -> str:
131 return str(self._cache)
133 def __len__(self) -> int:
134 return len(self._cache)
136 def enter(self) -> None:
137 """
138 Manages the cache lifecycle with reference counting.
140 To be used in the __enter__ method of a caller's context manager.
141 """
142 self._refcnt += 1
144 def exit(self) -> None:
145 """
146 Manages the cache lifecycle with reference counting.
148 To be used in the __exit__ method of a caller's context manager.
149 """
150 self._refcnt -= 1
151 if self._refcnt <= 0:
152 self.cleanup()
153 if self._cache_lock.locked():
154 warn(RuntimeWarning("SshClientCache lock was still held on exit."))
155 self._cache_lock.release()
157 async def get_client_connection(
158 self,
159 connect_params: dict,
160 ) -> tuple[SSHClientConnection, SshClient]:
161 """
162 Gets a (possibly cached) client connection.
164 Parameters
165 ----------
166 connect_params: dict
167 Parameters to pass to asyncssh.create_connection.
169 Returns
170 -------
171 tuple[asyncssh.connection.SSHClientConnection, SshClient]
172 A tuple of (SSHClientConnection, SshClient).
173 """
174 _LOG.debug("%s: get_client_connection: %s", current_thread().name, connect_params)
175 async with self._cache_lock:
176 connection_id = SshClient.id_from_params(connect_params)
177 client: None | SshClient | asyncssh.SSHClient
178 _, client = self._cache.get(connection_id, (None, None))
179 if client:
180 _LOG.debug("%s: Checking cached client %s", current_thread().name, connection_id)
181 connection = await client.connection()
182 if not connection:
183 _LOG.debug(
184 "%s: Removing stale client connection %s from cache.",
185 current_thread().name,
186 connection_id,
187 )
188 self._cache.pop(connection_id)
189 # Try to reconnect next.
190 else:
191 _LOG.debug("%s: Using cached client %s", current_thread().name, connection_id)
192 if connection_id not in self._cache:
193 _LOG.debug(
194 "%s: Establishing client connection to %s",
195 current_thread().name,
196 connection_id,
197 )
198 connection, client = await asyncssh.create_connection(SshClient, **connect_params)
199 assert isinstance(client, SshClient)
200 self._cache[connection_id] = (connection, client)
201 _LOG.debug("%s: Created connection to %s.", current_thread().name, connection_id)
202 return self._cache[connection_id]
204 def cleanup(self) -> None:
205 """Closes all cached connections."""
206 for connection, _ in self._cache.values():
207 connection.close()
208 self._cache = {}
211class SshService(Service, metaclass=ABCMeta):
212 """Base class for SSH services."""
214 # AsyncSSH requires an asyncio event loop to be running to work.
215 # However, running that event loop blocks the main thread.
216 # To avoid having to change our entire API to use async/await, all the way
217 # up the stack, we run the event loop that runs any async code in a
218 # background thread and submit async code to it using
219 # asyncio.run_coroutine_threadsafe, interacting with Futures after that.
220 # This is a bit of a hack, but it works for now.
221 #
222 # The event loop is created on demand and shared across all SshService
223 # instances, hence we need to lock it when doing the creation/cleanup,
224 # or later, during context enter and exit.
225 #
226 # We ran tests to ensure that multiple requests can still be executing
227 # concurrently inside that event loop so there should be no practical
228 # performance loss for our initial cases even with just single background
229 # thread running the event loop.
230 #
231 # Note: the tests were run to confirm that this works with two threads.
232 # Using a larger thread pool requires a bit more work since asyncssh
233 # requires that run() requests are submitted to the same event loop handler
234 # that the connection was made on.
235 # In that case, each background thread should get its own SshClientCache.
237 # Maintain one just one event loop thread for all SshService instances.
238 # But only keep it running while they are within a context.
239 _EVENT_LOOP_CONTEXT = EventLoopContext()
240 _EVENT_LOOP_THREAD_SSH_CLIENT_CACHE = SshClientCache()
242 _REQUEST_TIMEOUT: float | None = None # seconds
244 def __init__(
245 self,
246 config: dict[str, Any] | None = None,
247 global_config: dict[str, Any] | None = None,
248 parent: Service | None = None,
249 methods: dict[str, Callable] | list[Callable] | None = None,
250 ):
251 super().__init__(config, global_config, parent, methods)
253 # Make sure that the value we allow overriding on a per-connection
254 # basis are present in the config so merge_parameters can do its thing.
255 self.config.setdefault("ssh_port", None)
256 assert isinstance(self.config["ssh_port"], (int, type(None)))
257 self.config.setdefault("ssh_username", None)
258 assert isinstance(self.config["ssh_username"], (str, type(None)))
259 self.config.setdefault("ssh_priv_key_path", None)
260 assert isinstance(self.config["ssh_priv_key_path"], (str, type(None)))
262 # None can be used to disable the request timeout.
263 self._request_timeout = self.config.get("ssh_request_timeout", self._REQUEST_TIMEOUT)
264 self._request_timeout = nullable(float, self._request_timeout)
266 # Prep an initial connect_params.
267 self._connect_params: dict = {
268 # In general scripted commands shouldn't need a pty and having one
269 # available can confuse some commands, though we may need to make
270 # this configurable in the future.
271 "request_pty": False,
272 # By default disable known_hosts checking (since most VMs expected to be
273 # dynamically created).
274 "known_hosts": None,
275 }
277 if "ssh_known_hosts_file" in self.config:
278 self._connect_params["known_hosts"] = self.config.get("ssh_known_hosts_file", None)
279 if isinstance(self._connect_params["known_hosts"], str):
280 known_hosts_file = os.path.expanduser(self._connect_params["known_hosts"])
281 if not os.path.exists(known_hosts_file):
282 raise ValueError(f"ssh_known_hosts_file {known_hosts_file} does not exist")
283 self._connect_params["known_hosts"] = known_hosts_file
284 if self._connect_params["known_hosts"] is None:
285 _LOG.info("%s known_hosts checking is disabled per config.", self)
287 if "ssh_keepalive_interval" in self.config:
288 keepalive_internal = self.config.get("ssh_keepalive_interval")
289 self._connect_params["keepalive_interval"] = nullable(int, keepalive_internal)
291 def _enter_context(self) -> "SshService":
292 # Start the background thread if it's not already running.
293 assert not self._in_context
294 SshService._EVENT_LOOP_CONTEXT.enter()
295 SshService._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE.enter()
296 super()._enter_context()
297 return self
299 def _exit_context(
300 self,
301 ex_type: type[BaseException] | None,
302 ex_val: BaseException | None,
303 ex_tb: TracebackType | None,
304 ) -> Literal[False]:
305 # Stop the background thread if it's not needed anymore and potentially
306 # cleanup the cache as well.
307 assert self._in_context
308 SshService._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE.exit()
309 SshService._EVENT_LOOP_CONTEXT.exit()
310 return super()._exit_context(ex_type, ex_val, ex_tb)
312 @classmethod
313 def clear_client_cache(cls) -> None:
314 """
315 Clears the cache of client connections.
317 Note: This may cause in flight operations to fail.
318 """
319 cls._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE.cleanup()
321 def _run_coroutine(self, coro: Coroutine[Any, Any, CoroReturnType]) -> FutureReturnType:
322 """
323 Runs the given coroutine in the background event loop thread.
325 Parameters
326 ----------
327 coro : Coroutine[Any, Any, CoroReturnType]
328 The coroutine to run.
330 Returns
331 -------
332 Future[CoroReturnType]
333 A future that will be completed when the coroutine completes.
334 """
335 assert self._in_context
336 return self._EVENT_LOOP_CONTEXT.run_coroutine(coro)
338 def _get_connect_params(self, params: dict) -> dict:
339 """
340 Produces a dict of connection parameters for asyncssh.create_connection.
342 Parameters
343 ----------
344 params : dict
345 Additional connection parameters specific to this host.
347 Returns
348 -------
349 dict
350 A dict of connection parameters for asyncssh.create_connection.
351 """
352 # Setup default connect_params dict for all SshClients we might need to create.
354 # Note: None is an acceptable value for several of these, in which case
355 # reasonable defaults or values from ~/.ssh/config will take effect.
357 # Start with the base config params.
358 connect_params = self._connect_params.copy()
360 connect_params["host"] = params["ssh_hostname"] # required
362 if params.get("ssh_port"):
363 connect_params["port"] = int(params.pop("ssh_port"))
364 elif self.config["ssh_port"]:
365 connect_params["port"] = int(self.config["ssh_port"])
367 if "ssh_username" in params:
368 connect_params["username"] = str(params.pop("ssh_username"))
369 elif self.config["ssh_username"]:
370 connect_params["username"] = str(self.config["ssh_username"])
372 priv_key_file: str | None = params.get(
373 "ssh_priv_key_path",
374 self.config["ssh_priv_key_path"],
375 )
376 if priv_key_file:
377 priv_key_file = os.path.expanduser(priv_key_file)
378 if not os.path.exists(priv_key_file):
379 raise ValueError(f"ssh_priv_key_path {priv_key_file} does not exist")
380 connect_params["client_keys"] = [priv_key_file]
382 return connect_params
384 async def _get_client_connection(self, params: dict) -> tuple[SSHClientConnection, SshClient]:
385 """
386 Gets a (possibly cached) SshClient (connection) for the given connection params.
388 Parameters
389 ----------
390 params : dict
391 Optional override connection parameters.
393 Returns
394 -------
395 tuple[SSHClientConnection, SshClient]
396 The connection and client objects.
397 """
398 assert self._in_context
399 return await SshService._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE.get_client_connection(
400 self._get_connect_params(params)
401 )