Coverage for mlos_bench/mlos_bench/tests/storage/trial_config_test.py: 100%

41 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-10-07 01:52 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5"""Unit tests for saving and retrieving additional parameters of pending trials.""" 

6from datetime import datetime 

7from typing import Any, Dict 

8 

9import pytest 

10from pytz import UTC 

11 

12from mlos_bench.storage.base_storage import Storage 

13from mlos_bench.tunables.tunable_groups import TunableGroups 

14 

15 

16def test_exp_trial_pending(exp_storage: Storage.Experiment, tunable_groups: TunableGroups) -> None: 

17 """Schedule a trial and check that it is pending and has the right configuration.""" 

18 config = {"location": "westus2", "num_repeats": 100} 

19 trial = exp_storage.new_trial(tunable_groups, config=config) 

20 (pending,) = list(exp_storage.pending_trials(datetime.now(UTC), running=True)) 

21 assert pending.trial_id == trial.trial_id 

22 assert pending.tunables == tunable_groups 

23 assert pending.config() == { 

24 "location": "westus2", 

25 "num_repeats": "100", 

26 "experiment_id": "Test-001", 

27 "trial_id": trial.trial_id, 

28 } 

29 

30 

31def test_exp_trial_configs(exp_storage: Storage.Experiment, tunable_groups: TunableGroups) -> None: 

32 """Start multiple trials with two different configs and check that we store only two 

33 config objects in the DB. 

34 """ 

35 config1 = tunable_groups.copy().assign({"idle": "mwait"}) 

36 trials1 = [ 

37 exp_storage.new_trial(config1), 

38 exp_storage.new_trial(config1), 

39 exp_storage.new_trial(config1.copy()), # Same values, different instance 

40 ] 

41 assert trials1[0].tunable_config_id == trials1[1].tunable_config_id 

42 assert trials1[0].tunable_config_id == trials1[2].tunable_config_id 

43 

44 config2 = tunable_groups.copy().assign({"idle": "halt"}) 

45 trials2 = [ 

46 exp_storage.new_trial(config2), 

47 exp_storage.new_trial(config2), 

48 exp_storage.new_trial(config2.copy()), # Same values, different instance 

49 ] 

50 assert trials2[0].tunable_config_id == trials2[1].tunable_config_id 

51 assert trials2[0].tunable_config_id == trials2[2].tunable_config_id 

52 

53 assert trials1[0].tunable_config_id != trials2[0].tunable_config_id 

54 

55 pending_ids = [ 

56 pending.tunable_config_id 

57 for pending in exp_storage.pending_trials(datetime.now(UTC), running=True) 

58 ] 

59 assert len(pending_ids) == 6 

60 assert len(set(pending_ids)) == 2 

61 assert set(pending_ids) == {trials1[0].tunable_config_id, trials2[0].tunable_config_id} 

62 

63 

64def test_exp_trial_no_config(exp_no_tunables_storage: Storage.Experiment) -> None: 

65 """Schedule a trial that has an empty tunable groups config.""" 

66 empty_config: dict = {} 

67 tunable_groups = TunableGroups(config=empty_config) 

68 trial = exp_no_tunables_storage.new_trial(tunable_groups, config=empty_config) 

69 (pending,) = exp_no_tunables_storage.pending_trials(datetime.now(UTC), running=True) 

70 assert pending.trial_id == trial.trial_id 

71 assert pending.tunables == tunable_groups 

72 assert pending.config() == { 

73 "experiment_id": "Test-003", 

74 "trial_id": trial.trial_id, 

75 } 

76 

77 

78@pytest.mark.parametrize( 

79 "bad_config", 

80 [ 

81 { 

82 "obj": object(), 

83 }, 

84 { 

85 "callable": lambda x: x, 

86 }, 

87 { 

88 "nested": { 

89 "callable": lambda x: x, 

90 }, 

91 }, 

92 ], 

93) 

94def test_exp_trial_non_serializable_config( 

95 exp_no_tunables_storage: Storage.Experiment, 

96 bad_config: Dict[str, Any], 

97) -> None: 

98 """Tests that a trial with a non-serializable config is rejected.""" 

99 empty_config: dict = {} 

100 tunable_groups = TunableGroups(config=empty_config) 

101 with pytest.raises(ValueError): 

102 exp_no_tunables_storage.new_trial(tunable_groups, config=bad_config)