From cdfc746e1e7f57ac5b45b328416f42ff62cfa3ae Mon Sep 17 00:00:00 2001
From: sldesnoo-Delft <s.l.desnoo@tudelft.nl>
Date: Wed, 30 Aug 2023 17:35:57 +0200
Subject: [PATCH] Removed lazy reset_time. After other improvements and fixing
 the bugs the lazy reset had no significant improvement on the performance.

---
 pulse_lib/segments/segment_acquisition.py     | 55 +++++++----------
 pulse_lib/segments/segment_base.py            | 60 +++++++------------
 pulse_lib/segments/segment_container.py       | 17 +++---
 .../utility/data_handling_functions.py        | 36 ++++++++---
 4 files changed, 78 insertions(+), 90 deletions(-)

diff --git a/pulse_lib/segments/segment_acquisition.py b/pulse_lib/segments/segment_acquisition.py
index 31cbceb2..75f9ea5b 100644
--- a/pulse_lib/segments/segment_acquisition.py
+++ b/pulse_lib/segments/segment_acquisition.py
@@ -7,7 +7,7 @@ import logging
 import numpy as np
 import matplotlib.pyplot as plt
 
-from pulse_lib.segments.utility.data_handling_functions import loop_controller
+from pulse_lib.segments.utility.data_handling_functions import loop_controller, use_end_time_cache
 from pulse_lib.segments.data_classes.data_generic import data_container
 from pulse_lib.segments.data_classes.data_acquisition import acquisition_data, acquisition
 from pulse_lib.segments.utility.looping import loop_obj
@@ -39,7 +39,10 @@ class segment_acquisition():
 
         # store data in numpy looking object for easy operator access.
         self.data = data_container(acquisition_data())
-        self._end_times = np.zeros(1)
+        if use_end_time_cache:
+            self._end_times = np.zeros(1)
+        else:
+            self._end_times = None
 
         # local copy of self that will be used to count up the virtual gates.
         self._pulse_data_all = None
@@ -120,7 +123,6 @@ class segment_acquisition():
             other (segment) : the segment to be appended
             time (double/loop_obj) : add at the given time. if None, append at t_start of the segment)
         '''
-        other._lazy_reset_time()
         if other.shape != (1,):
             other_loopobj = loop_obj()
             other_loopobj.add_data(other.data, axis=list(range(other.data.ndim -1,-1,-1)),
@@ -132,31 +134,13 @@ class segment_acquisition():
 
         return self
 
+    @loop_controller
     def reset_time(self, time=None):
         '''
         resets the time back to zero after a certain point
         Args:
             time (double) : (optional), after time to reset back to 0. Note that this is absolute time and not rescaled time.
         '''
-        if self.is_slice or time is None:
-            self._reset_time(time)
-        else:
-            if self._pending_reset_time is not None:
-                time = np.fmax(time, self._pending_reset_time)
-            self._pending_reset_time = time
-
-    def _lazy_reset_time(self):
-        if self._pending_reset_time is not None:
-            if self.is_slice:
-                msg = 'Pulse-lib error. Mixed use of slicing and reset_time()'
-                logger.error(msg)
-                raise Exception(msg)
-            time = self._pending_reset_time
-            self._pending_reset_time = None
-            self._reset_time(time)
-
-    @loop_controller
-    def _reset_time(self, time=None):
         self.data_tmp.reset_time(time)
         return self.data_tmp
 
@@ -184,7 +168,6 @@ class segment_acquisition():
         Args:
             *key (int/slice object) : key of the element -- just use numpy style accessing (slicing supported)
         '''
-        self._lazy_reset_time()
         data_item = self.data[key[0]]
         if not isinstance(data_item, data_container):
             # If the slice contains only 1 element, then it's not a data_container anymore.
@@ -200,11 +183,13 @@ class segment_acquisition():
         self.data = data_org
 
         item.data = data_item
-        i = key[0]
-        if len(self.shape) == 1:
-            item._end_times = self._end_times[i:i+1]
-        else:
-            item._end_times = self._end_times[i]
+        if use_end_time_cache:
+            i = key[0]
+            # Note: the numpy slice uses the same memory!
+            if len(self.shape) == 1:
+                item._end_times = self._end_times[i:i+1]
+            else:
+                item._end_times = self._end_times[i]
         item.is_slice = True
         return item
 
