2

I have some pydantic BaseModels, that I want to populate with values.

from enum import Enum
from typing import Dict, List, Literal, Type, Union, overload    

from pydantic import BaseModel


class Document(BaseModel):
    name: str
    pages: int


class DocumentA(Document):
    reviewer: str


class DocumentB(Document):
    columns: Dict[str, Dict]


class DocumentC(Document):
    reviewer: str
    tools: List[str]

Example Values:

db = {
    "A": {
        0: {"name": "Document 1", "pages": 2, "reviewer": "Person A"},
        1: {"name": "Document 2", "pages": 3, "reviewer": "Person B"},
    },
    "B": {
        0: {"name": "Document 1", "pages": 1, "columns": {"colA": "A", "colB": "B"}},
        1: {"name": "Document 2", "pages": 5, "columns": {"colC": "C", "colD": "D"}},
    },
    "C": {
        0: {"name": "Document 1", "pages": 7, "reviewer": "Person C", "tools": ["hammer"]},
        1: {"name": "Document 2", "pages": 2, "reviewer": "Person A", "tools": ["hammer", "chisel"]},
    },
}

To load the values into the correct BaseModel Class, I have created a System Class, which is also need elsewhere and has more functionality, but I omitted details for clarity.

class System(Enum):
    A = ("A", DocumentA)
    B = ("B", DocumentB)
    C = ("C", DocumentC)

    @property
    def key(self)-> str:
        return self.value[0]

    @property
    def Document(self) -> Union[Type[DocumentA], Type[DocumentB], Type[DocumentC]]:
        return self.value[1]

Then, through System["A"].Document I can access DocumentA directly. To load the values, I use this function (disregard handling IndexErrors for now):

def load_document(db: Dict, idx: int, system: System) -> Union[DocumentA, DocumentB, DocumentC]:
    data = db[system.key][idx]
    return system.Document(**data)

Now, I might need to handle some of data of type B which I load directly in the handling function.

def handle_document_B(db: Dict, idx: int):
    doc = load_document(db=db, idx=idx, system=System.B)
    # Following line raises mypy errors
    # Item "DocumentA" of "Union[DocumentA, DocumentB, DocumentC]" has no attribute "columns"
    # Item "DocumentC" of "Union[DocumentA, DocumentB, DocumentC]" has no attribute "columns"
    print(doc.columns)

Running mypy raises errors on the line print(doc.columns), since load_document has a typed return value of Union[DocumentA, DocumentB, DocumentC], and obviously DocumentA and DocumentC cannot access the columns attribute. But the only Document type that could be loaded here is DocumentB anyways.

I know I could load the Document outside of the handler function and pass it instead, but I would prefer to load it in the handler.

I circumvented the type issue by overloading the load_document function with the correct Document class, but this seems like a tedious solution since I need to manually add an overloader for each System that might be added in the future.

Is it possible to conditionally type hint a functions return value based on an Enum input value?

1
  • It would be also great to make db (and db items) a TypedDict to have more typechecking benefits. Commented Jun 17, 2022 at 13:31

1 Answer 1

6

You can explicitly annotate return types depending on selected option:

from typing import Literal, overload

@overload
def load_document(db: Dict, idx: int, system: Literal[System.A]) -> DocumentA: ...

@overload
def load_document(db: Dict, idx: int, system: Literal[System.B]) -> DocumentB: ...

@overload
def load_document(db: Dict, idx: int, system: Literal[System.C]) -> DocumentC: ...

def load_document(db: Dict, idx: int, system: System) -> Union[DocumentA, DocumentB, DocumentC]:
    data = db[system.key][idx]
    return system.Document(**data)

The rest of your code typechecks now (playground).

Sign up to request clarification or add additional context in comments.

7 Comments

Literal means just that: a literal. load_document(some_dict, 6, x) won't be accepted, because x is not one of the literals System.A, System.B, or System.C (even if its runtime value is one of those values.)
overload is meant to overload on argument types, not argument values.
Thanks for your answer! I should have stated more clearly that I already found the overload solution (see next to last paragraph in my question). But with this I need to manually write a new overload method for each enum value. I would like to somehow automatically map this.
@chepner overload works fine here, if call is made in form load_document(..., system=System.A). It overloads on types, when argument is literal System.A - it works. Actually, system: System in annotation won't pass type checking if you call load(..., system='A') - and it is expected behavior.
@fstermann There is no such option. I'll search for related question (I'm sure I've seen it before), the resolution was "no, you can only manually write all overloads")
|

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.