Coverage for mlos_bench/mlos_bench/storage/sql/storage.py: 100%

38 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2024-12-14 01:58 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5"""Saving and restoring the benchmark data in SQL database.""" 

6 

7import logging 

8from typing import Dict, Literal, Optional 

9 

10from sqlalchemy import URL, create_engine 

11 

12from mlos_bench.services.base_service import Service 

13from mlos_bench.storage.base_experiment_data import ExperimentData 

14from mlos_bench.storage.base_storage import Storage 

15from mlos_bench.storage.sql.experiment import Experiment 

16from mlos_bench.storage.sql.experiment_data import ExperimentSqlData 

17from mlos_bench.storage.sql.schema import DbSchema 

18from mlos_bench.tunables.tunable_groups import TunableGroups 

19 

20_LOG = logging.getLogger(__name__) 

21 

22 

23class SqlStorage(Storage): 

24 """An implementation of the :py:class:`~.Storage` interface using SQLAlchemy 

25 backend. 

26 """ 

27 

28 def __init__( 

29 self, 

30 config: dict, 

31 global_config: Optional[dict] = None, 

32 service: Optional[Service] = None, 

33 ): 

34 super().__init__(config, global_config, service) 

35 lazy_schema_create = self._config.pop("lazy_schema_create", False) 

36 self._log_sql = self._config.pop("log_sql", False) 

37 self._url = URL.create(**self._config) 

38 self._repr = f"{self._url.get_backend_name()}:{self._url.database}" 

39 _LOG.info("Connect to the database: %s", self) 

40 self._engine = create_engine(self._url, echo=self._log_sql) 

41 self._db_schema: DbSchema 

42 if not lazy_schema_create: 

43 assert self._schema 

44 else: 

45 _LOG.info("Using lazy schema create for database: %s", self) 

46 

47 @property 

48 def _schema(self) -> DbSchema: 

49 """Lazily create schema upon first access.""" 

50 if not hasattr(self, "_db_schema"): 

51 self._db_schema = DbSchema(self._engine).create() 

52 if _LOG.isEnabledFor(logging.DEBUG): 

53 _LOG.debug("DDL statements:\n%s", self._schema) 

54 return self._db_schema 

55 

56 def __repr__(self) -> str: 

57 return self._repr 

58 

59 def experiment( # pylint: disable=too-many-arguments 

60 self, 

61 *, 

62 experiment_id: str, 

63 trial_id: int, 

64 root_env_config: str, 

65 description: str, 

66 tunables: TunableGroups, 

67 opt_targets: Dict[str, Literal["min", "max"]], 

68 ) -> Storage.Experiment: 

69 return Experiment( 

70 engine=self._engine, 

71 schema=self._schema, 

72 tunables=tunables, 

73 experiment_id=experiment_id, 

74 trial_id=trial_id, 

75 root_env_config=root_env_config, 

76 description=description, 

77 opt_targets=opt_targets, 

78 ) 

79 

80 @property 

81 def experiments(self) -> Dict[str, ExperimentData]: 

82 # FIXME: this is somewhat expensive if only fetching a single Experiment. 

83 # May need to expand the API or data structures to lazily fetch data and/or cache it. 

84 with self._engine.connect() as conn: 

85 cur_exp = conn.execute( 

86 self._schema.experiment.select().order_by( 

87 self._schema.experiment.c.exp_id.asc(), 

88 ) 

89 ) 

90 return { 

91 exp.exp_id: ExperimentSqlData( 

92 engine=self._engine, 

93 schema=self._schema, 

94 experiment_id=exp.exp_id, 

95 description=exp.description, 

96 root_env_config=exp.root_env_config, 

97 git_repo=exp.git_repo, 

98 git_commit=exp.git_commit, 

99 ) 

100 for exp in cur_exp.fetchall() 

101 }