Changes in sasmodels/mixture.py [7ae2b7f:fe496dd] in sasmodels


Ignore:
File:
1 edited

Legend:

Unmodified
Added
Removed
  • sasmodels/mixture.py

    r7ae2b7f rfe496dd  
    1111*ProductModel(P, S)*. 
    1212""" 
     13from __future__ import print_function 
     14 
    1315from copy import copy 
    1416import numpy as np  # type: ignore 
     
    1618from .modelinfo import Parameter, ParameterTable, ModelInfo 
    1719from .kernel import KernelModel, Kernel 
     20from .details import make_details 
    1821 
    1922try: 
    2023    from typing import List 
    21     from .details import CallDetails 
    2224except ImportError: 
    2325    pass 
     
    2628    # type: (List[ModelInfo]) -> ModelInfo 
    2729    """ 
    28     Create info block for product model. 
     30    Create info block for mixture model. 
    2931    """ 
    3032    flatten = [] 
     
    3840    # Build new parameter list 
    3941    combined_pars = [] 
     42    demo = {} 
    4043    for k, part in enumerate(parts): 
    4144        # Parameter prefix per model, A_, B_, ... 
     
    4346        # to support vector parameters 
    4447        prefix = chr(ord('A')+k) + '_' 
    45         combined_pars.append(Parameter(prefix+'scale')) 
     48        scale =  Parameter(prefix+'scale', default=1.0, 
     49                           description="model intensity for " + part.name) 
     50        combined_pars.append(scale) 
    4651        for p in part.parameters.kernel_parameters: 
    4752            p = copy(p) 
     
    5156                p.length_control = prefix + p.length_control 
    5257            combined_pars.append(p) 
     58        demo.update((prefix+k, v) for k, v in part.demo.items() 
     59                    if k != "background") 
     60    #print("pars",combined_pars) 
    5361    parameters = ParameterTable(combined_pars) 
     62    parameters.max_pd = sum(part.parameters.max_pd for part in parts) 
    5463 
    5564    model_info = ModelInfo() 
     
    7079    # Remember the component info blocks so we can build the model 
    7180    model_info.composition = ('mixture', parts) 
     81    model_info.demo = demo 
     82    return model_info 
    7283 
    7384 
     
    7889        self.parts = parts 
    7990 
    80     def __call__(self, q_vectors): 
     91    def make_kernel(self, q_vectors): 
    8192        # type: (List[np.ndarray]) -> MixtureKernel 
    8293        # Note: may be sending the q_vectors to the n times even though they 
     
    104115        self.info =  model_info 
    105116        self.kernels = kernels 
    106  
    107     def __call__(self, call_details, value, weight, cutoff): 
    108         # type: (CallDetails, np.ndarray, np.ndarry, float) -> np.ndarray 
    109         scale, background = value[0:2] 
     117        self.dtype = self.kernels[0].dtype 
     118 
     119    def __call__(self, call_details, values, cutoff, magnetic): 
     120        # type: (CallDetails, np.ndarray, np.ndarry, float, bool) -> np.ndarray 
     121        scale, background = values[0:2] 
    110122        total = 0.0 
    111123        # remember the parts for plotting later 
    112124        self.results = [] 
    113         for kernel, kernel_details in zip(self.kernels, call_details.parts): 
    114             part_result = kernel(kernel_details, value, weight, cutoff) 
    115             total += part_result 
    116             self.results.append(part_result) 
     125        offset = 2 # skip scale & background 
     126        parts = MixtureParts(self.info, self.kernels, call_details, values) 
     127        for kernel, kernel_details, kernel_values in parts: 
     128            #print("calling kernel", kernel.info.name) 
     129            result = kernel(kernel_details, kernel_values, cutoff, magnetic) 
     130            #print(kernel.info.name, result) 
     131            total += result 
     132            self.results.append(result) 
    117133 
    118134        return scale*total + background 
     
    123139            k.release() 
    124140 
     141 
     142class MixtureParts(object): 
     143    def __init__(self, model_info, kernels, call_details, values): 
     144        # type: (ModelInfo, List[Kernel], CallDetails, np.ndarray) -> None 
     145        self.model_info = model_info 
     146        self.parts = model_info.composition[1] 
     147        self.kernels = kernels 
     148        self.call_details = call_details 
     149        self.values = values 
     150        self.spin_index = model_info.parameters.npars + 2 
     151        #call_details.show(values) 
     152 
     153    def __iter__(self): 
     154        # type: () -> PartIterable 
     155        self.part_num = 0 
     156        self.par_index = 2 
     157        self.mag_index = self.spin_index + 3 
     158        return self 
     159 
     160    def next(self): 
     161        # type: () -> Tuple[List[Callable], CallDetails, np.ndarray] 
     162        if self.part_num >= len(self.parts): 
     163            raise StopIteration() 
     164        info = self.parts[self.part_num] 
     165        kernel = self.kernels[self.part_num] 
     166        call_details = self._part_details(info, self.par_index) 
     167        values = self._part_values(info, self.par_index, self.mag_index) 
     168        values = values.astype(kernel.dtype) 
     169        #call_details.show(values) 
     170 
     171        self.part_num += 1 
     172        self.par_index += info.parameters.npars + 1 
     173        self.mag_index += 3 * len(info.parameters.magnetism_index) 
     174 
     175        return kernel, call_details, values 
     176 
     177    def _part_details(self, info, par_index): 
     178        # type: (ModelInfo, int) -> CallDetails 
     179        full = self.call_details 
     180        # par_index is index into values array of the current parameter, 
     181        # which includes the initial scale and background parameters. 
     182        # We want the index into the weight length/offset for each parameter. 
     183        # Exclude the initial scale and background, so subtract two, but each 
     184        # component has its own scale factor which we need to skip when 
     185        # constructing the details for the kernel, so add one, giving a 
     186        # net subtract one. 
     187        index = slice(par_index - 1, par_index - 1 + info.parameters.npars) 
     188        length = full.length[index] 
     189        offset = full.offset[index] 
     190        # The complete weight vector is being sent to each part so that 
     191        # offsets don't need to be adjusted. 
     192        part = make_details(info, length, offset, full.num_weights) 
     193        return part 
     194 
     195    def _part_values(self, info, par_index, mag_index): 
     196        # type: (ModelInfo, int, int) -> np.ndarray 
     197        #print(info.name, par_index, self.values[par_index:par_index + info.parameters.npars + 1]) 
     198        scale = self.values[par_index] 
     199        pars = self.values[par_index + 1:par_index + info.parameters.npars + 1] 
     200        nmagnetic = len(info.parameters.magnetism_index) 
     201        if nmagnetic: 
     202            spin_state = self.values[self.spin_index:self.spin_index + 3] 
     203            mag_index = self.values[mag_index:mag_index + 3 * nmagnetic] 
     204        else: 
     205            spin_state = [] 
     206            mag_index = [] 
     207        nvalues = self.model_info.parameters.nvalues 
     208        nweights = self.call_details.num_weights 
     209        weights = self.values[nvalues:nvalues+2*nweights] 
     210        zero = self.values.dtype.type(0.) 
     211        values = [[scale, zero], pars, spin_state, mag_index, weights] 
     212        # Pad value array to a 32 value boundary 
     213        spacer = (32 - sum(len(v) for v in values)%32)%32 
     214        values.append([zero]*spacer) 
     215        values = np.hstack(values).astype(self.kernels[0].dtype) 
     216        return values 
Note: See TracChangeset for help on using the changeset viewer.