From 541999c5d85bb9542015f5b2617eaeea52e8a70a Mon Sep 17 00:00:00 2001
From: Sander de Snoo <59472150+sldesnoo-Delft@users.noreply.github.com>
Date: Wed, 4 Oct 2023 14:28:44 +0200
Subject: [PATCH] Raise clearer exception when state threshold is applied to
 time trace data

---
 .../acquisition/measurement_converter.py      | 70 +++++++++----------
 1 file changed, 35 insertions(+), 35 deletions(-)

diff --git a/pulse_lib/acquisition/measurement_converter.py b/pulse_lib/acquisition/measurement_converter.py
index d821ee84..83ce113d 100644
--- a/pulse_lib/acquisition/measurement_converter.py
+++ b/pulse_lib/acquisition/measurement_converter.py
@@ -12,14 +12,14 @@ logger = logging.getLogger(__name__)
 
 @dataclass
 class SetpointsSingle:
-    name : str
-    label : str
-    unit : str
-    shape : Tuple[int] = field(default_factory=tuple)
-    setpoints : Tuple[Tuple[float]] = field(default_factory=tuple)
-    setpoint_names : Tuple[str] = field(default_factory=tuple)
-    setpoint_labels : Tuple[str] = field(default_factory=tuple)
-    setpoint_units : Tuple[str] = field(default_factory=tuple)
+    name: str
+    label: str
+    unit: str
+    shape: Tuple[int] = field(default_factory=tuple)
+    setpoints: Tuple[Tuple[float]] = field(default_factory=tuple)
+    setpoint_names: Tuple[str] = field(default_factory=tuple)
+    setpoint_labels: Tuple[str] = field(default_factory=tuple)
+    setpoint_units: Tuple[str] = field(default_factory=tuple)
 
     def __post_init__(self):
         self.name = self.name.replace(' ', '_')
