Just let each thread pull pairs of (a,b) and compute the partial result until all pairs are consumed.
Main thread:
output = 0
iter = iterator_over(A,B)
// start threads and wait until done
answer = output / size(A) / size(B)
return answer
Each thread:
res = 0
while true:
synchronized:
if iter.hasNext():
a,b = it.next()
else:
break
res += f(a, b, n)
synchronized:
output += res
For optimal performance, the amount of threads should be the same as the amount of virtual cores.