3
\$\begingroup\$

Given a class BasicDataFeed whose purpose is to feed questions and answers into an artificial neural network, which is tested in a non-negotiable manner as follows:

import numpy as np
import matplotlib.pyplot as plt


dt = 0.001
dims = 4
t_len = 0.1
pause = 0.01
n_items = 4

cor = np.eye(n_items)


def dataset_func(idx, t):
    return cor[idx] * (t*10)


df = BasicDataFeed(dataset_func, np.eye(n_items), t_len, dims, n_items, pause)

t_steps = list(np.arange(0, 2*n_items*(t_len+pause), dt))
df_out = []
ans_out = []

for tt in t_steps:
    df_out.append(df.feed(tt))
    ans_out.append(df.get_answer(tt))

plt.plot(df_out)

plt.gca().set_prop_cycle(None)

plt.plot(ans_out)
plt.show()

With the desired output:

basic_desired_output

How do I write my state-machine in a more elegant manner? I feel like I'm keeping track of time (and thus the state transitions) incorrectly.

import numpy as np

from random import shuffle

dt = 0.001

class BasicDataFeed(object):

    def __init__(self, dataset, correct, t_len: float, dims: int, n_items: int, pause: float):
        self.data_index = 0
        self.paused = False

        self.time = 0.0
        self.sig_time = 0

        self.pause_time = pause
        self.q_duration = t_len
        self.ans_duration = self.q_duration + self.pause_time

        self.correct = correct
        self.qs = dataset
        self.num_items = n_items
        self.dims = dims
        self.indices = list(range(self.num_items))

    def get_answer(self, t):
        """Signal for correct answer"""
        if self.pause_time < self.time < self.ans_duration:
            return self.correct[self.indices[self.data_index]]
        else:
            return np.zeros(self.num_items)

    def feed(self, t):
        """Feed the question into the network
        this is the main state machine of the network"""
        self.time += dt

        if self.time > self.pause_time and self.sig_time > self.q_duration:

            if self.data_index < self.num_items - 1:
                self.data_index += 1
            else:
                shuffle(self.indices)
                self.data_index = 0

            self.time = 0.0
            self.sig_time = 0.0

        elif self.time > self.pause_time:
            self.paused = False

            q_idx = self.indices[self.data_index]
            return_val = self.qs(q_idx, self.sig_time)
            self.sig_time += dt
            return return_val

        else:
            self.paused = True

        return np.zeros(self.dims)

Note: This is a continuation of my previous question about time-dependent state machine. Also, don't worry about the type annotations, they're just there to help me debug and act as documentation for the code.

\$\endgroup\$
1
  • \$\begingroup\$ I would split the behavior of the state machine with the input of events. In other words, use a context that does all the time related stuff, and feed events at the right moment to a queue. Let the state machine listen to the queue and process events whenever they are triggered. \$\endgroup\$ Commented May 21, 2019 at 17:34

1 Answer 1

3
\$\begingroup\$

tested in a non-negotiable manner as follows

Cool. That's not a test - it's not reproducible because you haven't set a random seed, and you haven't asserted anything. It's just a demo.

From the top:

Move all of the variables dt, dims etc. to a local namespace; make dt a parameter on BasicDataFeed. dataset_func can accept a third parameter cor that you bind from the demo function with a partial.

t_steps should not be cast to a list.

Add PEP484 type hints.

tt is not used - you can drop it from the loop and the parameter to both feed and get_answer.

Convert the matplotlib calls from non-re-entrant style to re-entrant style.

indices should be an ndarray, not a list.

df_out and ans_out should not be lists; they should also be pre-allocated ndarray.

Don't use the Python random module; use a Numpy generator.

All together,

import functools
import typing

import numpy as np
import matplotlib.pyplot as plt


class DatasetFunc(typing.Protocol):
    def __call__(self, idx: int, t: float) -> np.ndarray:
        ...


class BasicDataFeed:
    def __init__(
        self,
        dataset: DatasetFunc,
        correct: np.ndarray,
        rand: np.random.Generator,
        t_len: float,
        dims: int,
        n_items: int,
        pause: float,
        dt: float,
    ) -> None:
        self.data_index = 0
        self.paused = False

        self.time = 0.
        self.sig_time = 0.
        self.dt = dt

        self.pause_time = pause
        self.q_duration = t_len
        self.ans_duration = self.q_duration + self.pause_time

        self.correct = correct
        self.rand = rand
        self.qs = dataset
        self.num_items = n_items
        self.dims = dims
        self.indices = np.arange(self.num_items)

    def get_answer(self) -> np.ndarray:
        """Signal for correct answer"""
        if self.pause_time < self.time < self.ans_duration:
            return self.correct[self.indices[self.data_index]]
        return np.zeros(shape=self.num_items, dtype=self.correct.dtype)

    def feed(self) -> np.ndarray:
        """Feed the question into the network
        this is the main state machine of the network"""
        self.time += self.dt

        if self.time > self.pause_time and self.sig_time > self.q_duration:
            if self.data_index < self.num_items - 1:
                self.data_index += 1
            else:
                self.rand.shuffle(self.indices)
                self.data_index = 0

            self.time = 0.
            self.sig_time = 0.

        elif self.time > self.pause_time:
            self.paused = False

            q_idx = self.indices[self.data_index]
            return_val = self.qs(idx=q_idx, t=self.sig_time)
            self.sig_time += self.dt
            return return_val

        else:
            self.paused = True

        return np.zeros(shape=self.dims)


def dataset_func(idx: int, t: float, cor: np.ndarray) -> np.ndarray:
    return cor[idx, :] * (t*10)


def demo() -> None:
    dt = 0.001
    dims = 4
    t_len = 0.1
    pause = 0.01
    n_items = 4

    cor = np.eye(n_items, dtype=np.uint8)
    rand = np.random.default_rng(seed=0)

    df = BasicDataFeed(
        dataset=functools.partial(dataset_func, cor=cor),
        correct=cor, rand=rand, t_len=t_len, dims=dims,
        n_items=n_items, pause=pause, dt=dt,
    )

    t_steps = np.arange(0, 2*n_items*(t_len + pause), dt)
    df_out = np.empty(shape=(t_steps.size, n_items), dtype=np.float64)
    ans_out = np.empty(shape=(t_steps.size, n_items), dtype=cor.dtype)

    for i in range(t_steps.size):
        df_out[i, :] = df.feed()
        ans_out[i, :] = df.get_answer()

    fig, ax = plt.subplots()
    ax.plot(df_out)  # Lower sawtooth
    ax.set_prop_cycle(None)
    ax.plot(ans_out)  # Upper squares
    plt.show()


if __name__ == '__main__':
    demo()
\$\endgroup\$

You must log in to answer this question.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.