Coverage for mlos_bench/mlos_bench/tunables/tunable_groups.py: 99%

88 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-10-07 01:52 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5"""TunableGroups definition.""" 

6import copy 

7from typing import Dict, Generator, Iterable, Mapping, Optional, Tuple, Union 

8 

9from mlos_bench.config.schemas import ConfigSchema 

10from mlos_bench.tunables.covariant_group import CovariantTunableGroup 

11from mlos_bench.tunables.tunable import Tunable, TunableValue 

12 

13 

14class TunableGroups: 

15 """A collection of covariant groups of tunable parameters.""" 

16 

17 def __init__(self, config: Optional[dict] = None): 

18 """ 

19 Create a new group of tunable parameters. 

20 

21 Parameters 

22 ---------- 

23 config : dict 

24 Python dict of serialized representation of the covariant tunable groups. 

25 """ 

26 if config is None: 

27 config = {} 

28 ConfigSchema.TUNABLE_PARAMS.validate(config) 

29 # Index (Tunable id -> CovariantTunableGroup) 

30 self._index: Dict[str, CovariantTunableGroup] = {} 

31 self._tunable_groups: Dict[str, CovariantTunableGroup] = {} 

32 for name, group_config in config.items(): 

33 self._add_group(CovariantTunableGroup(name, group_config)) 

34 

35 def __bool__(self) -> bool: 

36 return bool(self._index) 

37 

38 def __len__(self) -> int: 

39 return len(self._index) 

40 

41 def __eq__(self, other: object) -> bool: 

42 """ 

43 Check if two TunableGroups are equal. 

44 

45 Parameters 

46 ---------- 

47 other : TunableGroups 

48 A tunable groups object to compare to. 

49 

50 Returns 

51 ------- 

52 is_equal : bool 

53 True if two TunableGroups are equal. 

54 """ 

55 if not isinstance(other, TunableGroups): 

56 return False 

57 return bool(self._tunable_groups == other._tunable_groups) 

58 

59 def copy(self) -> "TunableGroups": 

60 """ 

61 Deep copy of the TunableGroups object. 

62 

63 Returns 

64 ------- 

65 tunables : TunableGroups 

66 A new instance of the TunableGroups object 

67 that is a deep copy of the original one. 

68 """ 

69 return copy.deepcopy(self) 

70 

71 def _add_group(self, group: CovariantTunableGroup) -> None: 

72 """ 

73 Add a CovariantTunableGroup to the current collection. 

74 

75 Note: non-overlapping groups are expected to be added to the collection. 

76 

77 Parameters 

78 ---------- 

79 group : CovariantTunableGroup 

80 """ 

81 assert ( 

82 group.name not in self._tunable_groups 

83 ), f"Duplicate covariant tunable group name {group.name} in {self}" 

84 self._tunable_groups[group.name] = group 

85 for tunable in group.get_tunables(): 

86 if tunable.name in self._index: 

87 raise ValueError( 

88 f"Duplicate Tunable {tunable.name} from group {group.name} in {self}" 

89 ) 

90 self._index[tunable.name] = group 

91 

92 def merge(self, tunables: "TunableGroups") -> "TunableGroups": 

93 """ 

94 Merge the two collections of covariant tunable groups. 

95 

96 Unlike the dict `update` method, this method does not modify the 

97 original when overlapping keys are found. 

98 It is expected be used to merge the tunable groups referenced by a 

99 standalone Environment config into a parent CompositeEnvironment, 

100 for instance. 

101 This allows self contained, potentially overlapping, but also 

102 overridable configs to be composed together. 

103 

104 Parameters 

105 ---------- 

106 tunables : TunableGroups 

107 A collection of covariant tunable groups. 

108 

109 Returns 

110 ------- 

111 self : TunableGroups 

112 Self-reference for chaining. 

113 """ 

114 # pylint: disable=protected-access 

115 # Check that covariant groups are unique, else throw an error. 

116 for group in tunables._tunable_groups.values(): 

117 if group.name not in self._tunable_groups: 

118 self._add_group(group) 

119 else: 

120 # Check that there's no overlap in the tunables. 

121 # But allow for differing current values. 

