From 541999c5d85bb9542015f5b2617eaeea52e8a70a Mon Sep 17 00:00:00 2001 From: Sander de Snoo <59472150+sldesnoo-Delft@users.noreply.github.com> Date: Wed, 4 Oct 2023 14:28:44 +0200 Subject: [PATCH] Raise clearer exception when state threshold is applied to time trace data --- .../acquisition/measurement_converter.py | 70 +++++++++---------- 1 file changed, 35 insertions(+), 35 deletions(-) diff --git a/pulse_lib/acquisition/measurement_converter.py b/pulse_lib/acquisition/measurement_converter.py index d821ee84..83ce113d 100644 --- a/pulse_lib/acquisition/measurement_converter.py +++ b/pulse_lib/acquisition/measurement_converter.py @@ -12,14 +12,14 @@ logger = logging.getLogger(__name__) @dataclass class SetpointsSingle: - name : str - label : str - unit : str - shape : Tuple[int] = field(default_factory=tuple) - setpoints : Tuple[Tuple[float]] = field(default_factory=tuple) - setpoint_names : Tuple[str] = field(default_factory=tuple) - setpoint_labels : Tuple[str] = field(default_factory=tuple) - setpoint_units : Tuple[str] = field(default_factory=tuple) + name: str + label: str + unit: str + shape: Tuple[int] = field(default_factory=tuple) + setpoints: Tuple[Tuple[float]] = field(default_factory=tuple) + setpoint_names: Tuple[str] = field(default_factory=tuple) + setpoint_labels: Tuple[str] = field(default_factory=tuple) + setpoint_units: Tuple[str] = field(default_factory=tuple) def __post_init__(self): self.name = self.name.replace(' ', '_') @@ -58,7 +58,8 @@ class SetpointsMulti: spm = setpoints_multi() param = MultiParameter(..., **spm.__dict__) ''' - def __init__(self, sps_list:List[SetpointsSingle]): + + def __init__(self, sps_list: List[SetpointsSingle]): self.names = tuple(sps.name for sps in sps_list) self.labels = tuple(sps.label for sps in sps_list) self.units = tuple(sps.unit for sps in sps_list) @@ -116,8 +117,8 @@ class MeasurementParameter(MultiParameter): label = name # check the shape returned by the derived parameter - dummy_data = {name:np.zeros(shape) - for name,shape in zip(self.names,self.shapes)} + dummy_data = {name: np.zeros(shape) + for name, shape in zip(self.names, self.shapes)} dp_shape = np.shape(func(dummy_data)) n_dim = len(dp_shape) n_rep = self._mc.n_rep @@ -205,7 +206,7 @@ class MeasurementParameter(MultiParameter): d = data[m_name] return np.histogram(d, bins=binedges)[0]/d.shape[0] - binedges = np.linspace(range[0], range[1], bins+1) + binedges = np.linspace(range[0], range[1], bins+1) bincenters = (binedges[1:] + binedges[:-1])/2 setpoints = (tuple(bincenters),) setpoint_names = ('sensor_val',) @@ -226,8 +227,9 @@ class MeasurementParameter(MultiParameter): data = self._mc.get_measurement_data(self._data_selection) if len(self._derived_params) > 0: - data_map = {name:values for name,values in zip(self.names, data)} - for name,dp in self._derived_params.items(): + # TODO use custom dict that raise a more useful exception instead of KeyError. + data_map = {name: values for name, values in zip(self.names, data)} + for name, dp in self._derived_params.items(): dp_data = dp(data_map) data.append(dp_data) data_map[name] = dp_data @@ -283,12 +285,14 @@ class MeasurementConverter: channel = digitizer_channels[channel_name] self._raw_is_iq.append(channel.iq_out) - def _generate_setpoints(self): for m in self._description.measurements: - if isinstance(m, measurement_acquisition) and not m.has_threshold: - # do not add to result - continue + if isinstance(m, measurement_acquisition): + if not m.has_threshold: + # do not add to result + continue + if m.interval is not None and m.aggregate_func is None: + raise Exception(f'State threshold cannot be applied on time trace ({m.name})') name = f'{m.name}_state' label = name @@ -331,7 +335,7 @@ class MeasurementConverter: if not isinstance(data_offset, Number): data_offset = data_offset[tuple(index)] if m.interval is None: - channel_raw = channel_data[...,data_offset] + channel_raw = channel_data[..., data_offset] thresholded = self._channel_raw.get(m.acquisition_channel+'.thresholded', None) if thresholded is not None: self._hw_thresholded[len(self._raw)] = thresholded[..., data_offset] @@ -342,11 +346,11 @@ class MeasurementConverter: shape = channel_data.shape[:-1]+(max(n_samples),) channel_raw = np.full(shape, np.nan) n_samples = n_samples[tuple(index)] - channel_raw[...,:n_samples] = channel_data[...,data_offset:data_offset+n_samples] + channel_raw[..., :n_samples] = channel_data[..., data_offset:data_offset+n_samples] else: - channel_raw = channel_data[...,data_offset:data_offset+n_samples] + channel_raw = channel_data[..., data_offset:data_offset+n_samples] if m.aggregate_func: - t_start = 0 # TODO @@@ get from measurement + t_start = 0 # TODO @@@ get from measurement # aggregate time series channel_raw = m.aggregate_func(t_start, channel_raw) self._raw.append(channel_raw) @@ -359,7 +363,7 @@ class MeasurementConverter: last_result = {} n_rep = self.n_rep if self.n_rep else 1 accepted_mask = np.ones(n_rep, dtype=int) - for i,m in enumerate(self._description.measurements): + for i, m in enumerate(self._description.measurements): if isinstance(m, measurement_acquisition): if not m.has_threshold: # do not add to result @@ -370,8 +374,8 @@ class MeasurementConverter: result = result.astype(int) hw_thresholded = self._hw_thresholded.get(i, None) if hw_thresholded is not None and np.any(result != hw_thresholded): - logger.warning(f'{np.sum(result != hw_thresholded)} differences between hardware and software threshold. ' - f'({np.where(result != hw_thresholded)})') + logger.warning(f'{np.sum(result != hw_thresholded)} differences between hardware and software ' + f'threshold. (indices: {np.where(result != hw_thresholded)})') elif isinstance(m, measurement_expression): result = m.expression.evaluate(last_result) else: @@ -393,7 +397,6 @@ class MeasurementConverter: self._total_selected = [total_selected] self._selectors = selectors if total_selected > 0: - # Note: for time traces the threshold should not be set. self._values = [np.sum(result*accepted_mask)/total_selected for result in values_unfiltered] else: logger.warning('No shot is accepted') @@ -407,14 +410,14 @@ class MeasurementConverter: def get_setpoints(self, selection): sp_list = [] if selection.raw: - for sp,is_iq in zip(self.sp_raw, self._raw_is_iq): + for sp, is_iq in zip(self.sp_raw, self._raw_is_iq): if not is_iq: sp_list.append(sp) else: funcs = iq_mode2func(selection.iq_mode) if isinstance(funcs, list): - for postfix,_ in funcs: - unit = 'rad' if postfix == '_phase' else 'mV' + for postfix, _ in funcs: + unit = 'rad' if postfix == '_phase' else 'mV' sp_new = sp.with_attributes(name=sp.name+postfix, unit=unit) sp_list.append(sp_new) else: @@ -438,13 +441,13 @@ class MeasurementConverter: def get_measurement_data(self, selection): data = [] if selection.raw: - for raw,is_iq in zip(self._raw, self._raw_is_iq): + for raw, is_iq in zip(self._raw, self._raw_is_iq): if not is_iq: data.append(raw) else: funcs = iq_mode2func(selection.iq_mode) if isinstance(funcs, list): - for _,func in funcs: + for _, func in funcs: data.append(func(raw)) else: data.append(funcs(raw)) @@ -462,9 +465,6 @@ class MeasurementConverter: def get_measurements(self, selection): result = {} - for name,value in zip(self._get_names(selection), - self.get_measurement_data(selection)): + for name, value in zip(self._get_names(selection), self.get_measurement_data(selection)): result[name] = value return result - - -- GitLab