3
\$\begingroup\$

I have been making a deep learning framework modelled after Pytorch in pure Python. I have made good progress but my framework is currently not seamless to use because the end user of the framework needs to constantly pass around backend= whenever they create a tensor or a module. This is not ergonomic because it means you're constantly passing backend=my_backend to classes and getting a lot of repeated code.

It would be better to do something similar to Pytorch and have the framework maintain some kind of a global state to avoid this.

"""
To enhance the experience of using the framework it maintains a global state.
This allows us to omit `backend=` arguments and fall back to a default backend, which is a module-level global.

It can be overridden at any time with a `with` keyword, i.e.

with state(BackendType.CUPY):
    ... do things ...

and it can be reset at any time using reset_state()

The default backend is numpy. 

TODO: Add thread safety to this file. 
"""

from contextlib import contextmanager

from framework.autograd.backend import (
    Backend,
    CupyBackend,
    NumpyBackend,
    BackendType
)

global_backend = NumpyBackend()
# later: global_device (when cupy introduced)

def set_state(backend_type: BackendType):
    global global_backend
    if backend_type == BackendType.NUMPY:
        default_backend = NumpyBackend()
    elif backend_type == BackendType.CUPY:
        default_backend = CupyBackend()
    else:
        print("backend_type not recognised. ")

def reset_state():
    global global_backend
    global_backend = NumpyBackend()

@contextmanager
def state(backend: BackendType):
    # Context manager enter
    global global_backend
    prev_backend_type = global_backend.backend_type
    set_state(backend)

    yield

    # Context manager exit
    set_state(prev_backend_type)

  1. Am I likely to run into issues later on in development by using this design?
  2. Is my current design implicitly not thread safe, meaning I can e.g. no longer run tests in parallel?

Point (2) arises because when I was looking at how Pytorch handles this, it looks like it was thread local.

Note that the above code snippet depends on autograd/backend.py which I show below

import numpy as np
from typing import Union, Tuple
from enum import Enum

"""
A Backend contains a registry of primitive operations.
Calls to the backend are delegated to the appropriate Numpy/Cupy method implementations. 

As a rule of thumb - import numpy as np, import cupy as cp, should only appear in this file. 
"""

# Later: Union[np.ndarray, cp.ndarray]
type BackendArray  = Union[np.ndarray, "cp.ndarray"]

class BackendType(Enum):
    NUMPY = "numpy"
    CUPY = "cupy"

class Backend:
    def as_array(self, a) -> BackendArray:
        """
        Attempt to coerce `a` into a backend-suitable array type. 
        """
        raise NotImplementedError

    # Arithmetic operations
    def add(self, a: BackendArray, b: BackendArray) -> BackendArray : raise NotImplementedError
    def sub(self, a: BackendArray, b: BackendArray) -> BackendArray : raise NotImplementedError
    def mul(self, a: BackendArray, b: BackendArray) -> BackendArray : raise NotImplementedError
    def matmul(self, a: BackendArray, b: BackendArray) -> BackendArray : raise NotImplementedError
    def true_div(self, a: BackendArray, b: BackendArray) -> BackendArray : raise NotImplementedError
    def neg(self, a: BackendArray) -> BackendArray : raise NotImplementedError
    
    # Shape operations
    def transpose(self, a: BackendArray) -> BackendArray : raise NotImplementedError
    def reshape(self, a: BackendArray, newshape: Tuple[int, ...]) -> BackendArray : raise NotImplementedError
    def flatten(self, a: BackendArray) -> BackendArray : raise NotImplementedError
    def ndim(self, a: BackendArray) -> int : raise NotImplementedError
    def shape(self, a: BackendArray) -> Tuple[int, ...] : raise NotImplementedError

    # Reduction operations
    def sum(self, a: BackendArray, axis: int | Tuple[int, ...] | None = None, keepdims: bool = False) -> BackendArray : raise NotImplementedError
    def maximum(self, a: BackendArray) -> BackendArray : raise NotImplementedError
    def minimum(self, a: BackendArray) -> BackendArray : raise NotImplementedError
    def max_eltwise(self, a: BackendArray, b) -> BackendArray : raise NotImplementedError
    def min_eltwise(self, a: BackendArray, b) -> BackendArray : raise NotImplementedError

    # BackendArray creation
    def ones(self, shape: Tuple[int, ...]) -> BackendArray : raise NotImplementedError
    def zeros(self, shape: Tuple[int, ...]) -> BackendArray : raise NotImplementedError
    def randn(self, *shape: int) -> BackendArray: raise NotImplementedError
    def where(self, condition: BackendArray, x: BackendArray, y: BackendArray) -> BackendArray : raise NotImplementedError

    # Elementwise mathematical functions
    def exp(self, a: BackendArray) -> BackendArray : raise NotImplementedError
    def log(self, a: BackendArray) -> BackendArray : raise NotImplementedError
    def sqrt(self, a: BackendArray) -> BackendArray : raise NotImplementedError

    # Backend type
    @property
    def backend_type(self) -> BackendType: raise NotImplementedError

