I have been making a deep learning framework modelled after Pytorch in pure Python. I have made good progress but my framework is currently not seamless to use because the end user of the framework needs to constantly pass around backend= whenever they create a tensor or a module. This is not ergonomic because it means you're constantly passing backend=my_backend to classes and getting a lot of repeated code.
It would be better to do something similar to Pytorch and have the framework maintain some kind of a global state to avoid this.
"""
To enhance the experience of using the framework it maintains a global state.
This allows us to omit `backend=` arguments and fall back to a default backend, which is a module-level global.
It can be overridden at any time with a `with` keyword, i.e.
with state(BackendType.CUPY):
... do things ...
and it can be reset at any time using reset_state()
The default backend is numpy.
TODO: Add thread safety to this file.
"""
from contextlib import contextmanager
from framework.autograd.backend import (
Backend,
CupyBackend,
NumpyBackend,
BackendType
)
global_backend = NumpyBackend()
# later: global_device (when cupy introduced)
def set_state(backend_type: BackendType):
global global_backend
if backend_type == BackendType.NUMPY:
default_backend = NumpyBackend()
elif backend_type == BackendType.CUPY:
default_backend = CupyBackend()
else:
print("backend_type not recognised. ")
def reset_state():
global global_backend
global_backend = NumpyBackend()
@contextmanager
def state(backend: BackendType):
# Context manager enter
global global_backend
prev_backend_type = global_backend.backend_type
set_state(backend)
yield
# Context manager exit
set_state(prev_backend_type)
- Am I likely to run into issues later on in development by using this design?
- Is my current design implicitly not thread safe, meaning I can e.g. no longer run tests in parallel?
Point (2) arises because when I was looking at how Pytorch handles this, it looks like it was thread local.
Note that the above code snippet depends on autograd/backend.py which I show below
import numpy as np
from typing import Union, Tuple
from enum import Enum
"""
A Backend contains a registry of primitive operations.
Calls to the backend are delegated to the appropriate Numpy/Cupy method implementations.
As a rule of thumb - import numpy as np, import cupy as cp, should only appear in this file.
"""
# Later: Union[np.ndarray, cp.ndarray]
type BackendArray = Union[np.ndarray, "cp.ndarray"]
class BackendType(Enum):
NUMPY = "numpy"
CUPY = "cupy"
class Backend:
def as_array(self, a) -> BackendArray:
"""
Attempt to coerce `a` into a backend-suitable array type.
"""
raise NotImplementedError
# Arithmetic operations
def add(self, a: BackendArray, b: BackendArray) -> BackendArray : raise NotImplementedError
def sub(self, a: BackendArray, b: BackendArray) -> BackendArray : raise NotImplementedError
def mul(self, a: BackendArray, b: BackendArray) -> BackendArray : raise NotImplementedError
def matmul(self, a: BackendArray, b: BackendArray) -> BackendArray : raise NotImplementedError
def true_div(self, a: BackendArray, b: BackendArray) -> BackendArray : raise NotImplementedError
def neg(self, a: BackendArray) -> BackendArray : raise NotImplementedError
# Shape operations
def transpose(self, a: BackendArray) -> BackendArray : raise NotImplementedError
def reshape(self, a: BackendArray, newshape: Tuple[int, ...]) -> BackendArray : raise NotImplementedError
def flatten(self, a: BackendArray) -> BackendArray : raise NotImplementedError
def ndim(self, a: BackendArray) -> int : raise NotImplementedError
def shape(self, a: BackendArray) -> Tuple[int, ...] : raise NotImplementedError
# Reduction operations
def sum(self, a: BackendArray, axis: int | Tuple[int, ...] | None = None, keepdims: bool = False) -> BackendArray : raise NotImplementedError
def maximum(self, a: BackendArray) -> BackendArray : raise NotImplementedError
def minimum(self, a: BackendArray) -> BackendArray : raise NotImplementedError
def max_eltwise(self, a: BackendArray, b) -> BackendArray : raise NotImplementedError
def min_eltwise(self, a: BackendArray, b) -> BackendArray : raise NotImplementedError
# BackendArray creation
def ones(self, shape: Tuple[int, ...]) -> BackendArray : raise NotImplementedError
def zeros(self, shape: Tuple[int, ...]) -> BackendArray : raise NotImplementedError
def randn(self, *shape: int) -> BackendArray: raise NotImplementedError
def where(self, condition: BackendArray, x: BackendArray, y: BackendArray) -> BackendArray : raise NotImplementedError
# Elementwise mathematical functions
def exp(self, a: BackendArray) -> BackendArray : raise NotImplementedError
def log(self, a: BackendArray) -> BackendArray : raise NotImplementedError
def sqrt(self, a: BackendArray) -> BackendArray : raise NotImplementedError
# Backend type
@property
def backend_type(self) -> BackendType: raise NotImplementedError
class NumpyBackend(Backend):
def as_array(self, a) -> np.ndarray:
try:
return np.array(a, dtype=np.float32) if not isinstance(a, np.ndarray) else a
except Exception:
raise ValueError("Could not coerce `a` into a numpy array.")
# Arithmetic operations
def add(self, a: np.ndarray, b: np.ndarray) -> np.ndarray: return np.add(a, b)
def sub(self, a: np.ndarray, b: np.ndarray) -> np.ndarray: return np.subtract(a, b)
def mul(self, a: np.ndarray, b: np.ndarray) -> np.ndarray: return np.multiply(a, b)
def matmul(self, a: np.ndarray, b: np.ndarray) -> np.ndarray: return np.matmul(a, b)
def true_div(self, a: np.ndarray, b: np.ndarray) -> np.ndarray: return np.true_divide(a, b)
def neg(self, a: np.ndarray) -> np.ndarray: return np.negative(a)
# Shape operations
def transpose(self, a: np.ndarray) -> np.ndarray: return np.transpose(a)
def reshape(self, a: np.ndarray, newshape: Tuple[int, ...]) -> np.ndarray: return np.reshape(a, newshape)
def flatten(self, a: np.ndarray) -> np.ndarray: return a.flatten()
def ndim(self, a: np.ndarray) -> int: return a.ndim
def shape(self, a: np.ndarray) -> tuple: return a.shape
# Reduction operations
def sum(self, a: np.ndarray, axis: int | Tuple[int, ...] | None = None, keepdims: bool = False) -> np.ndarray: return np.sum(a, axis=axis, keepdims=keepdims)
def maximum(self, a: np.ndarray) -> np.ndarray: return np.max(a)
def minimum(self, a: np.ndarray) -> np.ndarray: return np.min(a)
def max_eltwise(self, a: np.ndarray, b) -> np.ndarray: return np.maximum(a, b)
def min_eltwise(self, a: np.ndarray, b) -> np.ndarray: return np.minimum(a, b)
# BackendArray creation
def ones(self, shape: Tuple[int, ...]) -> np.ndarray: return np.ones(shape, dtype=np.float32)
def zeros(self, shape: Tuple[int, ...]) -> np.ndarray: return np.zeros(shape, dtype=np.float32)
def randn(self, *shape: int) -> np.ndarray: return np.random.randn(*shape).astype(np.float32)
def where(self, condition: np.ndarray, x: np.ndarray, y: np.ndarray) -> np.ndarray: return np.where(condition, x, y)
# Elementwise mathematical functions
def exp(self, a: np.ndarray) -> np.ndarray: return np.exp(a)
def log(self, a: np.ndarray) -> np.ndarray: return np.log(a)
def sqrt(self, a: np.ndarray) -> np.ndarray: return np.sqrt(a)
@property
def backend_type(self) -> BackendType: return BackendType.NUMPY
class CupyBackend(Backend):
pass