Coverage for mlos_bench/mlos_bench/environments/local/local_fileshare_env.py: 88%

56 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2025-01-21 01:50 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5"""Scheduler-side Environment to run scripts locally and upload/download data to the 

6shared storage. 

7""" 

8 

9import logging 

10from collections.abc import Generator, Iterable, Mapping 

11from datetime import datetime 

12from string import Template 

13from typing import Any 

14 

15from mlos_bench.environments.local.local_env import LocalEnv 

16from mlos_bench.environments.status import Status 

17from mlos_bench.services.base_service import Service 

18from mlos_bench.services.types.fileshare_type import SupportsFileShareOps 

19from mlos_bench.services.types.local_exec_type import SupportsLocalExec 

20from mlos_bench.tunables.tunable import TunableValue 

21from mlos_bench.tunables.tunable_groups import TunableGroups 

22 

23_LOG = logging.getLogger(__name__) 

24 

25 

26class LocalFileShareEnv(LocalEnv): 

27 """Scheduler-side Environment that runs scripts locally and uploads/downloads data 

28 to the shared file storage. 

29 """ 

30 

31 def __init__( # pylint: disable=too-many-arguments 

32 self, 

33 *, 

34 name: str, 

35 config: dict, 

36 global_config: dict | None = None, 

37 tunables: TunableGroups | None = None, 

38 service: Service | None = None, 

39 ): 

40 """ 

41 Create a new application environment with a given config. 

42 

43 Parameters 

44 ---------- 

45 name: str 

46 Human-readable name of the environment. 

47 config : dict 

48 Free-format dictionary that contains the benchmark environment 

49 configuration. Each config must have at least the "tunable_params" 

50 and the "const_args" sections. 

51 `LocalFileShareEnv` must also have at least some of the following 

52 parameters: {setup, upload, run, download, teardown, 

53 dump_params_file, read_results_file} 

54 global_config : dict 

55 Free-format dictionary of global parameters (e.g., security credentials) 

56 to be mixed in into the "const_args" section of the local config. 

57 tunables : TunableGroups 

58 A collection of tunable parameters for *all* environments. 

59 service: Service 

60 An optional service object (e.g., providing methods to 

61 deploy or reboot a VM, etc.). 

62 """ 

63 super().__init__( 

64 name=name, 

65 config=config, 

66 global_config=global_config, 

67 tunables=tunables, 

68 service=service, 

69 ) 

70 

71 assert self._service is not None and isinstance( 

72 self._service, SupportsLocalExec 

73 ), "LocalEnv requires a service that supports local execution" 

74 self._local_exec_service: SupportsLocalExec = self._service 

75 

76 assert self._service is not None and isinstance( 

77 self._service, SupportsFileShareOps 

78 ), "LocalEnv requires a service that supports file upload/download operations" 

79 self._file_share_service: SupportsFileShareOps = self._service 

80 

81 self._upload = self._template_from_to("upload") 

82 self._download = self._template_from_to("download") 

83 

84 def _template_from_to(self, config_key: str) -> list[tuple[Template, Template]]: 

85 """Convert a list of {"from": "...", "to": "..."} to a list of pairs of 

86 string.Template objects so that we can plug in self._params into it later. 

87 """ 

88 return [(Template(d["from"]), Template(d["to"])) for d in self.config.get(config_key, [])] 

89 

90 @staticmethod 

91 def _expand( 

92 from_to: Iterable[tuple[Template, Template]], 

93 params: Mapping[str, TunableValue], 

94 ) -> Generator[tuple[str, str]]: 

95 """ 

96 Substitute $var parameters in from/to path templates. 

97 

98 Return a generator of (str, str) pairs of paths. 

99 """ 

100 return ( 

101 (path_from.safe_substitute(params), path_to.safe_substitute(params)) 

102 for (path_from, path_to) in from_to 

103 ) 

104 

105 def setup(self, tunables: TunableGroups, global_config: dict | None = None) -> bool: 

106 """ 

107 Run setup scripts locally and upload the scripts and data to the shared storage. 

108 

109 Parameters 

110 ---------- 

111 tunables : TunableGroups 

112 A collection of tunable OS and application parameters along with their 

113 values. In a local environment these could be used to prepare a config 

114 file on the scheduler prior to transferring it to the remote environment, 

115 for instance. 

116 global_config : dict 

117 Free-format dictionary of global parameters of the environment 

118 that are not used in the optimization process. 

119 

120 Returns 

121 ------- 

122 is_success : bool 

123 True if operation is successful, false otherwise. 

124 """ 

125 self._is_ready = super().setup(tunables, global_config) 

126 if self._is_ready: 

127 assert self._temp_dir is not None 

128 params = self._get_env_params(restrict=False) 

129 params["PWD"] = self._temp_dir 

130 for path_from, path_to in self._expand(self._upload, params): 

131 self._file_share_service.upload( 

132 self._params, 

133 self._config_loader_service.resolve_path( 

134 path_from, 

135 extra_paths=[self._temp_dir], 

136 ), 

137 path_to, 

138 ) 

139 return self._is_ready 

140 

141 def _download_files(self, ignore_missing: bool = False) -> None: 

142 """ 

143 Download files from the shared storage. 

144 

145 Parameters 

146 ---------- 

147 ignore_missing : bool 

148 If True, raise an exception when some file cannot be downloaded. 

149 If False, proceed with downloading other files and log a warning. 

150 """ 

151 assert self._temp_dir is not None 

152 params = self._get_env_params(restrict=False) 

153 params["PWD"] = self._temp_dir 

154 for path_from, path_to in self._expand(self._download, params): 

155 try: 

156 self._file_share_service.download( 

157 self._params, 

158 path_from, 

159 self._config_loader_service.resolve_path( 

160 path_to, 

161 extra_paths=[self._temp_dir], 

162 ), 

163 ) 

164 except FileNotFoundError as ex: 

165 _LOG.warning("Cannot download: %s", path_from) 

166 if not ignore_missing: 

167 raise ex 

168 except Exception as ex: 

169 _LOG.exception("Cannot download %s to %s", path_from, path_to) 

170 raise ex 

171 

172 def run(self) -> tuple[Status, datetime, dict[str, TunableValue] | None]: 

173 """ 

174 Download benchmark results from the shared storage and run post-processing 

175 scripts locally. 

176 

177 Returns 

178 ------- 

179 (status, timestamp, output) : (Status, datetime.datetime, dict) 

180 3-tuple of (Status, timestamp, output) values, where `output` is a dict 

181 with the results or None if the status is not COMPLETED. 

182 If run script is a benchmark, then the score is usually expected to 

183 be in the `score` field. 

184 """ 

185 self._download_files() 

186 return super().run() 

187 

188 def status(self) -> tuple[Status, datetime, list[tuple[datetime, str, Any]]]: 

189 self._download_files(ignore_missing=True) 

190 return super().status()