Skip to content
Snippets Groups Projects
Commit 541999c5 authored by Sander de Snoo's avatar Sander de Snoo
Browse files

Raise clearer exception when state threshold is applied to time trace data

parent 0068fcce
No related branches found
No related tags found
No related merge requests found
......@@ -12,14 +12,14 @@ logger = logging.getLogger(__name__)
@dataclass
class SetpointsSingle:
name : str
label : str
unit : str
shape : Tuple[int] = field(default_factory=tuple)
setpoints : Tuple[Tuple[float]] = field(default_factory=tuple)
setpoint_names : Tuple[str] = field(default_factory=tuple)
setpoint_labels : Tuple[str] = field(default_factory=tuple)
setpoint_units : Tuple[str] = field(default_factory=tuple)
name: str
label: str
unit: str
shape: Tuple[int] = field(default_factory=tuple)
setpoints: Tuple[Tuple[float]] = field(default_factory=tuple)
setpoint_names: Tuple[str] = field(default_factory=tuple)
setpoint_labels: Tuple[str] = field(default_factory=tuple)
setpoint_units: Tuple[str] = field(default_factory=tuple)
def __post_init__(self):
self.name = self.name.replace(' ', '_')
......@@ -58,7 +58,8 @@ class SetpointsMulti:
spm = setpoints_multi()
param = MultiParameter(..., **spm.__dict__)
'''
def __init__(self, sps_list:List[SetpointsSingle]):
def __init__(self, sps_list: List[SetpointsSingle]):
self.names = tuple(sps.name for sps in sps_list)
self.labels = tuple(sps.label for sps in sps_list)
self.units = tuple(sps.unit for sps in sps_list)
......@@ -116,8 +117,8 @@ class MeasurementParameter(MultiParameter):
label = name
# check the shape returned by the derived parameter
dummy_data = {name:np.zeros(shape)
for name,shape in zip(self.names,self.shapes)}
dummy_data = {name: np.zeros(shape)
for name, shape in zip(self.names, self.shapes)}
dp_shape = np.shape(func(dummy_data))
n_dim = len(dp_shape)
n_rep = self._mc.n_rep
......@@ -205,7 +206,7 @@ class MeasurementParameter(MultiParameter):
d = data[m_name]
return np.histogram(d, bins=binedges)[0]/d.shape[0]
binedges = np.linspace(range[0], range[1], bins+1)
binedges = np.linspace(range[0], range[1], bins+1)
bincenters = (binedges[1:] + binedges[:-1])/2
setpoints = (tuple(bincenters),)
setpoint_names = ('sensor_val',)
......@@ -226,8 +227,9 @@ class MeasurementParameter(MultiParameter):
data = self._mc.get_measurement_data(self._data_selection)
if len(self._derived_params) > 0:
data_map = {name:values for name,values in zip(self.names, data)}
for name,dp in self._derived_params.items():
# TODO use custom dict that raise a more useful exception instead of KeyError.
data_map = {name: values for name, values in zip(self.names, data)}
for name, dp in self._derived_params.items():
dp_data = dp(data_map)
data.append(dp_data)
data_map[name] = dp_data
......@@ -283,12 +285,14 @@ class MeasurementConverter:
channel = digitizer_channels[channel_name]
self._raw_is_iq.append(channel.iq_out)
def _generate_setpoints(self):
for m in self._description.measurements:
if isinstance(m, measurement_acquisition) and not m.has_threshold:
# do not add to result
continue
if isinstance(m, measurement_acquisition):
if not m.has_threshold:
# do not add to result
continue
if m.interval is not None and m.aggregate_func is None:
raise Exception(f'State threshold cannot be applied on time trace ({m.name})')
name = f'{m.name}_state'
label = name
......@@ -331,7 +335,7 @@ class MeasurementConverter:
if not isinstance(data_offset, Number):
data_offset = data_offset[tuple(index)]
if m.interval is None:
channel_raw = channel_data[...,data_offset]
channel_raw = channel_data[..., data_offset]
thresholded = self._channel_raw.get(m.acquisition_channel+'.thresholded', None)
if thresholded is not None:
self._hw_thresholded[len(self._raw)] = thresholded[..., data_offset]
......@@ -342,11 +346,11 @@ class MeasurementConverter:
shape = channel_data.shape[:-1]+(max(n_samples),)
channel_raw = np.full(shape, np.nan)
n_samples = n_samples[tuple(index)]
channel_raw[...,:n_samples] = channel_data[...,data_offset:data_offset+n_samples]
channel_raw[..., :n_samples] = channel_data[..., data_offset:data_offset+n_samples]
else:
channel_raw = channel_data[...,data_offset:data_offset+n_samples]
channel_raw = channel_data[..., data_offset:data_offset+n_samples]
if m.aggregate_func:
t_start = 0 # TODO @@@ get from measurement
t_start = 0 # TODO @@@ get from measurement
# aggregate time series
channel_raw = m.aggregate_func(t_start, channel_raw)
self._raw.append(channel_raw)
......@@ -359,7 +363,7 @@ class MeasurementConverter:
last_result = {}
n_rep = self.n_rep if self.n_rep else 1
accepted_mask = np.ones(n_rep, dtype=int)
for i,m in enumerate(self._description.measurements):
for i, m in enumerate(self._description.measurements):
if isinstance(m, measurement_acquisition):
if not m.has_threshold:
# do not add to result
......@@ -370,8 +374,8 @@ class MeasurementConverter:
result = result.astype(int)
hw_thresholded = self._hw_thresholded.get(i, None)
if hw_thresholded is not None and np.any(result != hw_thresholded):
logger.warning(f'{np.sum(result != hw_thresholded)} differences between hardware and software threshold. '
f'({np.where(result != hw_thresholded)})')
logger.warning(f'{np.sum(result != hw_thresholded)} differences between hardware and software '
f'threshold. (indices: {np.where(result != hw_thresholded)})')
elif isinstance(m, measurement_expression):
result = m.expression.evaluate(last_result)
else:
......@@ -393,7 +397,6 @@ class MeasurementConverter:
self._total_selected = [total_selected]
self._selectors = selectors
if total_selected > 0:
# Note: for time traces the threshold should not be set.
self._values = [np.sum(result*accepted_mask)/total_selected for result in values_unfiltered]
else:
logger.warning('No shot is accepted')
......@@ -407,14 +410,14 @@ class MeasurementConverter:
def get_setpoints(self, selection):
sp_list = []
if selection.raw:
for sp,is_iq in zip(self.sp_raw, self._raw_is_iq):
for sp, is_iq in zip(self.sp_raw, self._raw_is_iq):
if not is_iq:
sp_list.append(sp)
else:
funcs = iq_mode2func(selection.iq_mode)
if isinstance(funcs, list):
for postfix,_ in funcs:
unit = 'rad' if postfix == '_phase' else 'mV'
for postfix, _ in funcs:
unit = 'rad' if postfix == '_phase' else 'mV'
sp_new = sp.with_attributes(name=sp.name+postfix, unit=unit)
sp_list.append(sp_new)
else:
......@@ -438,13 +441,13 @@ class MeasurementConverter:
def get_measurement_data(self, selection):
data = []
if selection.raw:
for raw,is_iq in zip(self._raw, self._raw_is_iq):
for raw, is_iq in zip(self._raw, self._raw_is_iq):
if not is_iq:
data.append(raw)
else:
funcs = iq_mode2func(selection.iq_mode)
if isinstance(funcs, list):
for _,func in funcs:
for _, func in funcs:
data.append(func(raw))
else:
data.append(funcs(raw))
......@@ -462,9 +465,6 @@ class MeasurementConverter:
def get_measurements(self, selection):
result = {}
for name,value in zip(self._get_names(selection),
self.get_measurement_data(selection)):
for name, value in zip(self._get_names(selection), self.get_measurement_data(selection)):
result[name] = value
return result
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment