Coverage for mlos_bench/mlos_bench/tests/__init__.py: 85%
60 statements
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-14 01:58 +0000
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-14 01:58 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""
6Tests for mlos_bench.
8Used to make mypy happy about multiple conftest.py modules.
9"""
10import filecmp
11import os
12import shutil
13import socket
14from datetime import tzinfo
15from logging import debug, warning
16from subprocess import run
17from typing import List, Optional
19import pytest
20import pytz
22from mlos_bench.util import get_class_from_name, nullable
24ZONE_NAMES = [
25 # Explicit time zones.
26 "UTC",
27 "America/Chicago",
28 "America/Los_Angeles",
29 # Implicit local time zone.
30 None,
31]
32ZONE_INFO: List[Optional[tzinfo]] = [
33 nullable(pytz.timezone, zone_name) for zone_name in ZONE_NAMES
34]
37# A decorator for tests that require docker.
38# Use with @requires_docker above a test_...() function.
39DOCKER = shutil.which("docker")
40if DOCKER:
41 cmd = run(
42 "docker builder inspect default || docker buildx inspect default",
43 shell=True,
44 check=False,
45 capture_output=True,
46 )
47 stdout = cmd.stdout.decode()
48 if cmd.returncode != 0 or not any(
49 line for line in stdout.splitlines() if "Platform" in line and "linux" in line
50 ):
51 debug("Docker is available but missing support for targeting linux platform.")
52 DOCKER = None
53requires_docker = pytest.mark.skipif(
54 not DOCKER,
55 reason="Docker with Linux support is not available on this system.",
56)
58# A decorator for tests that require ssh.
59# Use with @requires_ssh above a test_...() function.
60SSH = shutil.which("ssh")
61requires_ssh = pytest.mark.skipif(not SSH, reason="ssh is not available on this system.")
63# A common seed to use to avoid tracking down race conditions and intermingling
64# issues of seeds across tests that run in non-deterministic parallel orders.
65SEED = 42
67# import numpy as np
68# np.random.seed(SEED)
71def try_resolve_class_name(class_name: Optional[str]) -> Optional[str]:
72 """Gets the full class name from the given name or None on error."""
73 if class_name is None:
74 return None
75 try:
76 the_class = get_class_from_name(class_name)
77 return the_class.__module__ + "." + the_class.__name__
78 except (ValueError, AttributeError, ModuleNotFoundError, ImportError):
79 return None
82def check_class_name(obj: object, expected_class_name: str) -> bool:
83 """Compares the class name of the given object with the given name."""
84 full_class_name = obj.__class__.__module__ + "." + obj.__class__.__name__
85 return full_class_name == try_resolve_class_name(expected_class_name)
88def check_socket(host: str, port: int, timeout: float = 1.0) -> bool:
89 """
90 Test to see if a socket is open.
92 Parameters
93 ----------
94 host : str
95 port : int
96 timeout: float
98 Returns
99 -------
100 bool
101 """
102 with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
103 sock.settimeout(timeout) # seconds
104 result = sock.connect_ex((host, port))
105 return result == 0
108def resolve_host_name(host: str) -> Optional[str]:
109 """
110 Resolves the host name to an IP address.
112 Parameters
113 ----------
114 host : str
116 Returns
117 -------
118 str
119 """
120 try:
121 return socket.gethostbyname(host)
122 except socket.gaierror:
123 return None
126def are_dir_trees_equal(dir1: str, dir2: str) -> bool:
127 """
128 Compare two directories recursively. Files in each directory are assumed to be equal
129 if their names and contents are equal.
131 @param dir1: First directory path @param dir2: Second directory path
133 @return: True if the directory trees are the same and there were no errors while
134 accessing the directories or files, False otherwise.
135 """
136 # See Also: https://stackoverflow.com/a/6681395
137 dirs_cmp = filecmp.dircmp(dir1, dir2)
138 if (
139 len(dirs_cmp.left_only) > 0
140 or len(dirs_cmp.right_only) > 0
141 or len(dirs_cmp.funny_files) > 0
142 ):
143 warning(
144 (
145 f"Found differences in dir trees {dir1}, {dir2}:\n"
146 f"{dirs_cmp.diff_files}\n{dirs_cmp.funny_files}"
147 )
148 )
149 return False
150 (_, mismatch, errors) = filecmp.cmpfiles(dir1, dir2, dirs_cmp.common_files, shallow=False)
151 if len(mismatch) > 0 or len(errors) > 0:
152 warning(f"Found differences in files:\n{mismatch}\n{errors}")
153 return False
154 for common_dir in dirs_cmp.common_dirs:
155 new_dir1 = os.path.join(dir1, common_dir)
156 new_dir2 = os.path.join(dir2, common_dir)
157 if not are_dir_trees_equal(new_dir1, new_dir2):
158 return False
159 return True