@@ -275,11 +260,13 @@ class segment_acquisition():
     @property
     def total_time(self):
         if not self.render_mode:
-            # use end time from numpy array instead of individual lookup of data elements.
-            if self._pending_reset_time is not None:
-                return np.fmax(self._pending_reset_time, self._end_times)
-            return self._end_times
-#            return self.data.total_time
+            if use_end_time_cache:
+                # use end time from numpy array instead of individual lookup of data elements.
+                if self._pending_reset_time is not None:
+                    return np.fmax(self._pending_reset_time, self._end_times)
+                return self._end_times
+            else:
+                return self.data.total_time
         else:
             return self.pulse_data_all.total_time
 
@@ -291,7 +278,6 @@ class segment_acquisition():
             return self.pulse_data_all.start_time
 
     def enter_rendering_mode(self):
-        self._lazy_reset_time()
         self.render_mode = True
         # make a pre-render of all the pulse data (e.g. compose channels, do not render in full).
         if self.type == 'render':
@@ -313,7 +299,6 @@ class segment_acquisition():
             render full (bool) : do full render (e.g. also get data form virtual channels). Put True if you want to see the waveshape send to the AWG.
             sample_rate (float): standard 1 Gs/s
         '''
-        self._lazy_reset_time()
         if render_full == True:
             pulse_data_curr_seg = self._get_data_all_at(index)
         else:
diff --git a/pulse_lib/segments/segment_base.py b/pulse_lib/segments/segment_base.py
index ca0d04cf..b3ca96cb 100644
--- a/pulse_lib/segments/segment_base.py
+++ b/pulse_lib/segments/segment_base.py
@@ -7,7 +7,7 @@ import logging
 import numpy as np
 import matplotlib.pyplot as plt
 
-from pulse_lib.segments.utility.data_handling_functions import loop_controller
+from pulse_lib.segments.utility.data_handling_functions import loop_controller, use_end_time_cache
 from pulse_lib.segments.data_classes.data_generic import data_container
 from pulse_lib.segments.utility.looping import loop_obj
 from pulse_lib.segments.utility.setpoint_mgr import setpoint_mgr
@@ -37,8 +37,11 @@ class segment_base():
 
         # store data in numpy looking object for easy operator access.
         self.data = data_container(data_object)
-        # end time for every index. Effectively this is a cache of data[index].end_time
-        self._end_times = np.zeros(1)
+        if use_end_time_cache:
+            # end time for every index. Effectively this is a cache of data[index].end_time
+            self._end_times = np.zeros(1)
+        else:
+            self._end_times = None
 
         # references to other channels (for virtual gates).
         self.reference_channels = []
@@ -56,10 +59,10 @@ class segment_base():
         self.is_slice = False
 
     def _copy(self, cpy):
-        self._lazy_reset_time()
         cpy.type = copy.copy(self.type)
         cpy.data = copy.copy(self.data)
-        cpy._end_times = self._end_times.copy()
+        if use_end_time_cache:
+            cpy._end_times = self._end_times.copy()
 
         # note that the container objecet needs to take care of these. By default it will refer to the old references.
         cpy.reference_channels = copy.copy(self.reference_channels)
@@ -71,31 +74,13 @@ class segment_base():
 
         return cpy
 
+    @loop_controller
     def reset_time(self, time=None):
         '''
         resets the time back to zero after a certain point
         Args:
             time (double) : (optional), after time to reset back to 0. Note that this is absolute time and not rescaled time.
         '''
-        if self.is_slice or time is None:
-            self._reset_time(time)
-        else:
-            if self._pending_reset_time is not None:
-                time = np.fmax(time, self._pending_reset_time)
-            self._pending_reset_time = time
-
-    def _lazy_reset_time(self):
-        if self._pending_reset_time is not None:
-            if self.is_slice:
-                msg = 'Pulse-lib error. Mixed use of slicing and reset_time()'
-                logger.error(msg)
-                raise Exception(msg)
-            time = self._pending_reset_time
-            self._pending_reset_time = None
-            self._reset_time(time)
-
-    @loop_controller
-    def _reset_time(self, time=None):
         self.data_tmp.reset_time(time)
         return self.data_tmp
 
@@ -124,7 +109,6 @@ class segment_base():
         Args:
             *key (int/slice object) : key of the element -- just use numpy style accessing (slicing supported)
         '''
