Coverage for mlos_bench/mlos_bench/tunables/tunable.py: 96%

299 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-10-07 01:52 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5"""Tunable parameter definition.""" 

6import collections 

7import copy 

8import logging 

9from typing import ( 

10 Any, 

11 Dict, 

12 Iterable, 

13 List, 

14 Literal, 

15 Optional, 

16 Sequence, 

17 Tuple, 

18 Type, 

19 TypedDict, 

20 Union, 

21) 

22 

23import numpy as np 

24 

25from mlos_bench.util import nullable 

26 

27_LOG = logging.getLogger(__name__) 

28"""A tunable parameter value type alias.""" 

29TunableValue = Union[int, float, Optional[str]] 

30"""Tunable value type.""" 

31TunableValueType = Union[Type[int], Type[float], Type[str]] 

32""" 

33Tunable value type tuple. 

34 

35For checking with isinstance() 

36""" 

37TunableValueTypeTuple = (int, float, str, type(None)) 

38"""The string name of a tunable value type.""" 

39TunableValueTypeName = Literal["int", "float", "categorical"] 

40"""Tunable values dictionary type.""" 

41TunableValuesDict = Dict[str, TunableValue] 

42"""Tunable value distribution type.""" 

43DistributionName = Literal["uniform", "normal", "beta"] 

44 

45 

46class DistributionDict(TypedDict, total=False): 

47 """A typed dict for tunable parameters' distributions.""" 

48 

49 type: DistributionName 

50 params: Optional[Dict[str, float]] 

51 

52 

53class TunableDict(TypedDict, total=False): 

54 """ 

55 A typed dict for tunable parameters. 

56 

57 Mostly used for mypy type checking. 

58 

59 These are the types expected to be received from the json config. 

60 """ 

61 

62 type: TunableValueTypeName 

63 description: Optional[str] 

64 default: TunableValue 

65 values: Optional[List[Optional[str]]] 

66 range: Optional[Union[Sequence[int], Sequence[float]]] 

67 quantization_bins: Optional[int] 

68 log: Optional[bool] 

69 distribution: Optional[DistributionDict] 

70 special: Optional[Union[List[int], List[float]]] 

71 values_weights: Optional[List[float]] 

72 special_weights: Optional[List[float]] 

73 range_weight: Optional[float] 

74 meta: Dict[str, Any] 

75 

76 

77class Tunable: # pylint: disable=too-many-instance-attributes,too-many-public-methods 

78 """A tunable parameter definition and its current value.""" 

79 

80 # Maps tunable types to their corresponding Python types by name. 

81 _DTYPE: Dict[TunableValueTypeName, TunableValueType] = { 

82 "int": int, 

83 "float": float, 

84 "categorical": str, 

85 } 

86 

87 def __init__(self, name: str, config: TunableDict): 

88 """ 

89 Create an instance of a new tunable parameter. 

90 

91 Parameters 

92 ---------- 

93 name : str 

94 Human-readable identifier of the tunable parameter. 

95 config : dict 

96 Python dict that represents a Tunable (e.g., deserialized from JSON) 

97 """ 

98 if not isinstance(name, str) or "!" in name: # TODO: Use a regex here and in JSON schema 

99 raise ValueError(f"Invalid name of the tunable: {name}") 

100 self._name = name 

101 self._type: TunableValueTypeName = config["type"] # required 

102 if self._type not in self._DTYPE: 

103 raise ValueError(f"Invalid parameter type: {self._type}") 

104 self._description = config.get("description") 

105 self._default = config["default"] 

106 self._default = self.dtype(self._default) if self._default is not None else self._default 

107 self._values = config.get("values") 

108 if self._values: 

109 self._values = [str(v) if v is not None else v for v in self._values] 

110 self._meta: Dict[str, Any] = config.get("meta", {}) 

111 self._range: Optional[Union[Tuple[int, int], Tuple[float, float]]] = None 

112 self._quantization_bins: Optional[int] = config.get("quantization_bins") 