122 if not self._tunable_groups[group.name].equals_defaults(group): 

123 raise ValueError( 

124 f"Overlapping covariant tunable group name {group.name} " 

125 "in {self._tunable_groups[group.name]} and {tunables}" 

126 ) 

127 return self 

128 

129 def __repr__(self) -> str: 

130 """ 

131 Produce a human-readable version of the TunableGroups (mostly for logging). 

132 

133 Returns 

134 ------- 

135 string : str 

136 A human-readable version of the TunableGroups. 

137 """ 

138 return ( 

139 "{ " 

140 + ", ".join( 

141 f"{group.name}::{tunable}" 

142 for group in sorted(self._tunable_groups.values(), key=lambda g: (-g.cost, g.name)) 

143 for tunable in sorted(group._tunables.values()) 

144 ) 

145 + " }" 

146 ) 

147 

148 def __contains__(self, tunable: Union[str, Tunable]) -> bool: 

149 """Checks if the given name/tunable is in this tunable group.""" 

150 name: str = tunable.name if isinstance(tunable, Tunable) else tunable 

151 return name in self._index 

152 

153 def __getitem__(self, tunable: Union[str, Tunable]) -> TunableValue: 

154 """Get the current value of a single tunable parameter.""" 

155 name: str = tunable.name if isinstance(tunable, Tunable) else tunable 

156 return self._index[name][name] 

157 

158 def __setitem__( 

159 self, 

160 tunable: Union[str, Tunable], 

161 tunable_value: Union[TunableValue, Tunable], 

162 ) -> TunableValue: 

163 """Update the current value of a single tunable parameter.""" 

164 # Use double index to make sure we set the is_updated flag of the group 

165 name: str = tunable.name if isinstance(tunable, Tunable) else tunable 

166 value: TunableValue = ( 

167 tunable_value.value if isinstance(tunable_value, Tunable) else tunable_value 

168 ) 

169 self._index[name][name] = value 

170 return self._index[name][name] 

171 

172 def __iter__(self) -> Generator[Tuple[Tunable, CovariantTunableGroup], None, None]: 

173 """ 

174 An iterator over all tunables in the group. 

175 

176 Returns 

177 ------- 

178 [(tunable, group), ...] : iter(Tunable, CovariantTunableGroup) 

179 An iterator over all tunables in all groups. Each element is a 2-tuple 

180 of an instance of the Tunable parameter and covariant group it belongs to. 

181 """ 

182 return ((group.get_tunable(name), group) for (name, group) in self._index.items()) 

183 

184 def get_tunable(self, tunable: Union[str, Tunable]) -> Tuple[Tunable, CovariantTunableGroup]: 

185 """ 

186 Access the entire Tunable (not just its value) and its covariant group. Throw 

187 KeyError if the tunable is not found. 

188 

189 Parameters 

190 ---------- 

191 tunable : Union[str, Tunable] 

192 Name of the tunable parameter. 

193 

194 Returns 

195 ------- 

196 (tunable, group) : (Tunable, CovariantTunableGroup) 

197 A 2-tuple of an instance of the Tunable parameter and covariant group it belongs to. 

198 """ 

199 name: str = tunable.name if isinstance(tunable, Tunable) else tunable 

200 group = self._index[name] 

201 return (group.get_tunable(name), group) 

202 

203 def get_covariant_group_names(self) -> Iterable[str]: 

204 """ 

205 Get the names of all covariance groups in the collection. 

206 

207 Returns 

208 ------- 

209 group_names : [str] 

210 IDs of the covariant tunable groups. 

211 """ 

212 return self._tunable_groups.keys() 

213 

214 def subgroup(self, group_names: Iterable[str]) -> "TunableGroups": 

215 """ 

216 Select the covariance groups from the current set and create a new TunableGroups 

217 object that consists of those covariance groups. 

218 

219 Note: The new TunableGroup will include *references* (not copies) to 

220 original ones, so each will get updated together. 

221 This is often desirable to support the use case of multiple related 

222 Environments (e.g. Local vs Remote) using the same set of tunables 

223 within a CompositeEnvironment. 

224 

225 Parameters 

226 ---------- 

227 group_names : list of str 

228 IDs of the covariant tunable groups. 

229 

230 Returns 

231 ------- 

232 tunables : TunableGroups 

233 A collection of covariant tunable groups. 

234 """ 

