Coverage for mlos_bench/mlos_bench/services/remote/ssh/ssh_service.py: 89%
157 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-10-07 01:52 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-10-07 01:52 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""A collection functions for interacting with SSH servers as file shares."""
7import logging
8import os
9from abc import ABCMeta
10from asyncio import Event as CoroEvent
11from asyncio import Lock as CoroLock
12from threading import current_thread
13from types import TracebackType
14from typing import (
15 Any,
16 Callable,
17 Coroutine,
18 Dict,
19 List,
20 Literal,
21 Optional,
22 Tuple,
23 Type,
24 Union,
25)
26from warnings import warn
28import asyncssh
29from asyncssh.connection import SSHClientConnection
31from mlos_bench.event_loop_context import (
32 CoroReturnType,
33 EventLoopContext,
34 FutureReturnType,
35)
36from mlos_bench.services.base_service import Service
37from mlos_bench.util import nullable
39_LOG = logging.getLogger(__name__)
42class SshClient(asyncssh.SSHClient):
43 """
44 Wrapper around SSHClient to help provide connection caching and reconnect logic.
46 Used by the SshService to try and maintain a single connection to hosts, handle
47 reconnects if possible, and use that to run commands rather than reconnect for each
48 command.
49 """
51 _CONNECTION_PENDING = "INIT"
52 _CONNECTION_LOST = "LOST"
54 def __init__(self, *args: tuple, **kwargs: dict):
55 self._connection_id: str = SshClient._CONNECTION_PENDING
56 self._connection: Optional[SSHClientConnection] = None
57 self._conn_event: CoroEvent = CoroEvent()
58 super().__init__(*args, **kwargs)
60 def __repr__(self) -> str:
61 return self._connection_id
63 @staticmethod
64 def id_from_connection(connection: SSHClientConnection) -> str:
65 """Gets a unique id repr for the connection."""
66 # pylint: disable=protected-access
67 return f"{connection._username}@{connection._host}:{connection._port}"
69 @staticmethod
70 def id_from_params(connect_params: dict) -> str:
71 """Gets a unique id repr for the connection."""
72 return (
73 f"{connect_params.get('username')}@{connect_params['host']}"
74 f":{connect_params.get('port')}"
75 )
77 def connection_made(self, conn: SSHClientConnection) -> None:
78 """
79 Override hook provided by asyncssh.SSHClient.
81 Changes the connection_id from _CONNECTION_PENDING to a unique id repr.
82 """
83 self._conn_event.clear()
84 _LOG.debug(
85 "%s: Connection made by %s: %s",
86 current_thread().name,
87 conn._options.env, # pylint: disable=protected-access
88 conn,
89 )
90 self._connection_id = SshClient.id_from_connection(conn)
91 self._connection = conn
92 self._conn_event.set()
93 return super().connection_made(conn)
95 def connection_lost(self, exc: Optional[Exception]) -> None:
96 self._conn_event.clear()
97 _LOG.debug("%s: %s", current_thread().name, "connection_lost")
98 if exc is None:
99 _LOG.debug(
100 "%s: gracefully disconnected ssh from %s: %s",
101 current_thread().name,
102 self._connection_id,
103 exc,
104 )
105 else:
106 _LOG.debug(
107 "%s: ssh connection lost on %s: %s",
108 current_thread().name,
109 self._connection_id,
110 exc,
111 )
112 self._connection_id = SshClient._CONNECTION_LOST
113 self._connection = None
114 self._conn_event.set()
115 return super().connection_lost(exc)
117 async def connection(self) -> Optional[SSHClientConnection]:
118 """Waits for and returns the SSHClientConnection to be established or lost."""
119 _LOG.debug("%s: Waiting for connection to be available.", current_thread().name)
120 await self._conn_event.wait()
121 _LOG.debug("%s: Connection available for %s", current_thread().name, self._connection_id)
122 return self._connection
125class SshClientCache:
126 """
127 Manages a cache of SshClient connections.
129 Note: Only one per event loop thread supported.
130 See additional details in SshService comments.
131 """
133 def __init__(self) -> None:
134 self._cache: Dict[str, Tuple[SSHClientConnection, SshClient]] = {}
135 self._cache_lock = CoroLock()
136 self._refcnt: int = 0
138 def __str__(self) -> str:
139 return str(self._cache)
141 def __len__(self) -> int:
142 return len(self._cache)
144 def enter(self) -> None:
145 """
146 Manages the cache lifecycle with reference counting.
148 To be used in the __enter__ method of a caller's context manager.
149 """
150 self._refcnt += 1
152 def exit(self) -> None:
153 """
154 Manages the cache lifecycle with reference counting.
156 To be used in the __exit__ method of a caller's context manager.
157 """
158 self._refcnt -= 1
159 if self._refcnt <= 0:
160 self.cleanup()
161 if self._cache_lock.locked():
162 warn(RuntimeWarning("SshClientCache lock was still held on exit."))
163 self._cache_lock.release()
165 async def get_client_connection(
166 self,
167 connect_params: dict,
168 ) -> Tuple[SSHClientConnection, SshClient]:
169 """
170 Gets a (possibly cached) client connection.
172 Parameters
173 ----------
174 connect_params: dict
175 Parameters to pass to asyncssh.create_connection.
177 Returns
178 -------
179 Tuple[SSHClientConnection, SshClient]
180 A tuple of (SSHClientConnection, SshClient).
181 """
182 _LOG.debug("%s: get_client_connection: %s", current_thread().name, connect_params)
183 async with self._cache_lock:
184 connection_id = SshClient.id_from_params(connect_params)
185 client: Union[None, SshClient, asyncssh.SSHClient]
186 _, client = self._cache.get(connection_id, (None, None))
187 if client:
188 _LOG.debug("%s: Checking cached client %s", current_thread().name, connection_id)
189 connection = await client.connection()
190 if not connection:
191 _LOG.debug(
192 "%s: Removing stale client connection %s from cache.",
193 current_thread().name,
194 connection_id,
195 )
196 self._cache.pop(connection_id)
197 # Try to reconnect next.
198 else:
199 _LOG.debug("%s: Using cached client %s", current_thread().name, connection_id)
200 if connection_id not in self._cache:
201 _LOG.debug(
202 "%s: Establishing client connection to %s",
203 current_thread().name,
204 connection_id,
205 )
206 connection, client = await asyncssh.create_connection(SshClient, **connect_params)
207 assert isinstance(client, SshClient)
208 self._cache[connection_id] = (connection, client)
209 _LOG.debug("%s: Created connection to %s.", current_thread().name, connection_id)
210 return self._cache[connection_id]
212 def cleanup(self) -> None:
213 """Closes all cached connections."""
214 for connection, _ in self._cache.values():
215 connection.close()
216 self._cache = {}
219class SshService(Service, metaclass=ABCMeta):
220 """Base class for SSH services."""
222 # AsyncSSH requires an asyncio event loop to be running to work.
223 # However, running that event loop blocks the main thread.
224 # To avoid having to change our entire API to use async/await, all the way
225 # up the stack, we run the event loop that runs any async code in a
226 # background thread and submit async code to it using
227 # asyncio.run_coroutine_threadsafe, interacting with Futures after that.
228 # This is a bit of a hack, but it works for now.
229 #
230 # The event loop is created on demand and shared across all SshService
231 # instances, hence we need to lock it when doing the creation/cleanup,
232 # or later, during context enter and exit.
233 #
234 # We ran tests to ensure that multiple requests can still be executing
235 # concurrently inside that event loop so there should be no practical
236 # performance loss for our initial cases even with just single background
237 # thread running the event loop.
238 #
239 # Note: the tests were run to confirm that this works with two threads.
240 # Using a larger thread pool requires a bit more work since asyncssh
241 # requires that run() requests are submitted to the same event loop handler
242 # that the connection was made on.
243 # In that case, each background thread should get its own SshClientCache.
245 # Maintain one just one event loop thread for all SshService instances.
246 # But only keep it running while they are within a context.
247 _EVENT_LOOP_CONTEXT = EventLoopContext()
248 _EVENT_LOOP_THREAD_SSH_CLIENT_CACHE = SshClientCache()
250 _REQUEST_TIMEOUT: Optional[float] = None # seconds
252 def __init__(
253 self,
254 config: Optional[Dict[str, Any]] = None,
255 global_config: Optional[Dict[str, Any]] = None,
256 parent: Optional[Service] = None,
257 methods: Union[Dict[str, Callable], List[Callable], None] = None,
258 ):
259 super().__init__(config, global_config, parent, methods)
261 # Make sure that the value we allow overriding on a per-connection
262 # basis are present in the config so merge_parameters can do its thing.
263 self.config.setdefault("ssh_port", None)
264 assert isinstance(self.config["ssh_port"], (int, type(None)))
265 self.config.setdefault("ssh_username", None)
266 assert isinstance(self.config["ssh_username"], (str, type(None)))
267 self.config.setdefault("ssh_priv_key_path", None)
268 assert isinstance(self.config["ssh_priv_key_path"], (str, type(None)))
270 # None can be used to disable the request timeout.
271 self._request_timeout = self.config.get("ssh_request_timeout", self._REQUEST_TIMEOUT)
272 self._request_timeout = nullable(float, self._request_timeout)
274 # Prep an initial connect_params.
275 self._connect_params: dict = {
276 # In general scripted commands shouldn't need a pty and having one
277 # available can confuse some commands, though we may need to make
278 # this configurable in the future.
279 "request_pty": False,
280 # By default disable known_hosts checking (since most VMs expected to be
281 # dynamically created).
282 "known_hosts": None,
283 }
285 if "ssh_known_hosts_file" in self.config:
286 self._connect_params["known_hosts"] = self.config.get("ssh_known_hosts_file", None)
287 if isinstance(self._connect_params["known_hosts"], str):
288 known_hosts_file = os.path.expanduser(self._connect_params["known_hosts"])
289 if not os.path.exists(known_hosts_file):
290 raise ValueError(f"ssh_known_hosts_file {known_hosts_file} does not exist")
291 self._connect_params["known_hosts"] = known_hosts_file
292 if self._connect_params["known_hosts"] is None:
293 _LOG.info("%s known_hosts checking is disabled per config.", self)
295 if "ssh_keepalive_interval" in self.config:
296 keepalive_internal = self.config.get("ssh_keepalive_interval")
297 self._connect_params["keepalive_interval"] = nullable(int, keepalive_internal)
299 def _enter_context(self) -> "SshService":
300 # Start the background thread if it's not already running.
301 assert not self._in_context
302 SshService._EVENT_LOOP_CONTEXT.enter()
303 SshService._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE.enter()
304 super()._enter_context()
305 return self
307 def _exit_context(
308 self,
309 ex_type: Optional[Type[BaseException]],
310 ex_val: Optional[BaseException],
311 ex_tb: Optional[TracebackType],
312 ) -> Literal[False]:
313 # Stop the background thread if it's not needed anymore and potentially
314 # cleanup the cache as well.
315 assert self._in_context
316 SshService._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE.exit()
317 SshService._EVENT_LOOP_CONTEXT.exit()
318 return super()._exit_context(ex_type, ex_val, ex_tb)
320 @classmethod
321 def clear_client_cache(cls) -> None:
322 """
323 Clears the cache of client connections.
325 Note: This may cause in flight operations to fail.
326 """
327 cls._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE.cleanup()
329 def _run_coroutine(self, coro: Coroutine[Any, Any, CoroReturnType]) -> FutureReturnType:
330 """
331 Runs the given coroutine in the background event loop thread.
333 Parameters
334 ----------
335 coro : Coroutine[Any, Any, CoroReturnType]
336 The coroutine to run.
338 Returns
339 -------
340 Future[CoroReturnType]
341 A future that will be completed when the coroutine completes.
342 """
343 assert self._in_context
344 return self._EVENT_LOOP_CONTEXT.run_coroutine(coro)
346 def _get_connect_params(self, params: dict) -> dict:
347 """
348 Produces a dict of connection parameters for asyncssh.create_connection.
350 Parameters
351 ----------
352 params : dict
353 Additional connection parameters specific to this host.
355 Returns
356 -------
357 dict
358 A dict of connection parameters for asyncssh.create_connection.
359 """
360 # Setup default connect_params dict for all SshClients we might need to create.
362 # Note: None is an acceptable value for several of these, in which case
363 # reasonable defaults or values from ~/.ssh/config will take effect.
365 # Start with the base config params.
366 connect_params = self._connect_params.copy()
368 connect_params["host"] = params["ssh_hostname"] # required
370 if params.get("ssh_port"):
371 connect_params["port"] = int(params.pop("ssh_port"))
372 elif self.config["ssh_port"]:
373 connect_params["port"] = int(self.config["ssh_port"])
375 if "ssh_username" in params:
376 connect_params["username"] = str(params.pop("ssh_username"))
377 elif self.config["ssh_username"]:
378 connect_params["username"] = str(self.config["ssh_username"])
380 priv_key_file: Optional[str] = params.get(
381 "ssh_priv_key_path",
382 self.config["ssh_priv_key_path"],
383 )
384 if priv_key_file:
385 priv_key_file = os.path.expanduser(priv_key_file)
386 if not os.path.exists(priv_key_file):
387 raise ValueError(f"ssh_priv_key_path {priv_key_file} does not exist")
388 connect_params["client_keys"] = [priv_key_file]
390 return connect_params
392 async def _get_client_connection(self, params: dict) -> Tuple[SSHClientConnection, SshClient]:
393 """
394 Gets a (possibly cached) SshClient (connection) for the given connection params.
396 Parameters
397 ----------
398 params : dict
399 Optional override connection parameters.
401 Returns
402 -------
403 Tuple[SSHClientConnection, SshClient]
404 The connection and client objects.
405 """
406 assert self._in_context
407 return await SshService._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE.get_client_connection(
408 self._get_connect_params(params)
409 )