Damian's notes – TypeGuards for Union of Callables

Damian Kula

TypeGuards for Union of Callables

Posted on 2023.11.26

I started writing Noiz sometime in 2018, a lot has changed in Python since then, I learned a lot too. Fortunately, even then I was fascinated by type hints and decided to add them wherever I could. This decision has helped a lot with the long-term maintenance of the application, as well as bringing new people into the project. into the project.

Recently, I was working on bypassing some annoying "features" of SQLAlchemy, which led me to look at the code which is responsible for generic work dispatch. In general, Noiz processes data in this way.

Generic workflow schema in Noiz

Because of the generic nature of the work dispatching routines, typing becomes an important but very difficult thing to get right. There is a reason why you can find # type: ignore in many places in our codebase... Let's fix these typing problems.

Current typing annotations

Let's have a look at what the code looks like right now. Note that the signatures are simplified for the sake of brevity.

InputsForMassCalculations = Union[
    InputTypeA,
    InputTypeB,
    InputTypeC,
    InputTypeD,
]

BulkAddableObjects = Union[
    ResultTypeA,
    ResultTypeB,
    ResultTypeC,
    ResultTypeD,
]

def run_calculate_and_upsert_on_dask(
    inputs: Iterable[InputsForMassCalculations],
    calculation_task: Callable[[InputsForMassCalculations], Tuple[BulkAddableObjects, ...]],
) -> NoneType:
    pass

I call this function with such parameters:

def _crosscorrelate_wrapper(
    inputs: InputTypeA,
) -> Tuple[ResultTypeA, ...]:
    pass

def _prepare_inputs_for_crosscorrelations(
    **kwargs
) -> Generator[InputTypeA, None, None]:
    pass


run_calculate_and_upsert_on_dask(
    inputs=_prepare_inputs_for_crosscorrelations(),
    calculation_task=_crosscorrelate_wrapper,
)

This often results in the following error from mypy 1.7.0:

 mypy src/noiz/api/crosscorrelations.py
src/noiz/api/crosscorrelations.py:1073: error: Argument "calculation_task" to
"run_calculate_and_upsert_on_dask" has incompatible type
"Callable[[InputTypeA], tuple[ResultTypeA, ...]]"; expected
"Callable[[Union[InputTypeA, InputTypeB, InputTypeC, InputTypeD]],
tuple[Union[ResultTypeA, ResultTypeB, ResultTypeC, ResultTypeD], ...]]"  [arg-type]

In the past I have often used the type-ignore comment because I was pressed for time. Today I have more time to explore this area, so let's create a self-contained code example to help understand the problem.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
  from typing import Tuple, Union, Callable, Iterable

  InsertType = Union[int, str]
  ReturnType = Union[str, int]


  def first(a: int) -> Tuple[str, ...]:
      return ("aa", "bb")


  def second(a: str) -> Tuple[int, ...]:
      return (1, 2)


  def calculation_runner(
          calculation_callable: Callable[[Iterable[InsertType]], Iterable[ReturnType]],
          argument: InsertType
  ):
      calculation_callable(argument)


  if __name__ == "__main__":
      calculation_runner(calculation_callable=first, argument=1)
      calculation_runner(calculation_callable=second, argument="a")

Running mypy on this example results in:

 mypy callable_of_unions.py
callable_of_unions.py:19: error: Argument 1 has incompatible type "Union[int, str]";
expected "Iterable[Union[int, str]]"  [arg-type]
callable_of_unions.py:23: error: Argument "calculation_callable" to "calculation_runner" has incompatible type
"Callable[[int], tuple[str, ...]]"; expected
"Callable[[Iterable[Union[int, str]]], Iterable[Union[str, int]]]"  [arg-type]
callable_of_unions.py:24: error: Argument "calculation_callable" to "calculation_runner" has incompatible type
"Callable[[str], tuple[int, ...]]"; expected
"Callable[[Iterable[Union[int, str]]], Iterable[Union[str, int]]]"  [arg-type]
Found 3 errors in 1 file (checked 1 source file)

OK, but what's wrong with these annotations? Well, first of all, it shouldn't be typed as a callable of unions, but as a union of callables. The reason for this is quite simple, the annotation Iterable[Union[str, int]] passes both (1,2,3) and (1, "a"). This annotation is not at all conclusive and doesn't guarantee what I expected in the past. Let's investigate further and find a more precise annotation.

Union of Callables

