Changeset d5ac45f in sasmodels for sasmodels/generate.py


Ignore:
Timestamp:
Mar 20, 2016 5:32:04 PM (8 years ago)
Author:
Paul Kienzle <pkienzle@…>
Branches:
master, core_shell_microgels, costrafo411, magnetic_model, release_v0.94, release_v0.95, ticket-1257-vesicle-product, ticket_1156, ticket_1265_superball, ticket_822_more_unit_tests
Children:
03cac08, 3a45c2c
Parents:
4f9d3fd
Message:

refactoring generate/kernel_template in process…

File:
1 edited

Legend:

Unmodified
Added
Removed
  • sasmodels/generate.py

    r9ef9dd9 rd5ac45f  
    2121 
    2222    *VR(p1, p2, ...)* returns the volume ratio for core-shell style forms. 
     23 
     24    #define INVALID(v) (expr)  returns False if v.parameter is invalid 
     25    for some parameter or other (e.g., v.bell_radius < v.radius).  If 
     26    necessary, the expression can call a function. 
    2327 
    2428These functions are defined in a kernel module .py script and an associated 
     
    216220import sys 
    217221from os.path import abspath, dirname, join as joinpath, exists, basename, \ 
    218     splitext 
     222    splitext, getmtime 
    219223import re 
    220224import string 
     
    224228import numpy as np 
    225229 
     230# TODO: promote Parameter and model_info to classes 
    226231PARAMETER_FIELDS = ['name', 'units', 'default', 'limits', 'type', 'description'] 
    227232Parameter = namedtuple('Parameter', PARAMETER_FIELDS) 
     
    230235#__all__ = ["model_info", "make_doc", "make_source", "convert_type"] 
    231236 
    232 C_KERNEL_TEMPLATE_PATH = joinpath(dirname(__file__), 'kernel_template.c') 
     237TEMPLATE_ROOT = dirname(__file__) 
    233238 
    234239F16 = np.dtype('float16') 
     
    338343    raise ValueError("%r not found in %s" % (filename, search_path)) 
    339344 
     345 
    340346def model_sources(model_info): 
    341347    """ 
     
    346352    return [_search(search_path, f) for f in model_info['source']] 
    347353 
    348 # Pragmas for enable OpenCL features.  Be sure to protect them so that they 
    349 # still compile even if OpenCL is not present. 
    350 _F16_PRAGMA = """\ 
    351 #if defined(__OPENCL_VERSION__) // && !defined(cl_khr_fp16) 
    352 #  pragma OPENCL EXTENSION cl_khr_fp16: enable 
    353 #endif 
    354 """ 
    355  
    356 _F64_PRAGMA = """\ 
    357 #if defined(__OPENCL_VERSION__) // && !defined(cl_khr_fp64) 
    358 #  pragma OPENCL EXTENSION cl_khr_fp64: enable 
    359 #endif 
    360 """ 
    361354 
    362355def convert_type(source, dtype): 
     
    369362    if dtype == F16: 
    370363        fbytes = 2 
    371         source = _F16_PRAGMA + _convert_type(source, "half", "f") 
     364        source = _convert_type(source, "float", "f") 
    372365    elif dtype == F32: 
    373366        fbytes = 4 
     
    375368    elif dtype == F64: 
    376369        fbytes = 8 
    377         source = _F64_PRAGMA + source  # Source is already double 
     370        # no need to convert if it is already double 
    378371    elif dtype == F128: 
    379372        fbytes = 16 
     
    418411 
    419412 
    420 LOOP_OPEN = """\ 
    421 for (int %(name)s_i=0; %(name)s_i < N%(name)s; %(name)s_i++) { 
    422   const double %(name)s = loops[2*(%(name)s_i%(offset)s)]; 
    423   const double %(name)s_w = loops[2*(%(name)s_i%(offset)s)+1];\ 
     413_template_cache = {} 
     414def load_template(filename): 
     415    path = joinpath(TEMPLATE_ROOT, filename) 
     416    mtime = getmtime(path) 
     417    if filename not in _template_cache or mtime > _template_cache[filename][0]: 
     418        with open(path) as fid: 
     419            _template_cache[filename] = (mtime, fid.read()) 
     420    return _template_cache[filename][1] 
     421 
     422def _gen_fn(name, pars, body): 
     423    """ 
     424    Generate a function given pars and body. 
     425 
     426    Returns the following string:: 
     427 
     428         double fn(double a, double b, ...); 
     429         double fn(double a, double b, ...) { 
     430             .... 
     431         } 
     432    """ 
     433    template = """\ 
     434double %(name)s(%(pars)s); 
     435double %(name)s(%(pars)s) { 
     436    %(body)s 
     437} 
     438 
     439 
    424440""" 
    425 def build_polydispersity_loops(pd_pars): 
    426     """ 
    427     Build polydispersity loops 
    428  
    429     Returns loop opening and loop closing 
    430     """ 
    431     depth = 4 
    432     offset = "" 
    433     loop_head = [] 
    434     loop_end = [] 
    435     for name in pd_pars: 
    436         subst = {'name': name, 'offset': offset} 
    437         loop_head.append(indent(LOOP_OPEN % subst, depth)) 
    438         loop_end.insert(0, (" "*depth) + "}") 
    439         offset += '+N' + name 
    440         depth += 2 
    441     return "\n".join(loop_head), "\n".join(loop_end) 
    442  
    443 C_KERNEL_TEMPLATE = None 
     441    par_decl = ', '.join('double ' + p for p in pars) if pars else 'void' 
     442    return template % {'name': name, 'body': body, 'pars': par_decl} 
     443 
     444def _gen_call_pars(name, pars): 
     445    name += "." 
     446    return ",".join(name+p for p in pars) 
     447 
    444448def make_source(model_info): 
    445449    """ 
     
    461465 
    462466    # Load template 
    463     global C_KERNEL_TEMPLATE 
    464     if C_KERNEL_TEMPLATE is None: 
    465         with open(C_KERNEL_TEMPLATE_PATH) as fid: 
    466             C_KERNEL_TEMPLATE = fid.read() 
     467    source = [load_template('kernel_header.c')] 
    467468 
    468469    # Load additional sources 
    469     source = [open(f).read() for f in model_sources(model_info)] 
     470    source += [open(f).read() for f in model_sources(model_info)] 
    470471 
    471472    # Prepare defines 
    472473    defines = [] 
    473     partype = model_info['partype'] 
    474     pd_1d = partype['pd-1d'] 
    475     pd_2d = partype['pd-2d'] 
    476     fixed_1d = partype['fixed-1d'] 
    477     fixed_2d = partype['fixed-1d'] 
    478474 
    479475    iq_parameters = [p.name 
    480476                     for p in model_info['parameters'][2:]  # skip scale, background 
    481                      if p.name in set(fixed_1d + pd_1d)] 
     477                     if p.name in model_info['par_set']['1d']] 
    482478    iqxy_parameters = [p.name 
    483479                       for p in model_info['parameters'][2:]  # skip scale, background 
    484                        if p.name in set(fixed_2d + pd_2d)] 
    485     volume_parameters = [p.name 
    486                          for p in model_info['parameters'] 
    487                          if p.type == 'volume'] 
    488  
    489     # Fill in defintions for volume parameters 
     480                       if p.name in model_info['par_set']['2d']] 
     481    volume_parameters = model_info['par_type']['volume'] 
     482 
     483    # Generate form_volume function, etc. from body only 
     484    if model_info['form_volume'] is not None: 
     485        pnames = [p.name for p in volume_parameters] 
     486        source.append(_gen_fn('form_volume', pnames, model_info['form_volume'])) 
     487    if model_info['Iq'] is not None: 
     488        pnames = ['q'] + [p.name for p in iq_parameters] 
     489        source.append(_gen_fn('Iq', pnames, model_info['Iq'])) 
     490    if model_info['Iqxy'] is not None: 
     491        pnames = ['qx', 'qy'] + [p.name for p in iqxy_parameters] 
     492        source.append(_gen_fn('Iqxy', pnames, model_info['Iqxy'])) 
     493 
     494    # Fill in definitions for volume parameters 
    490495    if volume_parameters: 
    491         defines.append(('VOLUME_PARAMETERS', 
    492                         ','.join(volume_parameters))) 
    493         defines.append(('VOLUME_WEIGHT_PRODUCT', 
    494                         '*'.join(p + '_w' for p in volume_parameters))) 
    495  
    496     # Generate form_volume function from body only 
    497     if model_info['form_volume'] is not None: 
    498         if volume_parameters: 
    499             vol_par_decl = ', '.join('double ' + p for p in volume_parameters) 
    500         else: 
    501             vol_par_decl = 'void' 
    502         defines.append(('VOLUME_PARAMETER_DECLARATIONS', 
    503                         vol_par_decl)) 
    504         fn = """\ 
    505 double form_volume(VOLUME_PARAMETER_DECLARATIONS); 
    506 double form_volume(VOLUME_PARAMETER_DECLARATIONS) { 
    507     %(body)s 
    508 } 
    509 """ % {'body':model_info['form_volume']} 
    510         source.append(fn) 
     496        deref_vol = ",".join("v."+p.name for p in volume_parameters) 
     497        defines.append(('CALL_VOLUME(v)', 'form_volume(%s)\n'%deref_vol)) 
     498    else: 
     499        # Model doesn't have volume.  We could make the kernel run a little 
     500        # faster by not using/transferring the volume normalizations, but 
     501        # the ifdef's reduce readability more than is worthwhile. 
     502        defines.append(('CALL_VOLUME(v)', '0.0')) 
    511503 
    512504    # Fill in definitions for Iq parameters 
    513     defines.append(('IQ_KERNEL_NAME', model_info['name'] + '_Iq')) 
     505    defines.append(('KERNEL_NAME', model_info['name'])) 
    514506    defines.append(('IQ_PARAMETERS', ', '.join(iq_parameters))) 
    515507    if fixed_1d: 
    516508        defines.append(('IQ_FIXED_PARAMETER_DECLARATIONS', 
    517509                        ', \\\n    '.join('const double %s' % p for p in fixed_1d))) 
    518     if pd_1d: 
    519         defines.append(('IQ_WEIGHT_PRODUCT', 
    520                         '*'.join(p + '_w' for p in pd_1d))) 
    521         defines.append(('IQ_DISPERSION_LENGTH_DECLARATIONS', 
    522                         ', \\\n    '.join('const int N%s' % p for p in pd_1d))) 
    523         defines.append(('IQ_DISPERSION_LENGTH_SUM', 
    524                         '+'.join('N' + p for p in pd_1d))) 
    525         open_loops, close_loops = build_polydispersity_loops(pd_1d) 
    526         defines.append(('IQ_OPEN_LOOPS', 
    527                         open_loops.replace('\n', ' \\\n'))) 
    528         defines.append(('IQ_CLOSE_LOOPS', 
    529                         close_loops.replace('\n', ' \\\n'))) 
    530     if model_info['Iq'] is not None: 
    531         defines.append(('IQ_PARAMETER_DECLARATIONS', 
    532                         ', '.join('double ' + p for p in iq_parameters))) 
    533         fn = """\ 
    534 double Iq(double q, IQ_PARAMETER_DECLARATIONS); 
    535 double Iq(double q, IQ_PARAMETER_DECLARATIONS) { 
    536     %(body)s 
    537 } 
    538 """ % {'body':model_info['Iq']} 
    539         source.append(fn) 
    540  
    541510    # Fill in definitions for Iqxy parameters 
    542511    defines.append(('IQXY_KERNEL_NAME', model_info['name'] + '_Iqxy')) 
     
    557526        defines.append(('IQXY_CLOSE_LOOPS', 
    558527                        close_loops.replace('\n', ' \\\n'))) 
    559     if model_info['Iqxy'] is not None: 
    560         defines.append(('IQXY_PARAMETER_DECLARATIONS', 
    561                         ', '.join('double ' + p for p in iqxy_parameters))) 
    562         fn = """\ 
    563 double Iqxy(double qx, double qy, IQXY_PARAMETER_DECLARATIONS); 
    564 double Iqxy(double qx, double qy, IQXY_PARAMETER_DECLARATIONS) { 
    565     %(body)s 
    566 } 
    567 """ % {'body':model_info['Iqxy']} 
    568         source.append(fn) 
    569  
    570528    # Need to know if we have a theta parameter for Iqxy; it is not there 
    571529    # for the magnetic sphere model, for example, which has a magnetic 
     
    584542def categorize_parameters(pars): 
    585543    """ 
     544    Categorize the parameters by use: 
     545 
     546    * *pd* list of polydisperse parameters in order; gui should test whether 
     547      they are in *2d* or *magnetic* as appropriate for the data 
     548    * *1d* set of parameters that are used to compute 1D patterns 
     549    * *2d* set of parameters that are used to compute 2D patterns (which 
     550      includes all 1D parameters) 
     551    * *magnetic* set of parameters that are used to compute magnetic 
     552      patterns (which includes all 1D and 2D parameters) 
     553    * *sesans* set of parameters that are used to compute sesans patterns 
     554     (which is just 1D without background) 
     555    * *pd-relative* is the set of parameters with relative distribution 
     556      width (e.g., radius +/- 10%) rather than absolute distribution 
     557      width (e.g., theta +/- 6 degrees). 
     558    """ 
     559    par_set = {} 
     560    par_set['1d'] = [p for p in pars if p.type not in ('orientation', 'magnetic')] 
     561    par_set['2d'] = [p for p in pars if p.type != 'magnetic'] 
     562    par_set['magnetic'] = [p for p in pars] 
     563    par_set['pd'] = [p for p in pars if p.type in ('volume', 'orientation')] 
     564    par_set['pd_relative'] = [p for p in pars if p.type == 'volume'] 
     565    return par_set 
     566 
     567def collect_types(pars): 
     568    """ 
    586569    Build parameter categories out of the the parameter definitions. 
    587570 
     
    596579    * *orientation* list of orientation parameters 
    597580    * *magnetic* list of magnetic parameters 
    598     * *<empty string>* list of parameters that have no type info 
     581    * *sld* list of parameters that have no type info 
     582    * *other* list of parameters that have no type info 
    599583 
    600584    Each parameter is in one and only one category. 
    601  
    602     The following derived categories are created: 
    603  
    604     * *fixed-1d* list of non-polydisperse parameters for 1D models 
    605     * *pd-1d* list of polydisperse parameters for 1D models 
    606     * *fixed-2d* list of non-polydisperse parameters for 2D models 
    607     * *pd-d2* list of polydisperse parameters for 2D models 
    608     """ 
    609     partype = { 
    610         'volume': [], 'orientation': [], 'magnetic': [], 'sld': [], '': [], 
    611         'fixed-1d': [], 'fixed-2d': [], 'pd-1d': [], 'pd-2d': [], 
    612         'pd-rel': set(), 
     585    """ 
     586    par_type = { 
     587        'volume': [], 'orientation': [], 'magnetic': [], 'sld': [], 'other': [], 
    613588    } 
    614  
    615589    for p in pars: 
    616         if p.type == 'volume': 
    617             partype['pd-1d'].append(p.name) 
    618             partype['pd-2d'].append(p.name) 
    619             partype['pd-rel'].add(p.name) 
    620         elif p.type == 'magnetic': 
    621             partype['fixed-2d'].append(p.name) 
    622         elif p.type == 'orientation': 
    623             partype['pd-2d'].append(p.name) 
    624         elif p.type in ('', 'sld'): 
    625             partype['fixed-1d'].append(p.name) 
    626             partype['fixed-2d'].append(p.name) 
    627         else: 
    628             raise ValueError("unknown parameter type %r" % p.type) 
    629         partype[p.type].append(p.name) 
    630  
    631     return partype 
     590        par_type[p.type if p.type else 'other'].append(p.name) 
     591    return  par_type 
     592 
    632593 
    633594def process_parameters(model_info): 
     
    647608    partype = categorize_parameters(pars) 
    648609    model_info['limits'] = dict((p.name, p.limits) for p in pars) 
    649     model_info['partype'] = partype 
     610    model_info['par_type'] = collect_types(pars) 
     611    model_info['par_set'] = categorize_parameters(pars) 
    650612    model_info['defaults'] = dict((p.name, p.default) for p in pars) 
    651613    if model_info.get('demo', None) is None: 
Note: See TracChangeset for help on using the changeset viewer.