113 self._log: Optional[bool] = config.get("log") 

114 self._distribution: Optional[DistributionName] = None 

115 self._distribution_params: Dict[str, float] = {} 

116 distr = config.get("distribution") 

117 if distr: 

118 self._distribution = distr["type"] # required 

119 self._distribution_params = distr.get("params") or {} 

120 config_range = config.get("range") 

121 if config_range is not None: 

122 assert len(config_range) == 2, f"Invalid range: {config_range}" 

123 config_range = (config_range[0], config_range[1]) 

124 self._range = config_range 

125 self._special: Union[List[int], List[float]] = config.get("special") or [] 

126 self._weights: List[float] = ( 

127 config.get("values_weights") or config.get("special_weights") or [] 

128 ) 

129 self._range_weight: Optional[float] = config.get("range_weight") 

130 self._current_value = None 

131 self._sanity_check() 

132 self.value = self._default 

133 

134 def _sanity_check(self) -> None: 

135 """Check if the status of the Tunable is valid, and throw ValueError if it is 

136 not. 

137 """ 

138 if self.is_categorical: 

139 self._sanity_check_categorical() 

140 elif self.is_numerical: 

141 self._sanity_check_numerical() 

142 else: 

143 raise ValueError(f"Invalid parameter type for tunable {self}: {self._type}") 

144 if not self.is_valid(self.default): 

145 raise ValueError(f"Invalid default value for tunable {self}: {self.default}") 

146 

147 def _sanity_check_categorical(self) -> None: 

148 """Check if the status of the categorical Tunable is valid, and throw ValueError 

149 if it is not. 

150 """ 

151 # pylint: disable=too-complex 

152 assert self.is_categorical 

153 if not (self._values and isinstance(self._values, collections.abc.Iterable)): 

154 raise ValueError(f"Must specify values for the categorical type tunable {self}") 

155 if self._range is not None: 

156 raise ValueError(f"Range must be None for the categorical type tunable {self}") 

157 if len(set(self._values)) != len(self._values): 

158 raise ValueError(f"Values must be unique for the categorical type tunable {self}") 

159 if self._special: 

160 raise ValueError(f"Categorical tunable cannot have special values: {self}") 

161 if self._range_weight is not None: 

162 raise ValueError(f"Categorical tunable cannot have range_weight: {self}") 

163 if self._log is not None: 

164 raise ValueError(f"Categorical tunable cannot have log parameter: {self}") 

165 if self._quantization_bins is not None: 

166 raise ValueError(f"Categorical tunable cannot have quantization parameter: {self}") 

167 if self._distribution is not None: 

168 raise ValueError(f"Categorical parameters do not support `distribution`: {self}") 

169 if self._weights: 

170 if len(self._weights) != len(self._values): 

171 raise ValueError(f"Must specify weights for all values: {self}") 

172 if any(w < 0 for w in self._weights): 

173 raise ValueError(f"All weights must be non-negative: {self}") 

174 

175 def _sanity_check_numerical(self) -> None: 

176 """Check if the status of the numerical Tunable is valid, and throw ValueError 

177 if it is not. 

178 """ 

179 # pylint: disable=too-complex,too-many-branches 

180 assert self.is_numerical 

181 if self._values is not None: 

182 raise ValueError(f"Values must be None for the numerical type tunable {self}") 

183 if not self._range or len(self._range) != 2 or self._range[0] >= self._range[1]: 

184 raise ValueError(f"Invalid range for tunable {self}: {self._range}") 

185 if self._quantization_bins is not None and self._quantization_bins <= 1: 

186 raise ValueError(f"Number of quantization bins is <= 1: {self}") 

187 if self._distribution is not None and self._distribution not in { 

188 "uniform", 

189 "normal", 

190 "beta", 

191 }: 

192 raise ValueError(f"Invalid distribution: {self}") 

193 if self._distribution_params and self._distribution is None: 

194 raise ValueError(f"Must specify the distribution: {self}") 

195 if self._weights: 

196 if self._range_weight is None: 