@@ -58,7 +58,8 @@ class SetpointsMulti:
         spm = setpoints_multi()
         param = MultiParameter(..., **spm.__dict__)
     '''
-    def __init__(self, sps_list:List[SetpointsSingle]):
+
+    def __init__(self, sps_list: List[SetpointsSingle]):
         self.names = tuple(sps.name for sps in sps_list)
         self.labels = tuple(sps.label for sps in sps_list)
         self.units = tuple(sps.unit for sps in sps_list)
@@ -116,8 +117,8 @@ class MeasurementParameter(MultiParameter):
             label = name
 
         # check the shape returned by the derived parameter
-        dummy_data = {name:np.zeros(shape)
-                      for name,shape in zip(self.names,self.shapes)}
+        dummy_data = {name: np.zeros(shape)
+                      for name, shape in zip(self.names, self.shapes)}
         dp_shape = np.shape(func(dummy_data))
         n_dim = len(dp_shape)
         n_rep = self._mc.n_rep
@@ -205,7 +206,7 @@ class MeasurementParameter(MultiParameter):
                 d = data[m_name]
             return np.histogram(d, bins=binedges)[0]/d.shape[0]
 
-        binedges  = np.linspace(range[0], range[1], bins+1)
+        binedges = np.linspace(range[0], range[1], bins+1)
         bincenters = (binedges[1:] + binedges[:-1])/2
         setpoints = (tuple(bincenters),)
         setpoint_names = ('sensor_val',)
@@ -226,8 +227,9 @@ class MeasurementParameter(MultiParameter):
         data = self._mc.get_measurement_data(self._data_selection)
 
         if len(self._derived_params) > 0:
-            data_map = {name:values for name,values in zip(self.names, data)}
-            for name,dp in self._derived_params.items():
+            # TODO use custom dict that raise a more useful exception instead of KeyError.
+            data_map = {name: values for name, values in zip(self.names, data)}
+            for name, dp in self._derived_params.items():
                 dp_data = dp(data_map)
                 data.append(dp_data)
                 data_map[name] = dp_data
@@ -283,12 +285,14 @@ class MeasurementConverter:
             channel = digitizer_channels[channel_name]
             self._raw_is_iq.append(channel.iq_out)
 
-
     def _generate_setpoints(self):
         for m in self._description.measurements:
-            if isinstance(m, measurement_acquisition) and not m.has_threshold:
-                # do not add to result
-                continue
+            if isinstance(m, measurement_acquisition):
+                if not m.has_threshold:
+                    # do not add to result
+                    continue
+                if m.interval is not None and m.aggregate_func is None:
+                    raise Exception(f'State threshold cannot be applied on time trace ({m.name})')
 
             name = f'{m.name}_state'
             label = name
@@ -331,7 +335,7 @@ class MeasurementConverter:
                 if not isinstance(data_offset, Number):
                     data_offset = data_offset[tuple(index)]
                 if m.interval is None:
-                    channel_raw = channel_data[...,data_offset]
+                    channel_raw = channel_data[..., data_offset]
                     thresholded = self._channel_raw.get(m.acquisition_channel+'.thresholded', None)
                     if thresholded is not None:
                         self._hw_thresholded[len(self._raw)] = thresholded[..., data_offset]
@@ -342,11 +346,11 @@ class MeasurementConverter:
                         shape = channel_data.shape[:-1]+(max(n_samples),)
                         channel_raw = np.full(shape, np.nan)
                         n_samples = n_samples[tuple(index)]
-                        channel_raw[...,:n_samples] = channel_data[...,data_offset:data_offset+n_samples]
+                        channel_raw[..., :n_samples] = channel_data[..., data_offset:data_offset+n_samples]
                     else:
-                        channel_raw = channel_data[...,data_offset:data_offset+n_samples]
+                        channel_raw = channel_data[..., data_offset:data_offset+n_samples]
                     if m.aggregate_func:
-                        t_start = 0 # TODO @@@ get from measurement
+                        t_start = 0  # TODO @@@ get from measurement
                         # aggregate time series
                         channel_raw = m.aggregate_func(t_start, channel_raw)
                 self._raw.append(channel_raw)
@@ -359,7 +363,7 @@ class MeasurementConverter:
         last_result = {}
         n_rep = self.n_rep if self.n_rep else 1
         accepted_mask = np.ones(n_rep, dtype=int)
-        for i,m in enumerate(self._description.measurements):
+        for i, m in enumerate(self._description.measurements):
             if isinstance(m, measurement_acquisition):
                 if not m.has_threshold:
                     # do not add to result
@@ -370,8 +374,8 @@ class MeasurementConverter:
                 result = result.astype(int)
                 hw_thresholded = self._hw_thresholded.get(i, None)
                 if hw_thresholded is not None and np.any(result != hw_thresholded):
-                    logger.warning(f'{np.sum(result != hw_thresholded)} differences between hardware and software threshold. '
-                                   f'({np.where(result != hw_thresholded)})')
+                    logger.warning(f'{np.sum(result != hw_thresholded)} differences between hardware and software '
+                                   f'threshold. (indices: {np.where(result != hw_thresholded)})')
             elif isinstance(m, measurement_expression):
                 result = m.expression.evaluate(last_result)
             else:
@@ -393,7 +397,6 @@ class MeasurementConverter:
             self._total_selected = [total_selected]
         self._selectors = selectors
         if total_selected > 0:
-            # Note: for time traces the threshold should not be set.
             self._values = [np.sum(result*accepted_mask)/total_selected for result in values_unfiltered]
         else:
             logger.warning('No shot is accepted')
@@ -407,14 +410,14 @@ class MeasurementConverter:
     def get_setpoints(self, selection):
         sp_list = []
         if selection.raw:
-            for sp,is_iq in zip(self.sp_raw, self._raw_is_iq):
+            for sp, is_iq in zip(self.sp_raw, self._raw_is_iq):
                 if not is_iq:
                     sp_list.append(sp)
                 else:
                     funcs = iq_mode2func(selection.iq_mode)
                     if isinstance(funcs, list):
-                        for postfix,_ in funcs:
-                            unit = 'rad' if postfix == '_phase' else  'mV'
+                        for postfix, _ in funcs:
+                            unit = 'rad' if postfix == '_phase' else 'mV'
                             sp_new = sp.with_attributes(name=sp.name+postfix, unit=unit)
                             sp_list.append(sp_new)
                     else:
@@ -438,13 +441,13 @@ class MeasurementConverter:
     def get_measurement_data(self, selection):
         data = []
         if selection.raw:
-            for raw,is_iq in zip(self._raw, self._raw_is_iq):
+            for raw, is_iq in zip(self._raw, self._raw_is_iq):
                 if not is_iq:
                     data.append(raw)
                 else:
                     funcs = iq_mode2func(selection.iq_mode)
                     if isinstance(funcs, list):
-                        for _,func in funcs:
+                        for _, func in funcs:
                             data.append(func(raw))
                     else:
                         data.append(funcs(raw))
@@ -462,9 +465,6 @@ class MeasurementConverter:
 
     def get_measurements(self, selection):
         result = {}
-        for name,value in zip(self._get_names(selection),
-                              self.get_measurement_data(selection)):
+        for name, value in zip(self._get_names(selection), self.get_measurement_data(selection)):
             result[name] = value
         return result
-
-
-- 
GitLab