From 30e192609512f1664ebd073e5b5462fcabe9f07f Mon Sep 17 00:00:00 2001
From: Sander de Snoo <59472150+sldesnoo-Delft@users.noreply.github.com>
Date: Fri, 4 Aug 2023 09:15:00 +0200
Subject: [PATCH] Added aggregate_func to aggregate n samples of 1 measurement
 to a single value with a user defined function.

---
 pulse_lib/acquisition/acquisition_conf.py      | 12 +++++++++++-
 pulse_lib/acquisition/measurement_converter.py |  8 +++++---
 pulse_lib/segments/segment_measurements.py     |  8 +++++++-
 pulse_lib/sequencer.py                         |  8 +++++++-
 4 files changed, 30 insertions(+), 6 deletions(-)

diff --git a/pulse_lib/acquisition/acquisition_conf.py b/pulse_lib/acquisition/acquisition_conf.py
index 06163d9e..b3af5133 100644
--- a/pulse_lib/acquisition/acquisition_conf.py
+++ b/pulse_lib/acquisition/acquisition_conf.py
@@ -1,5 +1,7 @@
 from dataclasses import dataclass
-from typing import Optional, Union, List
+from typing import Optional, Union, List, Callable
+
+import numpy as np
 
 from pulse_lib.segments.utility.looping import loop_obj
 
@@ -15,22 +17,30 @@ class AcquisitionConf:
     measurement time in ns.
     If None it must be set in acquire()
     '''
+
     channels: Optional[List[str]] = None
     '''
     Channels to retrieve data from specified by name.
     If None it is defined by acquire()
     '''
+
     sample_rate: Optional[float] = None
     '''
     Sample rate of data in Hz. When not None, the data should not be averaged,
     but downsampled with specified rate. Useful for time traces and Elzerman readout.
     Downsampling uses block average.
     '''
+
     average_repetitions: bool = False
     '''
     Average acquisition data over the sequence repetitions.
     '''
 
+    aggregate_func: Callable[[np.ndarray], np.ndarray] = None
+    '''
+    Function aggregating data on time axis to new value.
+    '''
+
     # TODO are the options needed?
     # options: Optional[Dict[str,Any]] = None
     # '''
diff --git a/pulse_lib/acquisition/measurement_converter.py b/pulse_lib/acquisition/measurement_converter.py
index cfa7ba5c..a9a0fb85 100644
--- a/pulse_lib/acquisition/measurement_converter.py
+++ b/pulse_lib/acquisition/measurement_converter.py
@@ -272,7 +272,7 @@ class MeasurementConverter:
             sp_raw = SetpointsSingle(name, label, 'mV')
             if n_rep:
                 sp_raw.append(np.arange(n_rep), 'repetition', 'repetition', '')
-            if m.interval is not None:
+            if m.interval is not None and m.aggregate_func is None:
                 n_samples = m.n_samples
                 if not isinstance(n_samples, Number):
                     n_samples = max(n_samples)
@@ -322,7 +322,6 @@ class MeasurementConverter:
 
     def _set_data_raw(self, index):
         self._raw = []
-        self._raw_split = []
         for m in self._description.measurements:
             if isinstance(m, measurement_acquisition):
                 channel_name = m.acquisition_channel
@@ -335,13 +334,16 @@ class MeasurementConverter:
                 else:
                     n_samples = m.n_samples
                     if not isinstance(n_samples, Number):
+                        # NOTE: n_samples is an array (loop_obj)
                         shape = channel_data.shape[:-1]+(max(n_samples),)
                         channel_raw = np.full(shape, np.nan)
                         n_samples = n_samples[tuple(index)]
                         channel_raw[...,:n_samples] = channel_data[...,data_offset:data_offset+n_samples]
                     else:
                         channel_raw = channel_data[...,data_offset:data_offset+n_samples]
-
+                    if m.aggregate_func:
+                        # aggregate time series
+                        channel_raw = m.aggregate_func(channel_raw)
                 self._raw.append(channel_raw)
 
     def _set_states(self):
diff --git a/pulse_lib/segments/segment_measurements.py b/pulse_lib/segments/segment_measurements.py
index 0184ad34..7c459fd3 100644
--- a/pulse_lib/segments/segment_measurements.py
+++ b/pulse_lib/segments/segment_measurements.py
@@ -2,7 +2,9 @@
 Measurement channel implementation.
 """
 from dataclasses import dataclass
-from typing import Optional
+from typing import Optional, Callable
+
+import numpy as np
 
 from .utility.measurement_ref import MeasurementExpressionBase, MeasurementRef
 
@@ -25,6 +27,10 @@ class measurement_acquisition(measurement_base):
     '''  Number of samples when using time traces. Value set by sequencer when downsampling. '''
     data_offset: int = 0
     ''' Offset of data in acquired channel data. '''
+    aggregate_func: Callable[[np.ndarray], np.ndarray] = None
+    '''
+    Function aggregating data on time axis to new value.
+    '''
 
     @property
     def has_threshold(self):
diff --git a/pulse_lib/sequencer.py b/pulse_lib/sequencer.py
index a289b747..d933ab7a 100644
--- a/pulse_lib/sequencer.py
+++ b/pulse_lib/sequencer.py
@@ -337,6 +337,7 @@ class sequencer():
                         sample_rate=None,
                         channels=[],
                         average_repetitions=None,
+                        aggregate_func=None
                         ):
         '''
         Args:
@@ -347,6 +348,8 @@ class sequencer():
                 but sampled with specified rate. Useful for time traces and Elzerman readout.
                 Does not change digitizer DAC rate. Data is down-sampled using block averages.
             average_repetitions (bool): Average data over the sequence repetitions.
+            aggregate_func:
+                Function aggregating data on time axis to new value. Must be used with sample_rate.
         '''
         if self._measurement_converter is not None:
             raise Exception('Acquisition parameters cannot be changed after calling  '
@@ -359,7 +362,9 @@ class sequencer():
         if channels != []:
             conf.channels = channels
         if average_repetitions is not None:
-            conf.average_repetitions = average_repetitions # @@@ implement Keysight
+            conf.average_repetitions = average_repetitions
+        if aggregate_func is not None:
+            conf.aggregate_func = aggregate_func
 
     def _set_num_samples(self):
         default_t_measure = self._acquisition_conf.t_measure
@@ -399,6 +404,7 @@ class sequencer():
                     m.n_samples = self.uploader.get_num_samples(
                             m.acquisition_channel, t_measure, sample_rate)
                     m.interval = round(1e9/sample_rate)
+                m.aggregate_func = self._acquisition_conf.aggregate_func
             else:
                 m.n_samples = 1
 
-- 
GitLab