197 raise ValueError(f"Must specify weight for the range: {self}") 

198 if len(self._weights) != len(self._special): 

199 raise ValueError("Must specify weights for all special values {self}") 

200 if any(w < 0 for w in self._weights + [self._range_weight]): 

201 raise ValueError(f"All weights must be non-negative: {self}") 

202 elif self._range_weight is not None: 

203 raise ValueError(f"Must specify both weights and range_weight or none: {self}") 

204 

205 def __repr__(self) -> str: 

206 """ 

207 Produce a human-readable version of the Tunable (mostly for logging). 

208 

209 Returns 

210 ------- 

211 string : str 

212 A human-readable version of the Tunable. 

213 """ 

214 # TODO? Add weights, specials, quantization, distribution? 

215 if self.is_categorical: 

216 return ( 

217 f"{self._name}[{self._type}]({self._values}:{self._default})={self._current_value}" 

218 ) 

219 return f"{self._name}[{self._type}]({self._range}:{self._default})={self._current_value}" 

220 

221 def __eq__(self, other: object) -> bool: 

222 """ 

223 Check if two Tunable objects are equal. 

224 

225 Parameters 

226 ---------- 

227 other : Tunable 

228 A tunable object to compare to. 

229 

230 Returns 

231 ------- 

232 is_equal : bool 

233 True if the Tunables correspond to the same parameter and have the same value and type. 

234 NOTE: ranges and special values are not currently considered in the comparison. 

235 """ 

236 if not isinstance(other, Tunable): 

237 return False 

238 return bool( 

239 self._name == other._name 

240 and self._type == other._type 

241 and self._current_value == other._current_value 

242 ) 

243 

244 def __lt__(self, other: object) -> bool: # pylint: disable=too-many-return-statements 

245 """ 

246 Compare the two Tunable objects. We mostly need this to create a canonical list 

247 of tunable objects when hashing a TunableGroup. 

248 

249 Parameters 

250 ---------- 

251 other : Tunable 

252 A tunable object to compare to. 

253 

254 Returns 

255 ------- 

256 is_less : bool 

257 True if the current Tunable is less then the other one, False otherwise. 

258 """ 

259 if not isinstance(other, Tunable): 

260 return False 

261 if self._name < other._name: 

262 return True 

263 if self._name == other._name and self._type < other._type: 

264 return True 

265 if self._name == other._name and self._type == other._type: 

266 if self.is_numerical: 

267 assert self._current_value is not None 

268 assert other._current_value is not None 

269 return bool(float(self._current_value) < float(other._current_value)) 

270 # else: categorical 

271 if self._current_value is None: 

272 return True 

273 if other._current_value is None: 

274 return False 

275 return bool(str(self._current_value) < str(other._current_value)) 

276 return False 

277 

278 def copy(self) -> "Tunable": 

279 """ 

280 Deep copy of the Tunable object. 

281 

282 Returns 

283 ------- 

284 tunable : Tunable 

285 A new Tunable object that is a deep copy of the original one. 

286 """ 

287 return copy.deepcopy(self) 

288 

289 @property 

290 def default(self) -> TunableValue: 

291 """Get the default value of the tunable.""" 

292 return self._default 

293 

294 def is_default(self) -> TunableValue: 

295 """Checks whether the currently assigned value of the tunable is at its 

296 default. 

297 """ 

298 return self._default == self._current_value 

299 

300 @property 

301 def value(self) -> TunableValue: 

302 """Get the current value of the tunable.""" 

303 return self._current_value 

304 

305 @value.setter 

306 def value(self, value: TunableValue) -> TunableValue: 

307 """Set the current value of the tunable.""" 

308 # We need this coercion for the values produced by some optimizers 

309 # (e.g., scikit-optimize) and for data restored from certain storage 

310 # systems (where values can be strings). 

311 try: 

312 if self.is_categorical and value is None: 

313 coerced_value = None 

314 else: 

315 assert value is not None 

316 coerced_value = self.dtype(value) 

317 except Exception: 

