Commit 998f1c68 authored by sdiazpier's avatar sdiazpier

Split into numba cuda template and runner

parent cce8c9de
Pipeline #2434 failed with stages
......@@ -66,10 +66,9 @@ __global__ void integrate(/*{{{*/
for (unsigned int t = i_step; t < (i_step + n_step); t++)
{
for (unsigned int i_node = threadIdx.y; i_node < n_node; i_node+=blockDim.y)
for (unsigned int i_node = 0; i_node < n_node; i_node++)
{
if (i_node >= n_node) continue;
float theta_i = state(t % NH, i_node);
unsigned int i_n = i_node * n_node;
float sum = 0.0f;
......
......@@ -19,44 +19,12 @@ import numba.cuda as _lpy_ncu
import numba as _lpy_numba
from tvb_hpc import utils, network, model
from typing import List
import time
# TODO Add call to the generated numbacuda code
LOG = utils.getLogger('tvb_hpc')
LOG = utils.getLogger('tvb_hpc')
@_lpy_ncu.jit
def Kuramoto_and_Network_and_EulerStep_inner(nstep, nnode, ntime, state, input, param, drift, diffs, obsrv, nnz, delays, row, col, weights, a, i_step_0):
#Get the id for each thread
tcoupling = _lpy_ncu.threadIdx.x
tspeed = _lpy_ncu.blockIdx.x
sid = _lpy_ncu.gridDim.x
idp = tspeed*_lpy_ncu.blockDim.x+tcoupling
#for each simulation step and for each node in the system
for i_step in range(0, nstep):
for i_node in range(0, nnode):
#calculate the node index
idx= idp*nnode+i_node
#get the node params, in this case only omega
omega = param[i_node]
#retrieve the range of connected nodes
j_node_lo = row[i_node]
j_node_hi = row[i_node + 1]
#calculate the input from other nodes at the current step
acc_j_node = 0
for j_node in range(j_node_lo, j_node_hi):
acc_j_node = acc_j_node + weights[j_node]*m.sin(obsrv[((idp*ntime+((i_step + i_step_0) % ntime) + -1*delays[tspeed*nnz+j_node])*nnode+col[j_node])*2] + -1*obsrv[((idp*ntime+((i_step + i_step_0) % ntime))*nnode+i_node)*2])
input[idx] = a[tcoupling]*acc_j_node / nnode
#calculate the whole drift for the simulation step
drift[idx] = omega + input[idx]
#update the state
state[idx] = state[idx] + drift[idx]
#wrap the state within the desired limits
state[idx] = (state[idx] < 0)*(state[idx] + 6.283185307179586) + (state[idx] > 6.283185307179586)*(state[idx] + -6.283185307179586) + (state[idx] >= 0)*(state[idx] <= 6.283185307179586)*state[idx]
theta = state[idx]
#write the state to the observables data structure
obsrv[((idp*ntime + ((i_step + i_step_0) % ntime))*nnode + i_node)*2 + 1] = m.sin(theta)
obsrv[((idp*ntime + ((i_step + i_step_0) % ntime))*nnode + i_node)*2] = theta
def make_data():
c = network.Connectivity.hcp0()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment