Skip to main content
2 of 2
deleted 255 characters in body
Maarten Fabré
  • 9.4k
  • 1
  • 16
  • 27

Code style

  • instead of iterating over the indices like this for j in range(start_id_S2, len(S2.timeline)): you can use enumerate like this: for i, s1 in enumerate(S1.timeline[start_id_S1:], start_id_S1):
  • there is no need for all the continue statements and the final else clause in your last if-part

performance

You can also use the fact that the pulses are in order to iterate over the signals concurrently. That way you wont have n! comparisons where you iterate over the whole tree

To do this, I made a small change to your Signal class:

class Signal:
    def __init__(self, freq, start, end, width=0.3):
        self.freq = freq                                    # frequency in Hz
        self.width = float(width)                       # cathodic phase width in ms
        self.start = start                                    # Instant of the first pulse in ms
        self.end = end                                    # End point in ms

        # List of instant at which a stim pulse is triggered in ms
        self.timeline = np.round(np.arange(start, end, 1000/freq), 3)
        self.pulses = np.stack((self.timeline, self.timeline + self.width), axis=1)
        
    def find_closest_t(self, t):
        val = min(self.timeline, key=lambda x:abs(x-t))
        id = np.where(self.timeline==val)[0][0]

        if val <= t or id == 0:
            return val, id
        else:
            return self.timeline[id-1], id-1

    def find_closest_t_np(self, t):
        idx = max(np.searchsorted(self.timeline, t) - 1, 0)
        return idx
        
    def __iter__(self):
        return iter(self.pulses)
        # or yield from map(tuple, self.pulses) # if you need tuples
    
    def __bool__(self):
        return bool(self.timeline.size)

To iterate over all the signals at the same time, we,

  1. assemble a dictionary with the signals and their first pulses.
  2. look for overlap in the pulses
  3. if this is a new overlap, yield it
  4. advance the iteration on the signal where the end of the pulse comes first
  5. if this iterator is exhausted, remove the signal from the dict with signals

code:

from collections import namedtuple

def duo_overlap_iter(signals, perc=0):
    pulse = namedtuple('Pulse', 'name iter index start end ')
    iters = ((i, iter(signal)) for i, signal in enumerate(signals) if signal)
    iters = {name: pulse(name, it, 0, *next(it)) for name, it in iters}
    seen = set()
    
    while iters:
        for overlap in find_overlap(iters.values()):
            if overlap not in seen:
                yield overlap
                seen.add(overlap)
        try:
            p0 = min(iters.values(), key=lambda x: (x.end, -x.start))
            iters[p0.name] = pulse(p0.name, p0.iter, p0.index + 1, *next(p0.iter))
        except StopIteration:
            del iters[p0.name]

To find the overlap, we use itertools.combinations. We yield an overlap as a frozenset with the name and index of the corresponding signals

def find_overlap(pulses):
    for p0, p1 in combinations(pulses, 2):
        p = frozenset(((p0.name, p0.index), (p1.name, p1.index)))
        if p1.start <= p0.end and p0.start <= p1.end:
            yield p

Sample Data

S0 = Signal(20 , 100, 0,)  # empty
S1 = Signal(50, 0, 250)
S2 = Signal(30, 10, 300, 2)
S3 = Signal(20, -10, 280, 2)
signals = S0, S1, S2, S3

sample result

list(duo_overlap_iter(signals))
[frozenset({(1, 2), (3, 1)}), frozenset({(3, 3), (1, 7)}), frozenset({(1, 12), (3, 5)})]

Final results

To get the results in the way your code presents it, you can do something like this:

def overlap_duo_comb_iter(signals, perc=0):

    overlap = {i: [] for i, _ in enumerate(signals)}

    for (s0, i0), (s1, i1) in duo_overlap_iter(signals):
        overlap[s0].append(i0)
        overlap[s1].append(i1)

    return overlap

numpy code

During my revision of the timing, I noticed in your edited code using numpy, you still don't use the fact the signals are sorted. You iterate over the whole timeline instead of stopping once the 2nd signal can not overlap anymore.

def duo_overlap_np(S1, S2, perc):
    p1_overlapping = np.zeros_like(S1.timeline)
    p2_overlapping = np.zeros_like(S2.timeline)

    start = max(S1.start, S2.start)
    start_id_S1 = S1.find_closest_t_np(start)

    stop = min(S1.pulses[-1][1], S2.pulses[-1][1])
    for i, (s1, s1_end) in enumerate(S1.pulses[start_id_S1:], start_id_S1):
        if s1 > stop:
            break

        start_id_S2 = S2.find_closest_t_np(s1)
        for j, (s2, s2_end) in enumerate(S2.pulses[start_id_S2:], start_id_S2):
            if s2 > s1_end:
                break
            if s1 > s2_end:
                continue
            p1_overlapping[i] = 1
            p2_overlapping[j] = 1
                
    return list(np.nonzero(p1_overlapping)[0]), list(np.nonzero(p2_overlapping)[0])

Timings

print(overlap_duo_combination(signals))
% timeit overlap_duo_combination(signals)
{0: [], 1: [2, 7, 12], 2: [], 3: [1, 3, 5]}
1.33 ms ± 72.4 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
print(overlap_duo_combination(signals, func=duo_overlap_np))
assert overlap_duo_combination(signals) == overlap_duo_combination(signals, func=duo_overlap_np)
% timeit overlap_duo_combination(signals, func=duo_overlap_np)
{0: [], 1: [2, 7, 12], 2: [], 3: [1, 3, 5]}
267 µs ± 4.75 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
print(overlap_duo_comb_iter(signals))
assert overlap_duo_combination(signals) == overlap_duo_comb_iter(signals,)
% timeit overlap_duo_comb_iter(signals)
{0: [], 1: [2, 7, 12], 2: [], 3: [1, 3, 5]}
600 µs ± 12.6 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
print(list(duo_overlap_iter(signals)))
% timeit list(duo_overlap_iter(signals))
[frozenset({(1, 2), (3, 1)}), frozenset({(3, 3), (1, 7)}), frozenset({(1, 12), (3, 5)})]
605 µs ± 33.5 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

Code can all be found on github

Conclusion

The adapted numpy code is the fasted for this set of signals, but only works with fixed-length pulses. So if your pulses are fixed length, use the numpy code, if you just get pulses, with a start and end, my code can handle those too

Maarten Fabré
  • 9.4k
  • 1
  • 16
  • 27