Coverage for mlos_bench/mlos_bench/storage/sql/storage.py: 100%

46 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2025-01-21 01:50 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5"""Saving and restoring the benchmark data in SQL database.""" 

6 

7import logging 

8from typing import Literal 

9 

10from sqlalchemy import URL, create_engine 

11 

12from mlos_bench.services.base_service import Service 

13from mlos_bench.storage.base_experiment_data import ExperimentData 

14from mlos_bench.storage.base_storage import Storage 

15from mlos_bench.storage.sql.experiment import Experiment 

16from mlos_bench.storage.sql.experiment_data import ExperimentSqlData 

17from mlos_bench.storage.sql.schema import DbSchema 

18from mlos_bench.tunables.tunable_groups import TunableGroups 

19 

20_LOG = logging.getLogger(__name__) 

21 

22 

23class SqlStorage(Storage): 

24 """An implementation of the :py:class:`~.Storage` interface using SQLAlchemy 

25 backend. 

26 """ 

27 

28 def __init__( 

29 self, 

30 config: dict, 

31 global_config: dict | None = None, 

32 service: Service | None = None, 

33 ): 

34 super().__init__(config, global_config, service) 

35 lazy_schema_create = self._config.pop("lazy_schema_create", False) 

36 self._log_sql = self._config.pop("log_sql", False) 

37 self._url = URL.create(**self._config) 

38 self._repr = f"{self._url.get_backend_name()}:{self._url.database}" 

39 _LOG.info("Connect to the database: %s", self) 

40 self._engine = create_engine(self._url, echo=self._log_sql) 

41 self._db_schema = DbSchema(self._engine) 

42 self._schema_created = False 

43 self._schema_updated = False 

44 if not lazy_schema_create: 

45 assert self._schema 

46 self.update_schema() 

47 else: 

48 _LOG.info("Using lazy schema create for database: %s", self) 

49 

50 @property 

51 def _schema(self) -> DbSchema: 

52 """Lazily create schema upon first access.""" 

53 if not self._schema_created: 

54 self._db_schema.create() 

55 self._schema_created = True 

56 if _LOG.isEnabledFor(logging.DEBUG): 

57 _LOG.debug("DDL statements:\n%s", self._db_schema) 

58 return self._db_schema 

59 

60 def update_schema(self) -> None: 

61 """Update the database schema.""" 

62 if not self._schema_updated: 

63 self._schema.update() 

64 self._schema_updated = True 

65 

66 def __repr__(self) -> str: 

67 return self._repr 

68 

69 def experiment( # pylint: disable=too-many-arguments 

70 self, 

71 *, 

72 experiment_id: str, 

73 trial_id: int, 

74 root_env_config: str, 

75 description: str, 

76 tunables: TunableGroups, 

77 opt_targets: dict[str, Literal["min", "max"]], 

78 ) -> Storage.Experiment: 

79 return Experiment( 

80 engine=self._engine, 

81 schema=self._schema, 

82 tunables=tunables, 

83 experiment_id=experiment_id, 

84 trial_id=trial_id, 

85 root_env_config=root_env_config, 

86 description=description, 

87 opt_targets=opt_targets, 

88 ) 

89 

90 @property 

91 def experiments(self) -> dict[str, ExperimentData]: 

92 # FIXME: this is somewhat expensive if only fetching a single Experiment. 

93 # May need to expand the API or data structures to lazily fetch data and/or cache it. 

94 with self._engine.connect() as conn: 

95 cur_exp = conn.execute( 

96 self._schema.experiment.select().order_by( 

97 self._schema.experiment.c.exp_id.asc(), 

98 ) 

99 ) 

100 return { 

101 exp.exp_id: ExperimentSqlData( 

102 engine=self._engine, 

103 schema=self._schema, 

104 experiment_id=exp.exp_id, 

105 description=exp.description, 

106 root_env_config=exp.root_env_config, 

107 git_repo=exp.git_repo, 

108 git_commit=exp.git_commit, 

109 ) 

110 for exp in cur_exp.fetchall() 

111 }