Coverage for mlos_bench/mlos_bench/tests/config/schemas/environments/test_environment_schemas.py: 100%

29 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-10-07 01:52 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5"""Tests for environment schema validation.""" 

6 

7from os import path 

8 

9import pytest 

10 

11from mlos_bench.config.schemas import ConfigSchema 

12from mlos_bench.environments.base_environment import Environment 

13from mlos_bench.environments.composite_env import CompositeEnv 

14from mlos_bench.environments.script_env import ScriptEnv 

15from mlos_bench.tests import try_resolve_class_name 

16from mlos_bench.tests.config.schemas import ( 

17 check_test_case_against_schema, 

18 check_test_case_config_with_extra_param, 

19 get_schema_test_cases, 

20) 

21from mlos_core.tests import get_all_concrete_subclasses 

22 

23# General testing strategy: 

24# - hand code a set of good/bad configs (useful to test editor schema checking) 

25# - enumerate and try to check that we've covered all the cases 

26# - for each config, load and validate against expected schema 

27 

28TEST_CASES = get_schema_test_cases(path.join(path.dirname(__file__), "test-cases")) 

29 

30 

31# Dynamically enumerate some of the cases we want to make sure we cover. 

32 

33NON_CONFIG_ENV_CLASSES = { 

34 # ScriptEnv is ABCMeta abstract, but there's no good way to test that 

35 # dynamically in Python. 

36 ScriptEnv, 

37} 

38expected_environment_class_names = [ 

39 subclass.__module__ + "." + subclass.__name__ 

40 for subclass in get_all_concrete_subclasses(Environment, pkg_name="mlos_bench") 

41 if subclass not in NON_CONFIG_ENV_CLASSES 

42] 

43assert expected_environment_class_names 

44 

45COMPOSITE_ENV_CLASS_NAME = CompositeEnv.__module__ + "." + CompositeEnv.__name__ 

46expected_leaf_environment_class_names = [ 

47 subclass_name 

48 for subclass_name in expected_environment_class_names 

49 if subclass_name != COMPOSITE_ENV_CLASS_NAME 

50] 

51 

52 

53# Do the full cross product of all the test cases and all the Environment types. 

54@pytest.mark.parametrize("test_case_subtype", sorted(TEST_CASES.by_subtype)) 

55@pytest.mark.parametrize("env_class", expected_environment_class_names) 

56def test_case_coverage_mlos_bench_environment_type(test_case_subtype: str, env_class: str) -> None: 

57 """Checks to see if there is a given type of test case for the given mlos_bench 

58 Environment type. 

59 """ 

60 for test_case in TEST_CASES.by_subtype[test_case_subtype].values(): 

61 if try_resolve_class_name(test_case.config.get("class")) == env_class: 

62 return 

63 raise NotImplementedError( 

64 f"Missing test case for subtype {test_case_subtype} for Environment class {env_class}" 

65 ) 

66 

67 

68# Now we actually perform all of those validation tests. 

69 

70 

71@pytest.mark.parametrize("test_case_name", sorted(TEST_CASES.by_path)) 

72def test_environment_configs_against_schema(test_case_name: str) -> None: 

73 """Checks that the environment config validates against the schema.""" 

74 check_test_case_against_schema(TEST_CASES.by_path[test_case_name], ConfigSchema.ENVIRONMENT) 

75 check_test_case_against_schema(TEST_CASES.by_path[test_case_name], ConfigSchema.UNIFIED) 

76 

77 

78@pytest.mark.parametrize("test_case_name", sorted(TEST_CASES.by_type["good"])) 

79def test_environment_configs_with_extra_param(test_case_name: str) -> None: 

80 """Checks that the environment config fails to validate if extra params are present 

81 in certain places. 

82 """ 

83 check_test_case_config_with_extra_param( 

84 TEST_CASES.by_type["good"][test_case_name], 

85 ConfigSchema.ENVIRONMENT, 

86 ) 

87 check_test_case_config_with_extra_param( 

88 TEST_CASES.by_type["good"][test_case_name], 

89 ConfigSchema.UNIFIED, 

90 )