A better approach to typing these would be more like Union of Callables. Since every callable we expect to use expects an iterable of exact input type and will return an exact return type, we shouldn't use unions on either of them. I will simplify the problem a bit more to solve one problem at once: I change the callable to take a single value instead of an iterable. Here is the code:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
  from typing import Tuple, Union, Callable, Iterable, Dict


  UnionOfCallables = Union[
      Callable[[int], Tuple[str, ...]],
      Callable[[str], Dict[int, str]],
  ]

  def first(a: int) -> Tuple[str, ...]:
      return ("aa", "bb")


  def second(a: str) -> Dict[int, str]:
      return {1: "a", 2: ""}


  def calculation_runner(
          calculation_callable: UnionOfCallables,
          argument: Union[int, str]
  ):
      calculation_callable(argument)


  if __name__ == "__main__":
      calculation_runner(calculation_callable=first, argument=1)
      calculation_runner(calculation_callable=second, argument="a")

It starts to make sense, doesn't it? After all, our calculation functions are quite precise, aren't they? In general, we do not expect to mix the input types, nor the output types.

Let's run mypy on this code:

 mypy union_of_callables.py
union_of_callables.py:21: error: Argument 1 has incompatible type "Union[int, str]"; expected "int"  [arg-type]
union_of_callables.py:21: error: Argument 1 has incompatible type "Union[int, str]"; expected "str"  [arg-type]
Found 2 errors in 1 file (checked 1 source file)

Well, that's confusing. Why would mypy return two errors for the same line, which are also contradictory? The reason is that mypy cannot determine what type of argument is passed to a provided callable. In the end, we defined that provided callables can take either str or int, not a union of the two! We were being precise, but now it has bitten us.

Type Guards

Our types are not precise enough and mypy has problems with this. We can fix this problem by using a type narrowing. Mypy has a whole section in its documentation dedicated to this. There is also a discussion on mypy's github which touches on something closer to the problem we are dealing with here.

Generally speaking, TypeGuard is a way of convincing Mypy that you know what you are doing and are not going to shoot yourself in the foot. And we do know what we are doing, right? Maybe you do, but I have no idea what I am doing here.

Back to the point. TypeGuard allows you to restrict the type of a first argument passed to a function. This can be done either for a specific type or for a generic type. Since our case is quite generic, let's have a look at Generic TypeGuards in the documentation. After reading and experimenting a bit, here is the code I came up with:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
from typing import Tuple, Union, Callable, Dict, TypeVar, Any, Type
from typing_extensions import TypeGuard
import inspect


UnionOfCallables = Union[
    Callable[[int], Tuple[str, ...]],
    Callable[[str], Dict[int, str]],
]

_T = TypeVar("_T")

def is_argument_type_equal_to_expected_by_callable(
        callable: Callable[[Any], Any],
        arg: _T,
) -> TypeGuard[Callable[[_T], Any]]:
    return inspect.signature(callable).parameters["a"].annotation == type(arg)


def first(a: int) -> Tuple[str, ...]:
    return ("aa", "bb")


def second(a: str) -> Dict[int, str]:
    return {1: "a", 2: ""}


def calculation_runner(
    calculation_callable: UnionOfCallables,
    argument: Union[str, int]
):
    if is_argument_type_equal_to_expected_by_callable(arg=argument, callable=calculation_callable):
        calculation_callable(argument)
    else:
        raise TypeError("Type of argument provided is different than expected by a provided callable")


if __name__ == "__main__":
    calculation_runner(calculation_callable=first, argument=1)
    calculation_runner(calculation_callable=second, argument="a")

What's happening here?

_T = TypeVar("_T")
def is_argument_type_equal_to_expected_by_callable(
        callable: Callable[[Any], Any],
        arg: _T,
) -> TypeGuard[Callable[[_T], Any]]:
    return inspect.signature(callable).parameters["a"].annotation == type(arg)

This is a TypeGuard method. It inspects a signature of the provided callable and compares the type of the provided argument. If the type of the provided argument is different from what the callable expects, the TypeGuard won't let the code pass the check.

 mypy union_of_callables_with_guard.py
Success: no issues found in 1 source file

Woohoo!

Summary

Unfortunately, this isn't the end of the problem from Noiz's point of view. There are two outstanding issues, but this post is already too long for what it should be. First, the signature checking in the TypeGuard function can easily throw a KeyError. This will happen if the provided callable doesn't expect the argument a. Second, Noiz passes a generator of a given type to these callables, not a single instance. This means that if I want to check these types in a simple way, I will always lose the first element of a generator, which is not acceptable. I'm thinking about using a PushBackIterator but haven't decided yet.

BTW: I discovered this library partly thanks to the PythonBytes podcast which I listened to yesterday. Michael and Brian mentioned Larry Hastings' library, so I went through what else he has on his GitHub and discovered big. I have to admit, the timing couldn't be better.

Anyway, I might add links here to MRs I will produce for refactoring these types in Noiz, but I am not sure if they will be helpful. I expect the diff to be quite messy.