Coverage for mlos_bench/mlos_bench/tests/config/schemas/__init__.py: 94%
77 statements
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-14 01:58 +0000
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-14 01:58 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""Common tests for config schemas and their validation and test cases."""
7import os
8from copy import deepcopy
9from dataclasses import dataclass
10from typing import Any, Dict, Set
12import json5
13import jsonschema
14import pytest
16from mlos_bench.config.schemas.config_schemas import ConfigSchema
17from mlos_bench.tests.config import locate_config_examples
20# A dataclass to make pylint happy.
21@dataclass
22class SchemaTestType:
23 """The different type of schema test cases we expect to have."""
25 test_case_type: str
26 test_case_subtypes: Set[str]
28 def __hash__(self) -> int:
29 return hash(self.test_case_type)
32# The different type of schema test cases we expect to have.
33_SCHEMA_TEST_TYPES = {
34 x.test_case_type: x
35 for x in (
36 SchemaTestType(test_case_type="good", test_case_subtypes={"full", "partial"}),
37 SchemaTestType(test_case_type="bad", test_case_subtypes={"invalid", "unhandled"}),
38 )
39}
42@dataclass
43class SchemaTestCaseInfo:
44 """Some basic info about a schema test case."""
46 config: Dict[str, Any]
47 test_case_file: str
48 test_case_type: str
49 test_case_subtype: str
51 def __hash__(self) -> int:
52 return hash(self.test_case_file)
55def check_schema_dir_layout(test_cases_root: str) -> None:
56 """Makes sure the directory layout matches what we expect so we aren't missing any
57 extra configs or test cases.
58 """
59 for test_case_dir in os.listdir(test_cases_root):
60 if test_case_dir == "README.md":
61 continue
62 if test_case_dir not in _SCHEMA_TEST_TYPES:
63 raise NotImplementedError(f"Unhandled test case type: {test_case_dir}")
64 for test_case_subdir in os.listdir(os.path.join(test_cases_root, test_case_dir)):
65 if test_case_subdir == "README.md":
66 continue
67 if test_case_subdir not in _SCHEMA_TEST_TYPES[test_case_dir].test_case_subtypes:
68 raise NotImplementedError(
69 f"Unhandled test case subtype {test_case_subdir} "
70 f"for test case type {test_case_dir}"
71 )
74@dataclass
75class TestCases:
76 """A container for test cases by type."""
78 by_path: Dict[str, SchemaTestCaseInfo]
79 by_type: Dict[str, Dict[str, SchemaTestCaseInfo]]
80 by_subtype: Dict[str, Dict[str, SchemaTestCaseInfo]]
83def get_schema_test_cases(test_cases_root: str) -> TestCases:
84 """Gets a dict of schema test cases from the given root."""
85 test_cases = TestCases(
86 by_path={},
87 by_type={x: {} for x in _SCHEMA_TEST_TYPES},
88 by_subtype={
89 y: {} for x in _SCHEMA_TEST_TYPES for y in _SCHEMA_TEST_TYPES[x].test_case_subtypes
90 },
91 )
92 check_schema_dir_layout(test_cases_root)
93 # Note: we sort the test cases so that we can deterministically test them in parallel.
94 for test_case_type, schema_test_type in _SCHEMA_TEST_TYPES.items():
95 for test_case_subtype in schema_test_type.test_case_subtypes:
96 for test_case_file in locate_config_examples(
97 test_cases_root, os.path.join(test_case_type, test_case_subtype)
98 ):
99 with open(test_case_file, mode="r", encoding="utf-8") as test_case_fh:
100 try:
101 test_case_info = SchemaTestCaseInfo(
102 config=json5.load(test_case_fh),
103 test_case_file=test_case_file,
104 test_case_type=test_case_type,
105 test_case_subtype=test_case_subtype,
106 )
107 test_cases.by_path[test_case_info.test_case_file] = test_case_info
108 test_cases.by_type[test_case_info.test_case_type][
109 test_case_info.test_case_file
110 ] = test_case_info
111 test_cases.by_subtype[test_case_info.test_case_subtype][
112 test_case_info.test_case_file
113 ] = test_case_info
114 except Exception as ex:
115 raise RuntimeError("Failed to load test case: " + test_case_file) from ex
116 assert test_cases
118 assert len(test_cases.by_type["good"]) > 0
119 assert len(test_cases.by_type["bad"]) > 0
120 assert len(test_cases.by_subtype) > 2
122 return test_cases
125def check_test_case_against_schema(
126 test_case: SchemaTestCaseInfo,
127 schema_type: ConfigSchema,
128) -> None:
129 """
130 Checks the given test case against the given schema.
132 Parameters
133 ----------
134 test_case : SchemaTestCaseInfo
135 Schema test case to check.
136 schema_type : ConfigSchema
137 Schema to check against, e.g., ENVIRONMENT or SERVICE.
139 Raises
140 ------
141 NotImplementedError
142 If test case is not known.
143 """
144 if test_case.test_case_type == "good":
145 schema_type.validate(test_case.config)
146 elif test_case.test_case_type == "bad":
147 with pytest.raises(jsonschema.ValidationError):
148 schema_type.validate(test_case.config)
149 else:
150 raise NotImplementedError(f"Unknown test case type: {test_case.test_case_type}")
153def check_test_case_config_with_extra_param(
154 test_case: SchemaTestCaseInfo,
155 schema_type: ConfigSchema,
156) -> None:
157 """Checks that the config fails to validate if extra params are present in certain
158 places.
159 """
160 config = deepcopy(test_case.config)
161 schema_type.validate(config)
162 extra_outer_attr = "extra_outer_attr"
163 config[extra_outer_attr] = "should not be here"
164 with pytest.raises(jsonschema.ValidationError):
165 schema_type.validate(config)
166 del config[extra_outer_attr]
167 if not config.get("config"):
168 config["config"] = {}
169 extra_config_attr = "extra_config_attr"
170 config["config"][extra_config_attr] = "should not be here"
171 with pytest.raises(jsonschema.ValidationError):
172 schema_type.validate(config)