235 # pylint: disable=protected-access 

236 tunables = TunableGroups() 

237 for name in group_names: 

238 if name not in self._tunable_groups: 

239 raise KeyError(f"Unknown covariant group name '{name}' in tunable group {self}") 

240 tunables._add_group(self._tunable_groups[name]) 

241 return tunables 

242 

243 def get_param_values( 

244 self, 

245 group_names: Optional[Iterable[str]] = None, 

246 into_params: Optional[Dict[str, TunableValue]] = None, 

247 ) -> Dict[str, TunableValue]: 

248 """ 

249 Get the current values of the tunables that belong to the specified covariance 

250 groups. 

251 

252 Parameters 

253 ---------- 

254 group_names : list of str or None 

255 IDs of the covariant tunable groups. 

256 Select parameters from all groups if omitted. 

257 into_params : dict 

258 An optional dict to copy the parameters and their values into. 

259 

260 Returns 

261 ------- 

262 into_params : dict 

263 Flat dict of all parameters and their values from given covariance groups. 

264 """ 

265 if group_names is None: 

266 group_names = self.get_covariant_group_names() 

267 if into_params is None: 

268 into_params = {} 

269 for name in group_names: 

270 into_params.update(self._tunable_groups[name].get_tunable_values_dict()) 

271 return into_params 

272 

273 def is_updated(self, group_names: Optional[Iterable[str]] = None) -> bool: 

274 """ 

275 Check if any of the given covariant tunable groups has been updated. 

276 

277 Parameters 

278 ---------- 

279 group_names : list of str or None 

280 IDs of the (covariant) tunable groups. Check all groups if omitted. 

281 

282 Returns 

283 ------- 

284 is_updated : bool 

285 True if any of the specified tunable groups has been updated, False otherwise. 

286 """ 

287 return any( 

288 self._tunable_groups[name].is_updated() 

289 for name in (group_names or self.get_covariant_group_names()) 

290 ) 

291 

292 def is_defaults(self) -> bool: 

293 """ 

294 Checks whether the currently assigned values of all tunables are at their 

295 defaults. 

296 

297 Returns 

298 ------- 

299 bool 

300 """ 

301 return all(group.is_defaults() for group in self._tunable_groups.values()) 

302 

303 def restore_defaults(self, group_names: Optional[Iterable[str]] = None) -> "TunableGroups": 

304 """ 

305 Restore all tunable parameters to their default values. 

306 

307 Parameters 

308 ---------- 

309 group_names : list of str or None 

310 IDs of the (covariant) tunable groups. Restore all groups if omitted. 

311 

312 Returns 

313 ------- 

314 self : TunableGroups 

315 Self-reference for chaining. 

316 """ 

317 for name in group_names or self.get_covariant_group_names(): 

318 self._tunable_groups[name].restore_defaults() 

319 return self 

320 

321 def reset(self, group_names: Optional[Iterable[str]] = None) -> "TunableGroups": 

322 """ 

323 Clear the update flag of given covariant groups. 

324 

325 Parameters 

326 ---------- 

327 group_names : list of str or None 

328 IDs of the (covariant) tunable groups. Reset all groups if omitted. 

329 

330 Returns 

331 ------- 

332 self : TunableGroups 

333 Self-reference for chaining. 

334 """ 

335 for name in group_names or self.get_covariant_group_names(): 

336 self._tunable_groups[name].reset_is_updated() 

337 return self 

338 

339 def assign(self, param_values: Mapping[str, TunableValue]) -> "TunableGroups": 

340 """ 

341 In-place update the values of the tunables from the dictionary of (key, value) 

342 pairs. 

343 

344 Parameters 

345 ---------- 

346 param_values : Mapping[str, TunableValue] 

347 Dictionary mapping Tunable parameter names to new values. 

348 

349 Returns 

350 ------- 

351 self : TunableGroups 

352 Self-reference for chaining. 

353 """ 

354 for key, value in param_values.items(): 

355 self[key] = value 

356 return self