Coverage for mlos_bench/mlos_bench/storage/sql/storage.py: 100%

38 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-10-07 01:52 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5"""Saving and restoring the benchmark data in SQL database.""" 

6 

7import logging 

8from typing import Dict, Literal, Optional 

9 

10from sqlalchemy import URL, create_engine 

11 

12from mlos_bench.services.base_service import Service 

13from mlos_bench.storage.base_experiment_data import ExperimentData 

14from mlos_bench.storage.base_storage import Storage 

15from mlos_bench.storage.sql.experiment import Experiment 

16from mlos_bench.storage.sql.experiment_data import ExperimentSqlData 

17from mlos_bench.storage.sql.schema import DbSchema 

18from mlos_bench.tunables.tunable_groups import TunableGroups 

19 

20_LOG = logging.getLogger(__name__) 

21 

22 

23class SqlStorage(Storage): 

24 """An implementation of the Storage interface using SQLAlchemy backend.""" 

25 

26 def __init__( 

27 self, 

28 config: dict, 

29 global_config: Optional[dict] = None, 

30 service: Optional[Service] = None, 

31 ): 

32 super().__init__(config, global_config, service) 

33 lazy_schema_create = self._config.pop("lazy_schema_create", False) 

34 self._log_sql = self._config.pop("log_sql", False) 

35 self._url = URL.create(**self._config) 

36 self._repr = f"{self._url.get_backend_name()}:{self._url.database}" 

37 _LOG.info("Connect to the database: %s", self) 

38 self._engine = create_engine(self._url, echo=self._log_sql) 

39 self._db_schema: DbSchema 

40 if not lazy_schema_create: 

41 assert self._schema 

42 else: 

43 _LOG.info("Using lazy schema create for database: %s", self) 

44 

45 @property 

46 def _schema(self) -> DbSchema: 

47 """Lazily create schema upon first access.""" 

48 if not hasattr(self, "_db_schema"): 

49 self._db_schema = DbSchema(self._engine).create() 

50 if _LOG.isEnabledFor(logging.DEBUG): 

51 _LOG.debug("DDL statements:\n%s", self._schema) 

52 return self._db_schema 

53 

54 def __repr__(self) -> str: 

55 return self._repr 

56 

57 def experiment( # pylint: disable=too-many-arguments 

58 self, 

59 *, 

60 experiment_id: str, 

61 trial_id: int, 

62 root_env_config: str, 

63 description: str, 

64 tunables: TunableGroups, 

65 opt_targets: Dict[str, Literal["min", "max"]], 

66 ) -> Storage.Experiment: 

67 return Experiment( 

68 engine=self._engine, 

69 schema=self._schema, 

70 tunables=tunables, 

71 experiment_id=experiment_id, 

72 trial_id=trial_id, 

73 root_env_config=root_env_config, 

74 description=description, 

75 opt_targets=opt_targets, 

76 ) 

77 

78 @property 

79 def experiments(self) -> Dict[str, ExperimentData]: 

80 # FIXME: this is somewhat expensive if only fetching a single Experiment. 

81 # May need to expand the API or data structures to lazily fetch data and/or cache it. 

82 with self._engine.connect() as conn: 

83 cur_exp = conn.execute( 

84 self._schema.experiment.select().order_by( 

85 self._schema.experiment.c.exp_id.asc(), 

86 ) 

87 ) 

88 return { 

89 exp.exp_id: ExperimentSqlData( 

90 engine=self._engine, 

91 schema=self._schema, 

92 experiment_id=exp.exp_id, 

93 description=exp.description, 

94 root_env_config=exp.root_env_config, 

95 git_repo=exp.git_repo, 

96 git_commit=exp.git_commit, 

97 ) 

98 for exp in cur_exp.fetchall() 

99 }