318 _LOG.error( 

319 "Impossible conversion: %s %s <- %s %s", 

320 self._type, 

321 self._name, 

322 type(value), 

323 value, 

324 ) 

325 raise 

326 

327 if self._type == "int" and isinstance(value, float) and value != coerced_value: 

328 _LOG.error( 

329 "Loss of precision: %s %s <- %s %s", 

330 self._type, 

331 self._name, 

332 type(value), 

333 value, 

334 ) 

335 raise ValueError(f"Loss of precision: {self._name}={value}") 

336 

337 if not self.is_valid(coerced_value): 

338 _LOG.error( 

339 "Invalid assignment: %s %s <- %s %s", 

340 self._type, 

341 self._name, 

342 type(value), 

343 value, 

344 ) 

345 raise ValueError(f"Invalid value for the tunable: {self._name}={value}") 

346 

347 self._current_value = coerced_value 

348 return self._current_value 

349 

350 def update(self, value: TunableValue) -> bool: 

351 """ 

352 Assign the value to the tunable. Return True if it is a new value, False 

353 otherwise. 

354 

355 Parameters 

356 ---------- 

357 value : Union[int, float, str] 

358 Value to assign. 

359 

360 Returns 

361 ------- 

362 is_updated : bool 

363 True if the new value is different from the previous one, False otherwise. 

364 """ 

365 prev_value = self._current_value 

366 self.value = value 

367 return prev_value != self._current_value 

368 

369 def is_valid(self, value: TunableValue) -> bool: 

370 """ 

371 Check if the value can be assigned to the tunable. 

372 

373 Parameters 

374 ---------- 

375 value : Union[int, float, str] 

376 Value to validate. 

377 

378 Returns 

379 ------- 

380 is_valid : bool 

381 True if the value is valid, False otherwise. 

382 """ 

383 if self.is_categorical and self._values: 

384 return value in self._values 

385 elif self.is_numerical and self._range: 

386 if isinstance(value, (int, float)): 

387 return self.in_range(value) or value in self._special 

388 else: 

389 raise ValueError(f"Invalid value type for tunable {self}: {value}={type(value)}") 

390 else: 

391 raise ValueError(f"Invalid parameter type: {self._type}") 

392 

393 def in_range(self, value: Union[int, float, str, None]) -> bool: 

394 """ 

395 Check if the value is within the range of the tunable. 

396 

397 Do *NOT* check for special values. Return False if the tunable or value is 

398 categorical or None. 

399 """ 

400 return ( 

401 isinstance(value, (float, int)) 

402 and self.is_numerical 

403 and self._range is not None 

404 and bool(self._range[0] <= value <= self._range[1]) 

405 ) 

406 

407 @property 

408 def category(self) -> Optional[str]: 

409 """Get the current value of the tunable as a number.""" 

410 if self.is_categorical: 

411 return nullable(str, self._current_value) 

412 else: 

413 raise ValueError("Cannot get categorical values for a numerical tunable.") 

414 

415 @category.setter 

416 def category(self, new_value: Optional[str]) -> Optional[str]: 

417 """Set the current value of the tunable.""" 

418 assert self.is_categorical 

419 assert isinstance(new_value, (str, type(None))) 

420 self.value = new_value 

421 return self.value 

422 

423 @property 

424 def numerical_value(self) -> Union[int, float]: 

425 """Get the current value of the tunable as a number.""" 

426 assert self._current_value is not None 

427 if self._type == "int": 

428 return int(self._current_value) 

429 elif self._type == "float": 

430 return float(self._current_value) 

431 else: 

432 raise ValueError("Cannot get numerical value for a categorical tunable.") 

433 

434 @numerical_value.setter 

435 def numerical_value(self, new_value: Union[int, float]) -> Union[int, float]: 

436 """Set the current numerical value of the tunable.""" 

437 # We need this coercion for the values produced by some optimizers 

438 # (e.g., scikit-optimize) and for data restored from certain storage 

439 # systems (where values can be strings). 

440 assert self.is_numerical 