-        self._lazy_reset_time()
         data_item = self.data[key[0]]
         if not isinstance(data_item, data_container):
             # If the slice contains only 1 element, then it's not a data_container anymore.
@@ -140,11 +124,13 @@ class segment_base():
 
         item.data = data_item
         item.is_slice = True
-        i = key[0]
-        if len(self.shape) == 1:
-            item._end_times = self._end_times[i:i+1]
-        else:
-            item._end_times = self._end_times[i]
+        if use_end_time_cache:
+            i = key[0]
+            # Note: the numpy slice uses the same memory!
+            if len(self.shape) == 1:
+                item._end_times = self._end_times[i:i+1]
+            else:
+                item._end_times = self._end_times[i]
         return item
 
     def append(self, other):
@@ -160,7 +146,6 @@ class segment_base():
             other (segment) : the segment to be appended
             time (double/loop_obj) : add at the given time. if None, append at t_start of the segment)
         '''
-        other._lazy_reset_time()
         if other.shape != (1,):
             other_loopobj = loop_obj()
             other_loopobj.add_data(other.data, axis=list(range(other.data.ndim -1,-1,-1)),
@@ -237,11 +222,13 @@ class segment_base():
     @property
     def total_time(self):
         if not self.render_mode:
-            if self._pending_reset_time is not None:
-                return np.fmax(self._pending_reset_time, self._end_times)
-            # use end time from numpy array instead of individual lookup of data elements.
-            return self._end_times
-#            return self.data.total_time
+            if use_end_time_cache:
+                if self._pending_reset_time is not None:
+                    return np.fmax(self._pending_reset_time, self._end_times)
+                # use end time from numpy array instead of individual lookup of data elements.
+                return self._end_times
+            else:
+                return self.data.total_time
         else:
             return self.pulse_data_all.total_time
 
@@ -253,7 +240,6 @@ class segment_base():
             return self.pulse_data_all.start_time
 
     def enter_rendering_mode(self):
-        self._lazy_reset_time()
         self.render_mode = True
         # make a pre-render of all the pulse data (e.g. compose channels, do not render in full).
         if self.type == 'render':
diff --git a/pulse_lib/segments/segment_container.py b/pulse_lib/segments/segment_container.py
index e1535fd5..cfe66fcc 100644
--- a/pulse_lib/segments/segment_container.py
+++ b/pulse_lib/segments/segment_container.py
@@ -11,7 +11,8 @@ from pulse_lib.segments.segment_acquisition import segment_acquisition
 from pulse_lib.segments.segment_measurements import segment_measurements
 
 import pulse_lib.segments.utility.looping as lp
-from pulse_lib.segments.utility.data_handling_functions import find_common_dimension, update_dimension, reduce_arr
+from pulse_lib.segments.utility.data_handling_functions import (
+        find_common_dimension, update_dimension, reduce_arr, use_end_time_cache)
 from pulse_lib.segments.utility.setpoint_mgr import setpoint_mgr, setpoint
 from pulse_lib.segments.data_classes.data_generic import map_index
 
@@ -114,7 +115,8 @@ class segment_container():
                 setattr(new, name,new_chan)
                 new.channels[name] = new_chan
 
-            new._software_markers = self._software_markers[index]
+            # No HVI variables on slices.
+            new._software_markers = None
             new._setpoints = self._setpoints # @@@ -1 setpoint...
             new._shape = self._shape[1:]
             if new._shape == ():
@@ -204,9 +206,6 @@ class segment_container():
             dim = channel.shape
             my_shape = find_common_dimension(my_shape, dim)
 
-        dim = self._software_markers.shape
-        my_shape = find_common_dimension(my_shape, dim)
-
         if self.render_mode:
             self._render_shape = my_shape
         return my_shape
@@ -286,8 +285,6 @@ class segment_container():
         for channel in self.channels.values():
             comb_setpoints += channel.setpoints
 
-        comb_setpoints += self._software_markers.setpoints
-
         return comb_setpoints
 
     def reset_time(self):
@@ -346,10 +343,12 @@ class segment_container():
             raise Exception('extend_dim may not be called in render mode')
         for channel in self.channels.values():
             channel.data = update_dimension(channel.data, shape)
-            channel._end_times = np.zeros(shape) + channel._end_times
+            if use_end_time_cache:
+                channel._end_times = np.zeros(shape) + channel._end_times
 
         self._software_markers.data = update_dimension(self._software_markers.data, shape)
-        self._software_markers._end_times = np.zeros(shape) + self._software_markers._end_times
+        if use_end_time_cache:
+            self._software_markers._end_times = np.zeros(shape) + self._software_markers._end_times
 
     def wait(self, time, channels=None, reset_time=False):
         '''
