Coverage for mlos_bench/mlos_bench/services/remote/ssh/ssh_service.py: 89%
157 statements
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-14 01:58 +0000
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-14 01:58 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""A collection functions for interacting with SSH servers as file shares."""
7import logging
8import os
9from abc import ABCMeta
10from asyncio import Event as CoroEvent
11from asyncio import Lock as CoroLock
12from threading import current_thread
13from types import TracebackType
14from typing import (
15 Any,
16 Callable,
17 Coroutine,
18 Dict,
19 List,
20 Literal,
21 Optional,
22 Tuple,
23 Type,
24 Union,
25)
26from warnings import warn
28import asyncssh
29from asyncssh.connection import SSHClientConnection
31from mlos_bench.event_loop_context import (
32 CoroReturnType,
33 EventLoopContext,
34 FutureReturnType,
35)
36from mlos_bench.services.base_service import Service
37from mlos_bench.util import nullable
39_LOG = logging.getLogger(__name__)
42class SshClient(asyncssh.SSHClient):
43 """
44 Wrapper around SSHClient to help provide connection caching and reconnect logic.
46 Used by the SshService to try and maintain a single connection to hosts, handle
47 reconnects if possible, and use that to run commands rather than reconnect for each
48 command.
49 """
51 _CONNECTION_PENDING = "INIT"
52 _CONNECTION_LOST = "LOST"
54 def __init__(self, *args: tuple, **kwargs: dict):
55 self._connection_id: str = SshClient._CONNECTION_PENDING
56 self._connection: Optional[SSHClientConnection] = None
57 self._conn_event: CoroEvent = CoroEvent()
58 super().__init__(*args, **kwargs)
60 def __repr__(self) -> str:
61 return self._connection_id
63 @staticmethod
64 def id_from_connection(connection: SSHClientConnection) -> str:
65 """Gets a unique id repr for the connection."""
66 # pylint: disable=protected-access
67 return f"{connection._username}@{connection._host}:{connection._port}"
69 @staticmethod
70 def id_from_params(connect_params: dict) -> str:
71 """Gets a unique id repr for the connection."""
72 return (
73 f"""{connect_params.get("username")}@{connect_params["host"]}"""
74 f""":{connect_params.get("port")}"""
75 )
77 def connection_made(self, conn: SSHClientConnection) -> None:
78 """
79 Override hook provided by asyncssh.SSHClient.
81 Changes the connection_id from _CONNECTION_PENDING to a unique id repr.
82 """
83 self._conn_event.clear()
84 _LOG.debug(
85 "%s: Connection made by %s: %s",
86 current_thread().name,
87 conn._options.env, # pylint: disable=protected-access
88 conn,
89 )
90 self._connection_id = SshClient.id_from_connection(conn)
91 self._connection = conn
92 self._conn_event.set()
93 return super().connection_made(conn)
95 def connection_lost(self, exc: Optional[Exception]) -> None:
96 self._conn_event.clear()
97 _LOG.debug("%s: %s", current_thread().name, "connection_lost")
98 if exc is None:
99 _LOG.debug(
100 "%s: gracefully disconnected ssh from %s: %s",
101 current_thread().name,
102 self._connection_id,
103 exc,
104 )
105 else:
106 _LOG.debug(
107 "%s: ssh connection lost on %s: %s",
108 current_thread().name,
109 self._connection_id,
110 exc,
111 )
112 self._connection_id = SshClient._CONNECTION_LOST
113 self._connection = None
114 self._conn_event.set()
115 return super().connection_lost(exc)
117 async def connection(self) -> Optional[SSHClientConnection]:
118 """Waits for and returns the asyncssh.connection.SSHClientConnection to be
119 established or lost.
120 """
121 _LOG.debug("%s: Waiting for connection to be available.", current_thread().name)
122 await self._conn_event.wait()
123 _LOG.debug("%s: Connection available for %s", current_thread().name, self._connection_id)
124 return self._connection
127class SshClientCache:
128 """
129 Manages a cache of SshClient connections.
131 Note: Only one per event loop thread supported.
132 See additional details in SshService comments.
133 """
135 def __init__(self) -> None:
136 self._cache: Dict[str, Tuple[SSHClientConnection, SshClient]] = {}
137 self._cache_lock = CoroLock()
138 self._refcnt: int = 0
140 def __str__(self) -> str:
141 return str(self._cache)
143 def __len__(self) -> int:
144 return len(self._cache)
146 def enter(self) -> None:
147 """
148 Manages the cache lifecycle with reference counting.
150 To be used in the __enter__ method of a caller's context manager.
151 """
152 self._refcnt += 1
154 def exit(self) -> None:
155 """
156 Manages the cache lifecycle with reference counting.
158 To be used in the __exit__ method of a caller's context manager.
159 """
160 self._refcnt -= 1
161 if self._refcnt <= 0:
162 self.cleanup()
163 if self._cache_lock.locked():
164 warn(RuntimeWarning("SshClientCache lock was still held on exit."))
165 self._cache_lock.release()
167 async def get_client_connection(
168 self,
169 connect_params: dict,
170 ) -> Tuple[SSHClientConnection, SshClient]:
171 """
172 Gets a (possibly cached) client connection.
174 Parameters
175 ----------
176 connect_params: dict
177 Parameters to pass to asyncssh.create_connection.
179 Returns
180 -------
181 Tuple[asyncssh.connection.SSHClientConnection, SshClient]
182 A tuple of (SSHClientConnection, SshClient).
183 """
184 _LOG.debug("%s: get_client_connection: %s", current_thread().name, connect_params)
185 async with self._cache_lock:
186 connection_id = SshClient.id_from_params(connect_params)
187 client: Union[None, SshClient, asyncssh.SSHClient]
188 _, client = self._cache.get(connection_id, (None, None))
189 if client:
190 _LOG.debug("%s: Checking cached client %s", current_thread().name, connection_id)
191 connection = await client.connection()
192 if not connection:
193 _LOG.debug(
194 "%s: Removing stale client connection %s from cache.",
195 current_thread().name,
196 connection_id,
197 )
198 self._cache.pop(connection_id)
199 # Try to reconnect next.
200 else:
201 _LOG.debug("%s: Using cached client %s", current_thread().name, connection_id)
202 if connection_id not in self._cache:
203 _LOG.debug(
204 "%s: Establishing client connection to %s",
205 current_thread().name,
206 connection_id,
207 )
208 connection, client = await asyncssh.create_connection(SshClient, **connect_params)
209 assert isinstance(client, SshClient)
210 self._cache[connection_id] = (connection, client)
211 _LOG.debug("%s: Created connection to %s.", current_thread().name, connection_id)
212 return self._cache[connection_id]
214 def cleanup(self) -> None:
215 """Closes all cached connections."""
216 for connection, _ in self._cache.values():
217 connection.close()
218 self._cache = {}
221class SshService(Service, metaclass=ABCMeta):
222 """Base class for SSH services."""
224 # AsyncSSH requires an asyncio event loop to be running to work.
225 # However, running that event loop blocks the main thread.
226 # To avoid having to change our entire API to use async/await, all the way
227 # up the stack, we run the event loop that runs any async code in a
228 # background thread and submit async code to it using
229 # asyncio.run_coroutine_threadsafe, interacting with Futures after that.
230 # This is a bit of a hack, but it works for now.
231 #
232 # The event loop is created on demand and shared across all SshService
233 # instances, hence we need to lock it when doing the creation/cleanup,
234 # or later, during context enter and exit.
235 #
236 # We ran tests to ensure that multiple requests can still be executing
237 # concurrently inside that event loop so there should be no practical
238 # performance loss for our initial cases even with just single background
239 # thread running the event loop.
240 #
241 # Note: the tests were run to confirm that this works with two threads.
242 # Using a larger thread pool requires a bit more work since asyncssh
243 # requires that run() requests are submitted to the same event loop handler
244 # that the connection was made on.
245 # In that case, each background thread should get its own SshClientCache.
247 # Maintain one just one event loop thread for all SshService instances.
248 # But only keep it running while they are within a context.
249 _EVENT_LOOP_CONTEXT = EventLoopContext()
250 _EVENT_LOOP_THREAD_SSH_CLIENT_CACHE = SshClientCache()
252 _REQUEST_TIMEOUT: Optional[float] = None # seconds
254 def __init__(
255 self,
256 config: Optional[Dict[str, Any]] = None,
257 global_config: Optional[Dict[str, Any]] = None,
258 parent: Optional[Service] = None,
259 methods: Union[Dict[str, Callable], List[Callable], None] = None,
260 ):
261 super().__init__(config, global_config, parent, methods)
263 # Make sure that the value we allow overriding on a per-connection
264 # basis are present in the config so merge_parameters can do its thing.
265 self.config.setdefault("ssh_port", None)
266 assert isinstance(self.config["ssh_port"], (int, type(None)))
267 self.config.setdefault("ssh_username", None)
268 assert isinstance(self.config["ssh_username"], (str, type(None)))
269 self.config.setdefault("ssh_priv_key_path", None)
270 assert isinstance(self.config["ssh_priv_key_path"], (str, type(None)))
272 # None can be used to disable the request timeout.
273 self._request_timeout = self.config.get("ssh_request_timeout", self._REQUEST_TIMEOUT)
274 self._request_timeout = nullable(float, self._request_timeout)
276 # Prep an initial connect_params.
277 self._connect_params: dict = {
278 # In general scripted commands shouldn't need a pty and having one
279 # available can confuse some commands, though we may need to make
280 # this configurable in the future.
281 "request_pty": False,
282 # By default disable known_hosts checking (since most VMs expected to be
283 # dynamically created).
284 "known_hosts": None,
285 }
287 if "ssh_known_hosts_file" in self.config:
288 self._connect_params["known_hosts"] = self.config.get("ssh_known_hosts_file", None)
289 if isinstance(self._connect_params["known_hosts"], str):
290 known_hosts_file = os.path.expanduser(self._connect_params["known_hosts"])
291 if not os.path.exists(known_hosts_file):
292 raise ValueError(f"ssh_known_hosts_file {known_hosts_file} does not exist")
293 self._connect_params["known_hosts"] = known_hosts_file
294 if self._connect_params["known_hosts"] is None:
295 _LOG.info("%s known_hosts checking is disabled per config.", self)
297 if "ssh_keepalive_interval" in self.config:
298 keepalive_internal = self.config.get("ssh_keepalive_interval")
299 self._connect_params["keepalive_interval"] = nullable(int, keepalive_internal)
301 def _enter_context(self) -> "SshService":
302 # Start the background thread if it's not already running.
303 assert not self._in_context
304 SshService._EVENT_LOOP_CONTEXT.enter()
305 SshService._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE.enter()
306 super()._enter_context()
307 return self
309 def _exit_context(
310 self,
311 ex_type: Optional[Type[BaseException]],
312 ex_val: Optional[BaseException],
313 ex_tb: Optional[TracebackType],
314 ) -> Literal[False]:
315 # Stop the background thread if it's not needed anymore and potentially
316 # cleanup the cache as well.
317 assert self._in_context
318 SshService._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE.exit()
319 SshService._EVENT_LOOP_CONTEXT.exit()
320 return super()._exit_context(ex_type, ex_val, ex_tb)
322 @classmethod
323 def clear_client_cache(cls) -> None:
324 """
325 Clears the cache of client connections.
327 Note: This may cause in flight operations to fail.
328 """
329 cls._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE.cleanup()
331 def _run_coroutine(self, coro: Coroutine[Any, Any, CoroReturnType]) -> FutureReturnType:
332 """
333 Runs the given coroutine in the background event loop thread.
335 Parameters
336 ----------
337 coro : Coroutine[Any, Any, CoroReturnType]
338 The coroutine to run.
340 Returns
341 -------
342 Future[CoroReturnType]
343 A future that will be completed when the coroutine completes.
344 """
345 assert self._in_context
346 return self._EVENT_LOOP_CONTEXT.run_coroutine(coro)
348 def _get_connect_params(self, params: dict) -> dict:
349 """
350 Produces a dict of connection parameters for asyncssh.create_connection.
352 Parameters
353 ----------
354 params : dict
355 Additional connection parameters specific to this host.
357 Returns
358 -------
359 dict
360 A dict of connection parameters for asyncssh.create_connection.
361 """
362 # Setup default connect_params dict for all SshClients we might need to create.
364 # Note: None is an acceptable value for several of these, in which case
365 # reasonable defaults or values from ~/.ssh/config will take effect.
367 # Start with the base config params.
368 connect_params = self._connect_params.copy()
370 connect_params["host"] = params["ssh_hostname"] # required
372 if params.get("ssh_port"):
373 connect_params["port"] = int(params.pop("ssh_port"))
374 elif self.config["ssh_port"]:
375 connect_params["port"] = int(self.config["ssh_port"])
377 if "ssh_username" in params:
378 connect_params["username"] = str(params.pop("ssh_username"))
379 elif self.config["ssh_username"]:
380 connect_params["username"] = str(self.config["ssh_username"])
382 priv_key_file: Optional[str] = params.get(
383 "ssh_priv_key_path",
384 self.config["ssh_priv_key_path"],
385 )
386 if priv_key_file:
387 priv_key_file = os.path.expanduser(priv_key_file)
388 if not os.path.exists(priv_key_file):
389 raise ValueError(f"ssh_priv_key_path {priv_key_file} does not exist")
390 connect_params["client_keys"] = [priv_key_file]
392 return connect_params
394 async def _get_client_connection(self, params: dict) -> Tuple[SSHClientConnection, SshClient]:
395 """
396 Gets a (possibly cached) SshClient (connection) for the given connection params.
398 Parameters
399 ----------
400 params : dict
401 Optional override connection parameters.
403 Returns
404 -------
405 Tuple[SSHClientConnection, SshClient]
406 The connection and client objects.
407 """
408 assert self._in_context
409 return await SshService._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE.get_client_connection(
410 self._get_connect_params(params)
411 )