441 self.value = new_value 

442 return self.value 

443 

444 @property 

445 def name(self) -> str: 

446 """Get the name / string ID of the tunable.""" 

447 return self._name 

448 

449 @property 

450 def special(self) -> Union[List[int], List[float]]: 

451 """ 

452 Get the special values of the tunable. Return an empty list if there are none. 

453 

454 Returns 

455 ------- 

456 special : [int] | [float] 

457 A list of special values of the tunable. Can be empty. 

458 """ 

459 return self._special 

460 

461 @property 

462 def is_special(self) -> bool: 

463 """ 

464 Check if the current value of the tunable is special. 

465 

466 Returns 

467 ------- 

468 is_special : bool 

469 True if the current value of the tunable is special, False otherwise. 

470 """ 

471 return self.value in self._special 

472 

473 @property 

474 def weights(self) -> Optional[List[float]]: 

475 """ 

476 Get the weights of the categories or special values of the tunable. Return None 

477 if there are none. 

478 

479 Returns 

480 ------- 

481 weights : [float] 

482 A list of weights or None. 

483 """ 

484 return self._weights 

485 

486 @property 

487 def range_weight(self) -> Optional[float]: 

488 """ 

489 Get weight of the range of the numeric tunable. Return None if there are no 

490 weights or a tunable is categorical. 

491 

492 Returns 

493 ------- 

494 weight : float 

495 Weight of the range or None. 

496 """ 

497 assert self.is_numerical 

498 assert self._special 

499 assert self._weights 

500 return self._range_weight 

501 

502 @property 

503 def type(self) -> TunableValueTypeName: 

504 """ 

505 Get the data type of the tunable. 

506 

507 Returns 

508 ------- 

509 type : str 

510 Data type of the tunable - one of {'int', 'float', 'categorical'}. 

511 """ 

512 return self._type 

513 

514 @property 

515 def dtype(self) -> TunableValueType: 

516 """ 

517 Get the actual Python data type of the tunable. 

518 

519 This is useful for bulk conversions of the input data. 

520 

521 Returns 

522 ------- 

523 dtype : type 

524 Data type of the tunable - one of {int, float, str}. 

525 """ 

526 return self._DTYPE[self._type] 

527 

528 @property 

529 def is_categorical(self) -> bool: 

530 """ 

531 Check if the tunable is categorical. 

532 

533 Returns 

534 ------- 

535 is_categorical : bool 

536 True if the tunable is categorical, False otherwise. 

537 """ 

538 return self._type == "categorical" 

539 

540 @property 

541 def is_numerical(self) -> bool: 

542 """ 

543 Check if the tunable is an integer or float. 

544 

545 Returns 

546 ------- 

547 is_int : bool 

548 True if the tunable is an integer or float, False otherwise. 

549 """ 

550 return self._type in {"int", "float"} 

551 

552 @property 

553 def range(self) -> Union[Tuple[int, int], Tuple[float, float]]: 

554 """ 

555 Get the range of the tunable if it is numerical, None otherwise. 

556 

557 Returns 

558 ------- 

559 range : (number, number) 

560 A 2-tuple of numbers that represents the range of the tunable. 

561 Numbers can be int or float, depending on the type of the tunable. 

562 """ 

563 assert self.is_numerical 

564 assert self._range is not None 

565 return self._range 

566 

567 @property 

568 def span(self) -> Union[int, float]: 

569 """ 

570 Gets the span of the range. 

571 

572 Note: this does not take quantization into account. 

573 

574 Returns 

575 ------- 

576 Union[int, float] 

577 (max - min) for numerical tunables. 

578 """ 

579 num_range = self.range 

580 return num_range[1] - num_range[0] 

581 

582 @property 

583 def quantization_bins(self) -> Optional[int]: 

584 """ 

585 Get the number of quantization bins, if specified. 

586 

587 Returns 

588 ------- 

589 quantization_bins : int | None 

590 Number of quantization bins, or None. 

591 """ 

592 if self.is_categorical: 

