Coverage for mlos_bench/mlos_bench/storage/sql/schema.py: 95%

37 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-10-07 01:52 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5"""DB schema definition.""" 

6 

7import logging 

8from typing import Any, List 

9 

10from sqlalchemy import ( 

11 Column, 

12 DateTime, 

13 Dialect, 

14 Engine, 

15 Float, 

16 ForeignKeyConstraint, 

17 Integer, 

18 MetaData, 

19 PrimaryKeyConstraint, 

20 Sequence, 

21 String, 

22 Table, 

23 UniqueConstraint, 

24 create_mock_engine, 

25) 

26 

27_LOG = logging.getLogger(__name__) 

28 

29 

30class _DDL: 

31 """ 

32 A helper class to capture the DDL statements from SQLAlchemy. 

33 

34 It is used in `DbSchema.__str__()` method below. 

35 """ 

36 

37 def __init__(self, dialect: Dialect): 

38 self._dialect = dialect 

39 self.statements: List[str] = [] 

40 

41 def __call__(self, sql: Any, *_args: Any, **_kwargs: Any) -> None: 

42 self.statements.append(str(sql.compile(dialect=self._dialect))) 

43 

44 def __repr__(self) -> str: 

45 res = ";\n".join(self.statements) 

46 return res + ";" if res else "" 

47 

48 

49class DbSchema: 

50 """A class to define and create the DB schema.""" 

51 

52 # This class is internal to SqlStorage and is mostly a struct 

53 # for all DB tables, so it's ok to disable the warnings. 

54 # pylint: disable=too-many-instance-attributes 

55 

56 # Common string column sizes. 

57 _ID_LEN = 512 

58 _PARAM_VALUE_LEN = 1024 

59 _METRIC_VALUE_LEN = 255 

60 _STATUS_LEN = 16 

61 

62 def __init__(self, engine: Engine): 

63 """Declare the SQLAlchemy schema for the database.""" 

64 _LOG.info("Create the DB schema for: %s", engine) 

65 self._engine = engine 

66 # TODO: bind for automatic schema updates? (#649) 

67 self._meta = MetaData() 

68 

69 self.experiment = Table( 

70 "experiment", 

71 self._meta, 

72 Column("exp_id", String(self._ID_LEN), nullable=False), 

73 Column("description", String(1024)), 

74 Column("root_env_config", String(1024), nullable=False), 

75 Column("git_repo", String(1024), nullable=False), 

76 Column("git_commit", String(40), nullable=False), 

77 PrimaryKeyConstraint("exp_id"), 

78 ) 

79 

80 self.objectives = Table( 

81 "objectives", 

82 self._meta, 

83 Column("exp_id"), 

84 Column("optimization_target", String(self._ID_LEN), nullable=False), 

85 Column("optimization_direction", String(4), nullable=False), 

86 # TODO: Note: weight is not fully supported yet as currently 

87 # multi-objective is expected to explore each objective equally. 

88 # Will need to adjust the insert and return values to support this 

89 # eventually. 

90 Column("weight", Float, nullable=True), 

91 PrimaryKeyConstraint("exp_id", "optimization_target"), 

92 ForeignKeyConstraint(["exp_id"], [self.experiment.c.exp_id]), 

93 ) 

94 

95 # A workaround for SQLAlchemy issue with autoincrement in DuckDB: 

96 if engine.dialect.name == "duckdb": 

97 seq_config_id = Sequence("seq_config_id") 

98 col_config_id = Column( 

99 "config_id", 

100 Integer, 

101 seq_config_id, 

102 server_default=seq_config_id.next_value(), 

103 nullable=False, 

104 primary_key=True, 

105 ) 

106 else: 

107 col_config_id = Column( 

108 "config_id", 

109 Integer, 

110 nullable=False, 

111 primary_key=True, 

112 autoincrement=True, 

113 ) 

114 

115 self.config = Table( 

116 "config", 

117 self._meta, 

118 col_config_id, 

119 Column("config_hash", String(64), nullable=False, unique=True), 

120 ) 

121 

122 self.trial = Table( 

123 "trial", 

124 self._meta, 

125 Column("exp_id", String(self._ID_LEN), nullable=False), 

126 Column("trial_id", Integer, nullable=False), 

127 Column("config_id", Integer, nullable=False), 

128 Column("ts_start", DateTime, nullable=False), 

129 Column("ts_end", DateTime), 

130 # Should match the text IDs of `mlos_bench.environments.Status` enum: 

131 Column("status", String(self._STATUS_LEN), nullable=False), 

132 PrimaryKeyConstraint("exp_id", "trial_id"), 

133 ForeignKeyConstraint(["exp_id"], [self.experiment.c.exp_id]), 

134 ForeignKeyConstraint(["config_id"], [self.config.c.config_id]), 

135 ) 