class NumpyBackend(Backend):
    def as_array(self, a) -> np.ndarray:
        try:
            return np.array(a, dtype=np.float32) if not isinstance(a, np.ndarray) else a
        except Exception:
            raise ValueError("Could not coerce `a` into a numpy array.")

    # Arithmetic operations
    def add(self, a: np.ndarray, b: np.ndarray) -> np.ndarray: return np.add(a, b)
    def sub(self, a: np.ndarray, b: np.ndarray) -> np.ndarray: return np.subtract(a, b)
    def mul(self, a: np.ndarray, b: np.ndarray) -> np.ndarray: return np.multiply(a, b)
    def matmul(self, a: np.ndarray, b: np.ndarray) -> np.ndarray: return np.matmul(a, b)
    def true_div(self, a: np.ndarray, b: np.ndarray) -> np.ndarray: return np.true_divide(a, b)
    def neg(self, a: np.ndarray) -> np.ndarray: return np.negative(a)
    
    # Shape operations
    def transpose(self, a: np.ndarray) -> np.ndarray: return np.transpose(a)
    def reshape(self, a: np.ndarray, newshape: Tuple[int, ...]) -> np.ndarray: return np.reshape(a, newshape)
    def flatten(self, a: np.ndarray) -> np.ndarray: return a.flatten()
    def ndim(self, a: np.ndarray) -> int: return a.ndim
    def shape(self, a: np.ndarray) -> tuple: return a.shape

    # Reduction operations
    def sum(self, a: np.ndarray, axis: int | Tuple[int, ...] | None = None, keepdims: bool = False) -> np.ndarray: return np.sum(a, axis=axis, keepdims=keepdims)
    def maximum(self, a: np.ndarray) -> np.ndarray: return np.max(a)
    def minimum(self, a: np.ndarray) -> np.ndarray: return np.min(a)
    def max_eltwise(self, a: np.ndarray, b) -> np.ndarray: return np.maximum(a, b)
    def min_eltwise(self, a: np.ndarray, b) -> np.ndarray: return np.minimum(a, b)

    # BackendArray creation
    def ones(self, shape: Tuple[int, ...]) -> np.ndarray: return np.ones(shape, dtype=np.float32)
    def zeros(self, shape: Tuple[int, ...]) -> np.ndarray: return np.zeros(shape, dtype=np.float32)
    def randn(self, *shape: int) -> np.ndarray: return np.random.randn(*shape).astype(np.float32)
    def where(self, condition: np.ndarray, x: np.ndarray, y: np.ndarray) -> np.ndarray: return np.where(condition, x, y)

    # Elementwise mathematical functions
    def exp(self, a: np.ndarray) -> np.ndarray: return np.exp(a)
    def log(self, a: np.ndarray) -> np.ndarray: return np.log(a)
    def sqrt(self, a: np.ndarray) -> np.ndarray: return np.sqrt(a)

    @property
    def backend_type(self) -> BackendType: return BackendType.NUMPY

class CupyBackend(Backend):
    pass
\$\endgroup\$

1 Answer 1

4
\$\begingroup\$

extra enum

    if backend_type == BackendType.NUMPY:
        default_backend = NumpyBackend()
    elif backend_type == BackendType.CUPY:
        default_backend = CupyBackend()

It's unclear what the .NUMPY and .CUPY enums are buying you, there. Is there some design reason for not having the caller pass in NumpyBackend or CupyBackend? If passing it in is fine, then just () call the constructor here, similar to SqlAlchemy's session manager.

Suppose there is some design reason for the decoupling; maybe we don't want the calling module to do an expensive import. Then it may still be attractive to replace if elif elif with a dict mapping. Or hang each class reference directly on each enum value, where you initially defined those enums.

naming

def reset_state():
    ...
    global_backend = NumpyBackend()

I found that a bit unintuitive. It is not at all agnostic about backend, so it feels more like def reset_to_numpy_state():. Consider using a keyword default:

def reset_state(Backend=NumpyBackend):
    ...
    global_backend = Backend()

threads

If you're concerned that one of your backends is not currently threadsafe, or there may be a non-threadsafe backend in future, then just call threading.get_ident() to allocate thread-local storage in some global dict.

You can e.g. create a new NumpyBackend() for each thread.

consistency

def state(backend: BackendType):

The formal parameter seems to have the wrong name. Based on names used in the rest of the code, it appears we want

def state(backend_type: BackendType):

ancient annotations

Thank you for the annotations; they are very helpful to the reader.

from typing import ... Tuple
...
    def ones(self, shape: Tuple[int, ...]) -> BackendArray : ...

I can't imagine why anyone would want to author new code using that ancient notation, given that interpreters before 3.9 are EOL. We used to need such an import, long long ago. Nowadays, prefer
shape: tuple[... over
shape: Tuple[....

Let's look at a nearby signature:

    def sum(self, a: BackendArray, axis: int | Tuple[int, ...] | None = None, keepdims: bool = False) -> ...

Notice we admit int | None | ... rather than Union[int, None, ...]. That's good; thank you for using the modern notation.

BTW, those np.ndarray declarations are fine. But personally I have obtained better results from mypy and pyright with
from numpy.typing import NDArray, as it offers better control over expressing e.g. np.float32 types.

wrapping an exception

This makes me sad:

    def as_array(self, a) -> np.ndarray:
        try:
            return np.array(a, dtype=np.float32) if not isinstance(a, np.ndarray) else a
        except Exception:
            raise ValueError("Could not coerce `a` into a numpy array.")

Gosh, I wonder what happened? Hard to say, since we discarded the original exception and don't even know its type. Prefer to hang onto exception e, and raise ValueError(...) from e.

\$\endgroup\$

You must log in to answer this question.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.