593 return None 

594 return self._quantization_bins 

595 

596 @property 

597 def quantized_values(self) -> Optional[Union[Iterable[int], Iterable[float]]]: 

598 """ 

599 Get a sequence of quanitized values for this tunable. 

600 

601 Returns 

602 ------- 

603 Optional[Union[Iterable[int], Iterable[float]]] 

604 If the Tunable is quantizable, returns a sequence of those elements, 

605 else None (e.g., for unquantized float type tunables). 

606 """ 

607 num_range = self.range 

608 if self.type == "float": 

609 if not self.quantization_bins: 

610 return None 

611 # Be sure to return python types instead of numpy types. 

612 return ( 

613 float(x) 

614 for x in np.linspace( 

615 start=num_range[0], 

616 stop=num_range[1], 

617 num=self.quantization_bins, 

618 endpoint=True, 

619 ) 

620 ) 

621 assert self.type == "int", f"Unhandled tunable type: {self}" 

622 return range( 

623 int(num_range[0]), 

624 int(num_range[1]) + 1, 

625 int(self.span / (self.quantization_bins - 1)) if self.quantization_bins else 1, 

626 ) 

627 

628 @property 

629 def cardinality(self) -> Optional[int]: 

630 """ 

631 Gets the cardinality of elements in this tunable, or else None. (i.e., when the 

632 tunable is continuous float and not quantized). 

633 

634 If the tunable has quantization set, this 

635 

636 Returns 

637 ------- 

638 cardinality : int 

639 Either the number of points in the tunable or else None. 

640 """ 

641 if self.is_categorical: 

642 return len(self.categories) 

643 if self.quantization_bins: 

644 return self.quantization_bins 

645 if self.type == "int": 

646 return int(self.span) + 1 

647 return None 

648 

649 @property 

650 def is_log(self) -> Optional[bool]: 

651 """ 

652 Check if numeric tunable is log scale. 

653 

654 Returns 

655 ------- 

656 log : bool 

657 True if numeric tunable is log scale, False if linear. 

658 """ 

659 assert self.is_numerical 

660 return self._log 

661 

662 @property 

663 def distribution(self) -> Optional[DistributionName]: 

664 """ 

665 Get the name of the distribution (uniform, normal, or beta) if specified. 

666 

667 Returns 

668 ------- 

669 distribution : str 

670 Name of the distribution (uniform, normal, or beta) or None. 

671 """ 

672 return self._distribution 

673 

674 @property 

675 def distribution_params(self) -> Dict[str, float]: 

676 """ 

677 Get the parameters of the distribution, if specified. 

678 

679 Returns 

680 ------- 

681 distribution_params : Dict[str, float] 

682 Parameters of the distribution or None. 

683 """ 

684 assert self._distribution is not None 

685 return self._distribution_params 

686 

687 @property 

688 def categories(self) -> List[Optional[str]]: 

689 """ 

690 Get the list of all possible values of a categorical tunable. Return None if the 

691 tunable is not categorical. 

692 

693 Returns 

694 ------- 

695 values : List[str] 

696 List of all possible values of a categorical tunable. 

697 """ 

698 assert self.is_categorical 

699 assert self._values is not None 

700 return self._values 

701 

702 @property 

703 def values(self) -> Optional[Union[Iterable[Optional[str]], Iterable[int], Iterable[float]]]: 

704 """ 

705 Gets the categories or quantized values for this tunable. 

706 

707 Returns 

708 ------- 

709 Optional[Union[Iterable[Optional[str]], Iterable[int], Iterable[float]]] 

710 Categories or quantized values. 

711 """ 

712 if self.is_categorical: 

713 return self.categories 

714 assert self.is_numerical 

715 return self.quantized_values 

716 

717 @property 

718 def meta(self) -> Dict[str, Any]: 

719 """ 

720 Get the tunable's metadata. 

721 

722 This is a free-form dictionary that can be used to store any additional 

723 information about the tunable (e.g., the unit information). 

724 """ 

725 return self._meta