diff --git a/pulse_lib/segments/utility/data_handling_functions.py b/pulse_lib/segments/utility/data_handling_functions.py
index 27106961..a0b721f3 100644
--- a/pulse_lib/segments/utility/data_handling_functions.py
+++ b/pulse_lib/segments/utility/data_handling_functions.py
@@ -8,6 +8,8 @@ from pulse_lib.segments.utility.looping import loop_obj
 from pulse_lib.segments.data_classes.data_generic import data_container
 from pulse_lib.segments.utility.setpoint_mgr import setpoint
 
+use_end_time_cache = True
+
 def find_common_dimension(dim_1, dim_2):
     '''
     finds the union of two dimensions
@@ -194,7 +196,8 @@ def _update_segment_dims(segment, lp, rendering=False):
                 raise Exception(f'Cannot resize data in slice (Indexing). '
                                 'All loop axes must be added before indexing segment.')
             data = update_dimension(data, new_shape)
-            segment._end_times = np.zeros(new_shape) + segment._end_times
+            if use_end_time_cache:
+                segment._end_times = np.zeros(new_shape) + segment._end_times
         axes[i] = axis
 
         if not lp.no_setpoints and lp.setvals is not None:
@@ -232,9 +235,6 @@ def loop_controller(func):
         if _in_loop:
             raise Exception('NESTED LOOPS')
 
-        # Lazy reset_time
-        obj._lazy_reset_time()
-
         try:
             _in_loop = True
 
@@ -256,7 +256,8 @@ def loop_controller(func):
             if data.shape == (1,):
                 obj.data_tmp = data[0]
                 data[0] = func(obj, *args, **kwargs)
-                obj._end_times[0] = data[0].end_time
+                if use_end_time_cache:
+                    obj._end_times[0] = data[0].end_time
             elif len(loop_info_args) == 0 and len(loop_info_kwargs) == 0:
                 loop_over_data(func, obj, data, obj._end_times, args, kwargs)
             else:
@@ -345,10 +346,15 @@ def loop_over_data_lp(func, obj, data, end_times, args, args_info, kwargs, kwarg
             # we are at the lowest level of the loop.
             obj.data_tmp = data[i]
             data[i] = func(obj, *args_cpy, **kwargs_cpy)
-            end_times[i] = data[i].end_time
+            if use_end_time_cache:
+                end_times[i] = data[i].end_time
         else:
+            if use_end_time_cache:
+                et = end_times[i]
+            else:
+                et = None
             # clean up args, kwargs
-            loop_over_data_lp(func, obj, data[i], end_times[i], args_cpy, args_info, kwargs_cpy, kwargs_info)
+            loop_over_data_lp(func, obj, data[i], et, args_cpy, args_info, kwargs_cpy, kwargs_info)
 
 
 def loop_over_data(func, obj, data, end_times, args, kwargs):
@@ -371,9 +377,14 @@ def loop_over_data(func, obj, data, end_times, args, kwargs):
             # we are at the lowest level of the loop.
             obj.data_tmp = data[i]
             data[i] = func(obj, *args, **kwargs)
-            end_times[i] = data[i].end_time
+            if use_end_time_cache:
+                end_times[i] = data[i].end_time
         else:
-            loop_over_data(func, obj, data[i], end_times[i], args, kwargs)
+            if use_end_time_cache:
+                et = end_times[i]
+            else:
+                et = None
+            loop_over_data(func, obj, data[i], et, args, kwargs)
 
 
 def reduce_arr(arr):
@@ -389,6 +400,13 @@ def reduce_arr(arr):
     """
     shape = arr.shape
     ndim = len(shape)
+    if ndim == 1:
+        mn = np.min(arr)
+        mx = np.max(arr)
+        if mn == mx:
+            return mn, []
+        else:
+            return arr, [0]
     data_axis = []
     slice_array = ()
     for i in range(ndim):
-- 
GitLab