Coverage for mlos_core/mlos_core/tests/__init__.py: 100%
19 statements
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-21 01:50 +0000
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-21 01:50 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""Common functions for mlos_core Optimizer tests."""
7from importlib import import_module
8from pkgutil import walk_packages
9from types import ModuleType
10from typing import TypeVar
12# A common seed to use to avoid tracking down race conditions and intermingling
13# issues of seeds across tests that run in non-deterministic parallel orders.
14SEED = 42
16T = TypeVar("T")
19def get_all_submodules(pkg: ModuleType) -> list[str]:
20 """
21 Imports all submodules for a package and returns their names.
23 Useful for dynamically enumerating subclasses.
24 """
25 submodules = []
26 for _, submodule_name, _ in walk_packages(
27 pkg.__path__, prefix=f"{pkg.__name__}.", onerror=lambda x: None
28 ):
29 submodules.append(submodule_name)
30 return submodules
33def _get_all_subclasses(cls: type[T]) -> set[type[T]]:
34 """
35 Gets the set of all of the subclasses of the given class.
37 Useful for dynamically enumerating expected test cases.
38 """
39 return set(cls.__subclasses__()).union(
40 s for c in cls.__subclasses__() for s in _get_all_subclasses(c)
41 )
44def get_all_concrete_subclasses(cls: type[T], pkg_name: str | None = None) -> list[type[T]]:
45 """
46 Gets a sorted list of all of the concrete subclasses of the given class. Useful for
47 dynamically enumerating expected test cases.
49 Note: For abstract types, mypy will complain at the call site.
50 Use "# type: ignore[type-abstract]" to suppress the warning.
51 See Also: https://github.com/python/mypy/issues/4717
52 """
53 if pkg_name is not None:
54 pkg = import_module(pkg_name)
55 submodules = get_all_submodules(pkg)
56 assert submodules
57 return sorted(
58 [
59 subclass
60 for subclass in _get_all_subclasses(cls)
61 if not getattr(subclass, "__abstractmethods__", None)
62 ],
63 key=lambda c: (c.__module__, c.__name__),
64 )