136 

137 # Values of the tunable parameters of the experiment, 

138 # fixed for a particular trial config. 

139 self.config_param = Table( 

140 "config_param", 

141 self._meta, 

142 Column("config_id", Integer, nullable=False), 

143 Column("param_id", String(self._ID_LEN), nullable=False), 

144 Column("param_value", String(self._PARAM_VALUE_LEN)), 

145 PrimaryKeyConstraint("config_id", "param_id"), 

146 ForeignKeyConstraint(["config_id"], [self.config.c.config_id]), 

147 ) 

148 

149 # Values of additional non-tunable parameters of the trial, 

150 # e.g., scheduled execution time, VM name / location, number of repeats, etc. 

151 self.trial_param = Table( 

152 "trial_param", 

153 self._meta, 

154 Column("exp_id", String(self._ID_LEN), nullable=False), 

155 Column("trial_id", Integer, nullable=False), 

156 Column("param_id", String(self._ID_LEN), nullable=False), 

157 Column("param_value", String(self._PARAM_VALUE_LEN)), 

158 PrimaryKeyConstraint("exp_id", "trial_id", "param_id"), 

159 ForeignKeyConstraint( 

160 ["exp_id", "trial_id"], 

161 [self.trial.c.exp_id, self.trial.c.trial_id], 

162 ), 

163 ) 

164 

165 self.trial_status = Table( 

166 "trial_status", 

167 self._meta, 

168 Column("exp_id", String(self._ID_LEN), nullable=False), 

169 Column("trial_id", Integer, nullable=False), 

170 Column("ts", DateTime(timezone=True), nullable=False, default="now"), 

171 Column("status", String(self._STATUS_LEN), nullable=False), 

172 UniqueConstraint("exp_id", "trial_id", "ts"), 

173 ForeignKeyConstraint( 

174 ["exp_id", "trial_id"], 

175 [self.trial.c.exp_id, self.trial.c.trial_id], 

176 ), 

177 ) 

178 

179 self.trial_result = Table( 

180 "trial_result", 

181 self._meta, 

182 Column("exp_id", String(self._ID_LEN), nullable=False), 

183 Column("trial_id", Integer, nullable=False), 

184 Column("metric_id", String(self._ID_LEN), nullable=False), 

185 Column("metric_value", String(self._METRIC_VALUE_LEN)), 

186 PrimaryKeyConstraint("exp_id", "trial_id", "metric_id"), 

187 ForeignKeyConstraint( 

188 ["exp_id", "trial_id"], 

189 [self.trial.c.exp_id, self.trial.c.trial_id], 

190 ), 

191 ) 

192 

193 self.trial_telemetry = Table( 

194 "trial_telemetry", 

195 self._meta, 

196 Column("exp_id", String(self._ID_LEN), nullable=False), 

197 Column("trial_id", Integer, nullable=False), 

198 Column("ts", DateTime(timezone=True), nullable=False, default="now"), 

199 Column("metric_id", String(self._ID_LEN), nullable=False), 

200 Column("metric_value", String(self._METRIC_VALUE_LEN)), 

201 UniqueConstraint("exp_id", "trial_id", "ts", "metric_id"), 

202 ForeignKeyConstraint( 

203 ["exp_id", "trial_id"], 

204 [self.trial.c.exp_id, self.trial.c.trial_id], 

205 ), 

206 ) 

207 

208 _LOG.debug("Schema: %s", self._meta) 

209 

210 def create(self) -> "DbSchema": 

211 """Create the DB schema.""" 

212 _LOG.info("Create the DB schema") 

213 self._meta.create_all(self._engine) 

214 return self 

215 

216 def __repr__(self) -> str: 

217 """ 

218 Produce a string with all SQL statements required to create the schema from 

219 scratch in current SQL dialect. 

220 

221 That is, return a collection of CREATE TABLE statements and such. 

222 NOTE: this method is quite heavy! We use it only once at startup 

223 to log the schema, and if the logging level is set to DEBUG. 

224 

225 Returns 

226 ------- 

227 sql : str 

228 A multi-line string with SQL statements to create the DB schema from scratch. 

229 """ 

230 ddl = _DDL(self._engine.dialect) 

231 mock_engine = create_mock_engine(self._engine.url, executor=ddl) 

232 self._meta.create_all(mock_engine, checkfirst=False) 

233 return str(ddl)