Coverage for mlos_core/mlos_core/tests/__init__.py: 100%

19 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2025-01-21 01:50 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5"""Common functions for mlos_core Optimizer tests.""" 

6 

7from importlib import import_module 

8from pkgutil import walk_packages 

9from types import ModuleType 

10from typing import TypeVar 

11 

12# A common seed to use to avoid tracking down race conditions and intermingling 

13# issues of seeds across tests that run in non-deterministic parallel orders. 

14SEED = 42 

15 

16T = TypeVar("T") 

17 

18 

19def get_all_submodules(pkg: ModuleType) -> list[str]: 

20 """ 

21 Imports all submodules for a package and returns their names. 

22 

23 Useful for dynamically enumerating subclasses. 

24 """ 

25 submodules = [] 

26 for _, submodule_name, _ in walk_packages( 

27 pkg.__path__, prefix=f"{pkg.__name__}.", onerror=lambda x: None 

28 ): 

29 submodules.append(submodule_name) 

30 return submodules 

31 

32 

33def _get_all_subclasses(cls: type[T]) -> set[type[T]]: 

34 """ 

35 Gets the set of all of the subclasses of the given class. 

36 

37 Useful for dynamically enumerating expected test cases. 

38 """ 

39 return set(cls.__subclasses__()).union( 

40 s for c in cls.__subclasses__() for s in _get_all_subclasses(c) 

41 ) 

42 

43 

44def get_all_concrete_subclasses(cls: type[T], pkg_name: str | None = None) -> list[type[T]]: 

45 """ 

46 Gets a sorted list of all of the concrete subclasses of the given class. Useful for 

47 dynamically enumerating expected test cases. 

48 

49 Note: For abstract types, mypy will complain at the call site. 

50 Use "# type: ignore[type-abstract]" to suppress the warning. 

51 See Also: https://github.com/python/mypy/issues/4717 

52 """ 

53 if pkg_name is not None: 

54 pkg = import_module(pkg_name) 

55 submodules = get_all_submodules(pkg) 

56 assert submodules 

57 return sorted( 

58 [ 

59 subclass 

60 for subclass in _get_all_subclasses(cls) 

61 if not getattr(subclass, "__abstractmethods__", None) 

62 ], 

63 key=lambda c: (c.__module__, c.__name__), 

64 )