Coverage for mlos_bench/mlos_bench/tests/__init__.py: 85%

60 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-10-07 01:52 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5""" 

6Tests for mlos_bench. 

7 

8Used to make mypy happy about multiple conftest.py modules. 

9""" 

10import filecmp 

11import os 

12import shutil 

13import socket 

14from datetime import tzinfo 

15from logging import debug, warning 

16from subprocess import run 

17from typing import List, Optional 

18 

19import pytest 

20import pytz 

21 

22from mlos_bench.util import get_class_from_name, nullable 

23 

24ZONE_NAMES = [ 

25 # Explicit time zones. 

26 "UTC", 

27 "America/Chicago", 

28 "America/Los_Angeles", 

29 # Implicit local time zone. 

30 None, 

31] 

32ZONE_INFO: List[Optional[tzinfo]] = [ 

33 nullable(pytz.timezone, zone_name) for zone_name in ZONE_NAMES 

34] 

35 

36 

37# A decorator for tests that require docker. 

38# Use with @requires_docker above a test_...() function. 

39DOCKER = shutil.which("docker") 

40if DOCKER: 

41 cmd = run( 

42 "docker builder inspect default || docker buildx inspect default", 

43 shell=True, 

44 check=False, 

45 capture_output=True, 

46 ) 

47 stdout = cmd.stdout.decode() 

48 if cmd.returncode != 0 or not any( 

49 line for line in stdout.splitlines() if "Platform" in line and "linux" in line 

50 ): 

51 debug("Docker is available but missing support for targeting linux platform.") 

52 DOCKER = None 

53requires_docker = pytest.mark.skipif( 

54 not DOCKER, 

55 reason="Docker with Linux support is not available on this system.", 

56) 

57 

58# A decorator for tests that require ssh. 

59# Use with @requires_ssh above a test_...() function. 

60SSH = shutil.which("ssh") 

61requires_ssh = pytest.mark.skipif(not SSH, reason="ssh is not available on this system.") 

62 

63# A common seed to use to avoid tracking down race conditions and intermingling 

64# issues of seeds across tests that run in non-deterministic parallel orders. 

65SEED = 42 

66 

67# import numpy as np 

68# np.random.seed(SEED) 

69 

70 

71def try_resolve_class_name(class_name: Optional[str]) -> Optional[str]: 

72 """Gets the full class name from the given name or None on error.""" 

73 if class_name is None: 

74 return None 

75 try: 

76 the_class = get_class_from_name(class_name) 

77 return the_class.__module__ + "." + the_class.__name__ 

78 except (ValueError, AttributeError, ModuleNotFoundError, ImportError): 

79 return None 

80 

81 

82def check_class_name(obj: object, expected_class_name: str) -> bool: 

83 """Compares the class name of the given object with the given name.""" 

84 full_class_name = obj.__class__.__module__ + "." + obj.__class__.__name__ 

85 return full_class_name == try_resolve_class_name(expected_class_name) 

86 

87 

88def check_socket(host: str, port: int, timeout: float = 1.0) -> bool: 

89 """ 

90 Test to see if a socket is open. 

91 

92 Parameters 

93 ---------- 

94 host : str 

95 port : int 

96 timeout: float 

97 

98 Returns 

99 ------- 

100 bool 

101 """ 

102 with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: 

103 sock.settimeout(timeout) # seconds 

104 result = sock.connect_ex((host, port)) 

105 return result == 0 

106 

107 

108def resolve_host_name(host: str) -> Optional[str]: 

109 """ 

110 Resolves the host name to an IP address. 

111 

112 Parameters 

113 ---------- 

114 host : str 

115 

116 Returns 

117 ------- 

118 str 

119 """ 

120 try: 

121 return socket.gethostbyname(host) 

122 except socket.gaierror: 

123 return None 

124 

125 

126def are_dir_trees_equal(dir1: str, dir2: str) -> bool: 

127 """ 

128 Compare two directories recursively. Files in each directory are assumed to be equal 

129 if their names and contents are equal. 

130 

131 @param dir1: First directory path @param dir2: Second directory path 

132 

133 @return: True if the directory trees are the same and there were no errors while 

134 accessing the directories or files, False otherwise. 

135 """ 

136 # See Also: https://stackoverflow.com/a/6681395 

137 dirs_cmp = filecmp.dircmp(dir1, dir2) 

138 if ( 

139 len(dirs_cmp.left_only) > 0 

140 or len(dirs_cmp.right_only) > 0 

141 or len(dirs_cmp.funny_files) > 0 

142 ): 

143 warning( 

144 ( 

145 f"Found differences in dir trees {dir1}, {dir2}:\n" 

146 f"{dirs_cmp.diff_files}\n{dirs_cmp.funny_files}" 

147 ) 

148 ) 

149 return False 

150 (_, mismatch, errors) = filecmp.cmpfiles(dir1, dir2, dirs_cmp.common_files, shallow=False) 

151 if len(mismatch) > 0 or len(errors) > 0: 

152 warning(f"Found differences in files:\n{mismatch}\n{errors}") 

153 return False 

154 for common_dir in dirs_cmp.common_dirs: 

155 new_dir1 = os.path.join(dir1, common_dir) 

156 new_dir2 = os.path.join(dir2, common_dir) 

157 if not are_dir_trees_equal(new_dir1, new_dir2): 

158 return False 

159 return True