Changeset 765d025 in sasmodels for sasmodels/kerneldll.py


Ignore:
Timestamp:
Oct 30, 2018 12:05:27 PM (6 years ago)
Author:
Paul Kienzle <pkienzle@…>
Children:
646eeaa
Parents:
1662ebe (diff), aa8c6e0 (diff)
Note: this is a merge changeset, the changes displayed below correspond to the merge itself.
Use the (diff) links above to see all the changes relative to each parent.
Message:

Merge remote-tracking branch 'upstream/beta_approx'

File:
1 edited

Legend:

Unmodified
Added
Removed
  • sasmodels/kerneldll.py

    r1662ebe re44432d  
    9999    pass 
    100100# pylint: enable=unused-import 
     101 
     102# Compiler output is a byte stream that needs to be decode in python 3 
     103decode = (lambda s: s) if sys.version_info[0] < 3 else (lambda s: s.decode('utf8')) 
     104 
     105if "SAS_DLL_PATH" in os.environ: 
     106    SAS_DLL_PATH = os.environ["SAS_DLL_PATH"] 
     107else: 
     108    # Assume the default location of module DLLs is in .sasmodels/compiled_models. 
     109    SAS_DLL_PATH = os.path.join(os.path.expanduser("~"), ".sasmodels", "compiled_models") 
    101110 
    102111if "SAS_COMPILER" in os.environ: 
     
    161170        return CC + [source, "-o", output, "-lm"] 
    162171 
    163 # Assume the default location of module DLLs is in .sasmodels/compiled_models. 
    164 DLL_PATH = os.path.join(os.path.expanduser("~"), ".sasmodels", "compiled_models") 
    165  
    166172ALLOW_SINGLE_PRECISION_DLLS = True 
    167173 
     
    181187        subprocess.check_output(command, shell=shell, stderr=subprocess.STDOUT) 
    182188    except subprocess.CalledProcessError as exc: 
    183         raise RuntimeError("compile failed.\n%s\n%s" 
    184                            % (command_str, exc.output.decode())) 
     189        output = decode(exc.output) 
     190        raise RuntimeError("compile failed.\n%s\n%s"%(command_str, output)) 
    185191    if not os.path.exists(output): 
    186192        raise RuntimeError("compile failed.  File is in %r"%source) 
     
    201207        return path 
    202208 
    203     return joinpath(DLL_PATH, basename) 
     209    return joinpath(SAS_DLL_PATH, basename) 
    204210 
    205211 
     
    210216    exist yet if it hasn't been compiled. 
    211217    """ 
    212     return os.path.join(DLL_PATH, dll_name(model_info, dtype)) 
     218    return os.path.join(SAS_DLL_PATH, dll_name(model_info, dtype)) 
    213219 
    214220 
     
    229235    models are not allowed as DLLs. 
    230236 
    231     Set *sasmodels.kerneldll.DLL_PATH* to the compiled dll output path. 
     237    Set *sasmodels.kerneldll.SAS_DLL_PATH* to the compiled dll output path. 
     238    Alternatively, set the environment variable *SAS_DLL_PATH*. 
    232239    The default is in ~/.sasmodels/compiled_models. 
    233240    """ 
     
    248255    if need_recompile: 
    249256        # Make sure the DLL path exists 
    250         if not os.path.exists(DLL_PATH): 
    251             os.makedirs(DLL_PATH) 
     257        if not os.path.exists(SAS_DLL_PATH): 
     258            os.makedirs(SAS_DLL_PATH) 
    252259        basename = splitext(os.path.basename(dll))[0] + "_" 
    253260        system_fd, filename = tempfile.mkstemp(suffix=".c", prefix=basename) 
     
    312319 
    313320        # int, int, int, int*, double*, double*, double*, double*, double 
    314         argtypes = [ct.c_int32]*3 + [ct.c_void_p]*4 + [float_type] 
     321        argtypes = [ct.c_int32]*3 + [ct.c_void_p]*4 + [float_type, ct.c_int32] 
    315322        names = [generate.kernel_name(self.info, variant) 
    316323                 for variant in ("Iq", "Iqxy", "Imagnetic")] 
     
    372379    def __init__(self, kernel, model_info, q_input): 
    373380        # type: (Callable[[], np.ndarray], ModelInfo, PyInput) -> None 
     381        #,model_info,q_input) 
    374382        self.kernel = kernel 
    375383        self.info = model_info 
     
    377385        self.dtype = q_input.dtype 
    378386        self.dim = '2d' if q_input.is_2d else '1d' 
    379         self.result = np.empty(q_input.nq+1, q_input.dtype) 
     387        # leave room for f1/f2 results in case we need to compute beta for 1d models 
     388        nout = 2 if self.info.have_Fq else 1 
     389        # +4 for total weight, shell volume, effective radius, form volume 
     390        self.result = np.empty(q_input.nq*nout + 4, self.dtype) 
    380391        self.real = (np.float32 if self.q_input.dtype == generate.F32 
    381392                     else np.float64 if self.q_input.dtype == generate.F64 
    382393                     else np.float128) 
    383394 
    384     def __call__(self, call_details, values, cutoff, magnetic): 
    385         # type: (CallDetails, np.ndarray, np.ndarray, float, bool) -> np.ndarray 
    386  
     395    def _call_kernel(self, call_details, values, cutoff, magnetic, effective_radius_type): 
     396        # type: (CallDetails, np.ndarray, np.ndarray, float, bool, int) -> np.ndarray 
    387397        kernel = self.kernel[1 if magnetic else 0] 
    388398        args = [ 
     
    391401            None, # pd_stop pd_stride[MAX_PD] 
    392402            call_details.buffer.ctypes.data, # problem 
    393             values.ctypes.data,  #pars 
    394             self.q_input.q.ctypes.data, #q 
     403            values.ctypes.data,  # pars 
     404            self.q_input.q.ctypes.data, # q 
    395405            self.result.ctypes.data,   # results 
    396406            self.real(cutoff), # cutoff 
     407            effective_radius_type, # cutoff 
    397408        ] 
    398409        #print("Calling DLL") 
     
    404415            kernel(*args) # type: ignore 
    405416 
    406         #print("returned",self.q_input.q, self.result) 
    407         pd_norm = self.result[self.q_input.nq] 
    408         scale = values[0]/(pd_norm if pd_norm != 0.0 else 1.0) 
    409         background = values[1] 
    410         #print("scale",scale,background) 
    411         return scale*self.result[:self.q_input.nq] + background 
    412  
    413417    def release(self): 
    414418        # type: () -> None 
Note: See TracChangeset for help on using the changeset viewer.