Coverage for mlos_bench/mlos_bench/tests/__init__.py: 83%

60 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2025-01-21 01:50 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5""" 

6Tests for mlos_bench. 

7 

8Used to make mypy happy about multiple conftest.py modules. 

9""" 

10import filecmp 

11import os 

12import shutil 

13import socket 

14from datetime import tzinfo 

15from logging import debug, warning 

16from subprocess import run 

17 

18import pytest 

19import pytz 

20 

21from mlos_bench.util import get_class_from_name, nullable 

22 

23ZONE_NAMES = [ 

24 # Explicit time zones. 

25 "UTC", 

26 "America/Chicago", 

27 "America/Los_Angeles", 

28 # Implicit local time zone. 

29 None, 

30] 

31ZONE_INFO: list[tzinfo | None] = [nullable(pytz.timezone, zone_name) for zone_name in ZONE_NAMES] 

32 

33BUILT_IN_ENV_VAR_DEFAULTS = { 

34 "experiment_id": None, 

35 "trial_id": None, 

36 "trial_runner_id": None, 

37} 

38 

39# A decorator for tests that require docker. 

40# Use with @requires_docker above a test_...() function. 

41DOCKER = shutil.which("docker") 

42if DOCKER: 

43 cmd = run( 

44 "docker builder inspect default || docker buildx inspect default", 

45 shell=True, 

46 check=False, 

47 capture_output=True, 

48 ) 

49 stdout = cmd.stdout.decode() 

50 if cmd.returncode != 0 or not any( 

51 line for line in stdout.splitlines() if "Platform" in line and "linux" in line 

52 ): 

53 debug("Docker is available but missing support for targeting linux platform.") 

54 DOCKER = None 

55requires_docker = pytest.mark.skipif( 

56 not DOCKER, 

57 reason="Docker with Linux support is not available on this system.", 

58) 

59 

60# A decorator for tests that require ssh. 

61# Use with @requires_ssh above a test_...() function. 

62SSH = shutil.which("ssh") 

63requires_ssh = pytest.mark.skipif(not SSH, reason="ssh is not available on this system.") 

64 

65# A common seed to use to avoid tracking down race conditions and intermingling 

66# issues of seeds across tests that run in non-deterministic parallel orders. 

67SEED = 42 

68 

69# import numpy as np 

70# np.random.seed(SEED) 

71 

72 

73def try_resolve_class_name(class_name: str | None) -> str | None: 

74 """Gets the full class name from the given name or None on error.""" 

75 if class_name is None: 

76 return None 

77 try: 

78 the_class = get_class_from_name(class_name) 

79 return the_class.__module__ + "." + the_class.__name__ 

80 except (ValueError, AttributeError, ModuleNotFoundError, ImportError): 

81 return None 

82 

83 

84def check_class_name(obj: object, expected_class_name: str) -> bool: 

85 """Compares the class name of the given object with the given name.""" 

86 full_class_name = obj.__class__.__module__ + "." + obj.__class__.__name__ 

87 return full_class_name == try_resolve_class_name(expected_class_name) 

88 

89 

90def check_socket(host: str, port: int, timeout: float = 1.0) -> bool: 

91 """ 

92 Test to see if a socket is open. 

93 

94 Parameters 

95 ---------- 

96 host : str 

97 port : int 

98 timeout: float 

99 

100 Returns 

101 ------- 

102 bool 

103 """ 

104 with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: 

105 sock.settimeout(timeout) # seconds 

106 result = sock.connect_ex((host, port)) 

107 return result == 0 

108 

109 

110def resolve_host_name(host: str) -> str | None: 

111 """ 

112 Resolves the host name to an IP address. 

113 

114 Parameters 

115 ---------- 

116 host : str 

117 

118 Returns 

119 ------- 

120 str 

121 """ 

122 try: 

123 return socket.gethostbyname(host) 

124 except socket.gaierror: 

125 return None 

126 

127 

128def are_dir_trees_equal(dir1: str, dir2: str) -> bool: 

129 """ 

130 Compare two directories recursively. Files in each directory are assumed to be equal 

131 if their names and contents are equal. 

132 

133 @param dir1: First directory path @param dir2: Second directory path 

134 

135 @return: True if the directory trees are the same and there were no errors while 

136 accessing the directories or files, False otherwise. 

137 """ 

138 # See Also: https://stackoverflow.com/a/6681395 

139 dirs_cmp = filecmp.dircmp(dir1, dir2) 

140 if ( 

141 len(dirs_cmp.left_only) > 0 

142 or len(dirs_cmp.right_only) > 0 

143 or len(dirs_cmp.funny_files) > 0 

144 ): 

145 warning( 

146 f"Found differences in dir trees {dir1}, {dir2}:\n" 

147 f"{dirs_cmp.diff_files}\n{dirs_cmp.funny_files}" 

148 ) 

149 return False 

150 (_, mismatch, errors) = filecmp.cmpfiles(dir1, dir2, dirs_cmp.common_files, shallow=False) 

151 if len(mismatch) > 0 or len(errors) > 0: 

152 warning(f"Found differences in files:\n{mismatch}\n{errors}") 

153 return False 

154 for common_dir in dirs_cmp.common_dirs: 

155 new_dir1 = os.path.join(dir1, common_dir) 

156 new_dir2 = os.path.join(dir2, common_dir) 

157 if not are_dir_trees_equal(new_dir1, new_dir2): 

158 return False 

159 return True