...
 
Commits (2)
import ast
import astor
class IndexArithmetic(ast.NodeTransformer):
def __init__(self, index_map):
self.index_map = index_map
def visit_Name(self, node):
if node.id not in self.index_map:
return node
index = self.index_map[node.id]
new_node = ast.Subscript(
value=ast.Name(id='X', ctx=ast.Load()),
slice=ast.Index(value=index),
ctx=node.ctx
)
return new_node
def index_dfun(model):
n_svar = len(model.state)
svar_i = dict(zip(model.state,range(n_svar)))
dfun_asts = []
for drift in model.drift:
dfun_asts.append(ast.parse(drift).body[0].value)
svar_index = dict.fromkeys(model.state)
for i, var in enumerate(model.state):
index = ast.Tuple(
ctx=ast.Load(),
elts=[ ast.Num(n=i),
ast.Name(id='t', ctx=ast.Load())
])
svar_index[var] = index
dfun_idx_asts = []
for i, var in enumerate(model.state):
index = svar_index[var]
dX = ast.Subscript(
value=ast.Name(id='dX', ctx=ast.Load()),
slice=ast.Index(value=index),
ctx=ast.Store() )
dfun_idx_ast = ast.Assign(
targets=[dX],
value=IndexArithmetic(svar_index).generic_visit(dfun_asts[i]))
dfun_idx_asts.append( dfun_idx_ast )
nowrap = lambda x:''.join(x)
dfun_strings = list(map(lambda x: astor.to_source(x, pretty_source=nowrap),
dfun_idx_asts ) )
return dfun_strings
if __name__=="__main__":
from ppi import G2DO
from mako.template import Template
dfuns = index_dfun(G2DO)
template = Template(filename='numba_model.py')
print(template.render(dfuns=dfuns, const=G2DO.const, cvar=G2DO.input))
......@@ -5,6 +5,6 @@ template = Template(filename='template.py')
print(template.render( name='Generic2D',
const=G2DO.const,
limit=G2DO.limit,
sv=G2DO.state.split(),
sv=G2DO.state,
drift=G2DO.drift,
input=G2DO.input.split() ) )
import numpy
from numba import cuda, float32, guvectorize, float64
def make_model():
"Construct CUDA device function for the model."
# parameters
%for c,val in const.items():
${c}=float32(${val})
%endfor
@cuda.jit(device=True)
def f(dX, X, ${cvar}):
t = cuda.threadIdx.x
% for dX in dfuns:
${dX}
%endfor
return f
......@@ -8,7 +8,7 @@ TRAJ_STEPS = 4096
class G2DO:
"Generic nonlinear 2-D (phase plane) oscillator."
state = 'W V'
state = 'W', 'V'
limit = (-5, 5), (-5, 5)
input = 'c_0'
param = 'a'
......@@ -76,10 +76,10 @@ if __name__ == '__main__':
def clear(event):
ax.clear()
Q = ax.quiver(Y1, Y2, u, v, color='r')
Q = ax.quiver(Y1, Y2, u, v, color='r')
ax.set_xlabel('$y_1$')
ax.set_ylabel('$y_2$')
ax.set_xlim(-2.0,2.0)
ax.set_xlim(-2.0,2.0)
ax.set_ylim(-2.0,2.0)
plt.draw()
......
import ast
import pprint
def pformat_ast(node, include_attrs=False, **kws):
return pprint.pformat(ast2tree(node, include_attrs), **kws)
def ast2tree(node, include_attrs=True):
def _transform(node):
if isinstance(node, ast.AST):
fields = ((a, _transform(b))
for a, b in ast.iter_fields(node))
if include_attrs:
attrs = ((a, _transform(getattr(node, a)))
for a in node._attributes
if hasattr(node, a))
return (node.__class__.__name__, dict(fields), dict(attrs))
return (node.__class__.__name__, dict(fields))
elif isinstance(node, list):
return [_transform(x) for x in node]
elif isinstance(node, str):
return repr(node)
return node
if not isinstance(node, ast.AST):
raise TypeError('expected AST, got %r' % node.__class__.__name__)
return _transform(node)