Coverage for mlos_bench/mlos_bench/tests/__init__.py: 83%
60 statements
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-21 01:50 +0000
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-21 01:50 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""
6Tests for mlos_bench.
8Used to make mypy happy about multiple conftest.py modules.
9"""
10import filecmp
11import os
12import shutil
13import socket
14from datetime import tzinfo
15from logging import debug, warning
16from subprocess import run
18import pytest
19import pytz
21from mlos_bench.util import get_class_from_name, nullable
23ZONE_NAMES = [
24 # Explicit time zones.
25 "UTC",
26 "America/Chicago",
27 "America/Los_Angeles",
28 # Implicit local time zone.
29 None,
30]
31ZONE_INFO: list[tzinfo | None] = [nullable(pytz.timezone, zone_name) for zone_name in ZONE_NAMES]
33BUILT_IN_ENV_VAR_DEFAULTS = {
34 "experiment_id": None,
35 "trial_id": None,
36 "trial_runner_id": None,
37}
39# A decorator for tests that require docker.
40# Use with @requires_docker above a test_...() function.
41DOCKER = shutil.which("docker")
42if DOCKER:
43 cmd = run(
44 "docker builder inspect default || docker buildx inspect default",
45 shell=True,
46 check=False,
47 capture_output=True,
48 )
49 stdout = cmd.stdout.decode()
50 if cmd.returncode != 0 or not any(
51 line for line in stdout.splitlines() if "Platform" in line and "linux" in line
52 ):
53 debug("Docker is available but missing support for targeting linux platform.")
54 DOCKER = None
55requires_docker = pytest.mark.skipif(
56 not DOCKER,
57 reason="Docker with Linux support is not available on this system.",
58)
60# A decorator for tests that require ssh.
61# Use with @requires_ssh above a test_...() function.
62SSH = shutil.which("ssh")
63requires_ssh = pytest.mark.skipif(not SSH, reason="ssh is not available on this system.")
65# A common seed to use to avoid tracking down race conditions and intermingling
66# issues of seeds across tests that run in non-deterministic parallel orders.
67SEED = 42
69# import numpy as np
70# np.random.seed(SEED)
73def try_resolve_class_name(class_name: str | None) -> str | None:
74 """Gets the full class name from the given name or None on error."""
75 if class_name is None:
76 return None
77 try:
78 the_class = get_class_from_name(class_name)
79 return the_class.__module__ + "." + the_class.__name__
80 except (ValueError, AttributeError, ModuleNotFoundError, ImportError):
81 return None
84def check_class_name(obj: object, expected_class_name: str) -> bool:
85 """Compares the class name of the given object with the given name."""
86 full_class_name = obj.__class__.__module__ + "." + obj.__class__.__name__
87 return full_class_name == try_resolve_class_name(expected_class_name)
90def check_socket(host: str, port: int, timeout: float = 1.0) -> bool:
91 """
92 Test to see if a socket is open.
94 Parameters
95 ----------
96 host : str
97 port : int
98 timeout: float
100 Returns
101 -------
102 bool
103 """
104 with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
105 sock.settimeout(timeout) # seconds
106 result = sock.connect_ex((host, port))
107 return result == 0
110def resolve_host_name(host: str) -> str | None:
111 """
112 Resolves the host name to an IP address.
114 Parameters
115 ----------
116 host : str
118 Returns
119 -------
120 str
121 """
122 try:
123 return socket.gethostbyname(host)
124 except socket.gaierror:
125 return None
128def are_dir_trees_equal(dir1: str, dir2: str) -> bool:
129 """
130 Compare two directories recursively. Files in each directory are assumed to be equal
131 if their names and contents are equal.
133 @param dir1: First directory path @param dir2: Second directory path
135 @return: True if the directory trees are the same and there were no errors while
136 accessing the directories or files, False otherwise.
137 """
138 # See Also: https://stackoverflow.com/a/6681395
139 dirs_cmp = filecmp.dircmp(dir1, dir2)
140 if (
141 len(dirs_cmp.left_only) > 0
142 or len(dirs_cmp.right_only) > 0
143 or len(dirs_cmp.funny_files) > 0
144 ):
145 warning(
146 f"Found differences in dir trees {dir1}, {dir2}:\n"
147 f"{dirs_cmp.diff_files}\n{dirs_cmp.funny_files}"
148 )
149 return False
150 (_, mismatch, errors) = filecmp.cmpfiles(dir1, dir2, dirs_cmp.common_files, shallow=False)
151 if len(mismatch) > 0 or len(errors) > 0:
152 warning(f"Found differences in files:\n{mismatch}\n{errors}")
153 return False
154 for common_dir in dirs_cmp.common_dirs:
155 new_dir1 = os.path.join(dir1, common_dir)
156 new_dir2 = os.path.join(dir2, common_dir)
157 if not are_dir_trees_equal(new_dir1, new_dir2):
158 return False
159 return True