From d62d2a81fdec85fa631de0a4478bb826a4638bea Mon Sep 17 00:00:00 2001
From: Sander de Snoo <59472150+sldesnoo-Delft@users.noreply.github.com>
Date: Thu, 14 Dec 2023 14:40:39 +0100
Subject: [PATCH] Added update_end(stop) to segment to extend segment till at
 least `stop`

---
 .../data_classes/data_HVI_variables.py        |  3 +++
 .../segments/data_classes/data_acquisition.py |  4 ++++
 .../segments/data_classes/data_generic.py     |  4 ++++
 .../segments/data_classes/data_markers.py     |  4 ++++
 pulse_lib/segments/data_classes/data_pulse.py | 18 +++++++++--------
 pulse_lib/segments/segment_acquisition.py     | 10 ++++++++++
 pulse_lib/segments/segment_base.py            | 11 ++++++++++
 pulse_lib/segments/segment_container.py       | 20 ++++++++++++++++++-
 pulse_lib/segments/segment_pulse.py           |  4 +++-
 9 files changed, 68 insertions(+), 10 deletions(-)

diff --git a/pulse_lib/segments/data_classes/data_HVI_variables.py b/pulse_lib/segments/data_classes/data_HVI_variables.py
index f73556b6..73e5bc31 100644
--- a/pulse_lib/segments/data_classes/data_HVI_variables.py
+++ b/pulse_lib/segments/data_classes/data_HVI_variables.py
@@ -42,6 +42,9 @@ class marker_HVI_variable(parent_data):
     def wait(self, time):
         raise NotImplementedError()
 
+    def update_end_time(self, end):
+        raise NotImplementedError()
+
     def integrate_waveform(self, sample_rate):
         raise NotImplementedError()
 
diff --git a/pulse_lib/segments/data_classes/data_acquisition.py b/pulse_lib/segments/data_classes/data_acquisition.py
index 07bed3af..6105bdea 100644
--- a/pulse_lib/segments/data_classes/data_acquisition.py
+++ b/pulse_lib/segments/data_classes/data_acquisition.py
@@ -77,6 +77,10 @@ class acquisition_data(parent_data):
         """
         self.end_time += time
 
+    def update_end_time(self, end):
+        if end + self.start_time > self.end_time:
+            self.end_time = end + self.start_time
+
     @property
     def total_time(self):
         '''
diff --git a/pulse_lib/segments/data_classes/data_generic.py b/pulse_lib/segments/data_classes/data_generic.py
index 83a8e17d..72776bd7 100644
--- a/pulse_lib/segments/data_classes/data_generic.py
+++ b/pulse_lib/segments/data_classes/data_generic.py
@@ -57,6 +57,10 @@ class parent_data(ABC):
     def wait(self, time):
         raise NotImplementedError()
 
+    @abstractmethod
+    def update_end_time(self, end):
+        raise NotImplementedError()
+
     @abstractmethod
     def integrate_waveform(self, sample_rate):
         '''
diff --git a/pulse_lib/segments/data_classes/data_markers.py b/pulse_lib/segments/data_classes/data_markers.py
index 6d483a29..e5d51a99 100644
--- a/pulse_lib/segments/data_classes/data_markers.py
+++ b/pulse_lib/segments/data_classes/data_markers.py
@@ -64,6 +64,10 @@ class marker_data(parent_data):
         """
         self.end_time += time
 
+    def update_end_time(self, t):
+        if t + self.start_time > self.end_time:
+            self.end_time = t + self.start_time
+
     @property
     def total_time(self):
         '''
diff --git a/pulse_lib/segments/data_classes/data_pulse.py b/pulse_lib/segments/data_classes/data_pulse.py
index 4decf811..8e1f2b44 100644
--- a/pulse_lib/segments/data_classes/data_pulse.py
+++ b/pulse_lib/segments/data_classes/data_pulse.py
@@ -186,9 +186,9 @@ class pulse_data(parent_data):
             self.pulse_deltas.append(delta)
             self._consolidated = False
         # always update end time
-        self._update_end_time(delta.time)
+        self.update_end_time(delta.time)
 
-    def _update_end_time(self, t):
+    def update_end_time(self, t):
         if t != np.inf and t > self.end_time:
             self.end_time = t
 
@@ -200,21 +200,21 @@ class pulse_data(parent_data):
             MW_data_object (IQ_data_single) : description MW pulse (see pulse_lib.segments.data_classes.data_IQ)
         """
         self.MW_pulse_data.append(MW_data_object)
-        self._update_end_time(MW_data_object.stop)
+        self.update_end_time(MW_data_object.stop)
 
     def add_chirp(self, chirp):
         self.chirp_data.append(chirp)
-        self._update_end_time(chirp.stop)
+        self.update_end_time(chirp.stop)
 
     def add_custom_pulse_data(self, custom_pulse: custom_pulse_element):
         self.custom_pulse_data.append(custom_pulse)
-        self._update_end_time(custom_pulse.stop)
+        self.update_end_time(custom_pulse.stop)
 
     def add_phase_shift(self, phase_shift: PhaseShift):
         if not phase_shift.is_near_zero:
             self._phase_shifts_consolidated = False
             self.phase_shifts.append(phase_shift)
-        self._update_end_time(phase_shift.time)
+        self.update_end_time(phase_shift.time)
 
     @property
     def total_time(self):
@@ -235,7 +235,7 @@ class pulse_data(parent_data):
         if time is None:
             time = self.total_time
         else:
-            self._update_end_time(time)
+            self.update_end_time(time)
 
         self.start_time = time
 
@@ -290,7 +290,7 @@ class pulse_data(parent_data):
 
         self._consolidated = False
         self._phase_shifts_consolidated = False
-        self._update_end_time(time + other.total_time)
+        self.update_end_time(time + other.total_time)
 
     def shift_MW_frequency(self, frequency):
         '''
@@ -503,6 +503,8 @@ class pulse_data(parent_data):
                     # symmetric 2nd order correction used in v1.7+
                     samples[i] += - dt*(t_sample-dt)*delta.ramp/2/2
                     samples2[i] = - dt*(t_sample-dt)*delta.ramp/2/2
+                    # samples[i] += - dt*(t_sample-dt)*delta.ramp*(dt)*0.5
+                    # samples2[i] = - dt*(t_sample-dt)*delta.ramp*(1-dt)*0.5
             else:
                 for i, delta in enumerate(self.pulse_deltas):
                     times[i] = delta.time
diff --git a/pulse_lib/segments/segment_acquisition.py b/pulse_lib/segments/segment_acquisition.py
index fbf538de..700148ad 100644
--- a/pulse_lib/segments/segment_acquisition.py
+++ b/pulse_lib/segments/segment_acquisition.py
@@ -156,6 +156,16 @@ class segment_acquisition():
             self.data_tmp.reset_time(None)
         return self.data_tmp
 
+    @loop_controller
+    def update_end(self, stop):
+        '''
+        Sets the end of the segment to at least stop (relative to current start time).
+        This has an effect similar to add_block(0, stop, 0.0), but works on all
+        Args:
+            stop (float) : minimum end time of segment.
+        '''
+        self.data_tmp.update_end_time(stop)
+        return self.data_tmp
 
     @property
     def setpoints(self):
diff --git a/pulse_lib/segments/segment_base.py b/pulse_lib/segments/segment_base.py
index b3ca96cb..96db13fd 100644
--- a/pulse_lib/segments/segment_base.py
+++ b/pulse_lib/segments/segment_base.py
@@ -84,6 +84,17 @@ class segment_base():
         self.data_tmp.reset_time(time)
         return self.data_tmp
 
+    @loop_controller
+    def update_end(self, stop):
+        '''
+        Sets the end of the segment to at least stop (relative to current start time).
+        This has an effect similar to add_block(0, stop, 0.0), but works on all
+        Args:
+            stop (float) : minimum end time of segment.
+        '''
+        self.data_tmp.update_end_time(stop)
+        return self.data_tmp
+
     @loop_controller
     def wait(self, time, reset_time=False):
         '''
diff --git a/pulse_lib/segments/segment_container.py b/pulse_lib/segments/segment_container.py
index 1acfc5ec..9ebaa60f 100644
--- a/pulse_lib/segments/segment_container.py
+++ b/pulse_lib/segments/segment_container.py
@@ -321,7 +321,6 @@ class segment_container():
             for channel in self.channels.values():
                 channel.reset_time(time)
 
-
     def get_waveform(self, channel, index = [0], sample_rate=1e9, ref_channel_states=None):
         '''
         function to get the raw data of a waveform,
@@ -367,6 +366,21 @@ class segment_container():
         if reset_time:
             self.reset_time()
 
+    def update_end(self, stop, channels=None):
+        '''
+        Sets the end of the segment to at least stop (relative to current start time).
+        This has an effect similar to add_block(0, stop, 0.0), but works on all
+        Args:
+            stop (float, loop_obj) : minimum end time of segment.
+            channels (List[str]): channels to add the wait to. If None add to all channels.
+        '''
+        if channels is None:
+            for channel in self.channels.values():
+                channel.update_end(stop)
+        else:
+            for channel in channels:
+                self[channel].update_end(stop)
+
     def add_block(self, start, stop, channels, amplitudes, reset_time=False):
         '''
         Adds a block to each of the specified channels.
@@ -380,6 +394,8 @@ class segment_container():
         for channel, amplitude in zip(channels, amplitudes):
             self[channel].add_block(start, stop, amplitude)
         if reset_time:
+            if len(channels) == 0:
+                self.update_end(stop)
             self.reset_time()
 
     def add_ramp(self, start, stop, channels, start_amplitudes, stop_amplitudes, keep_amplitude=False, reset_time=False):
@@ -397,6 +413,8 @@ class segment_container():
         for channel, start_amp, stop_amp in zip(channels, start_amplitudes, stop_amplitudes):
             self[channel].add_ramp_ss(start, stop, start_amp, stop_amp, keep_amplitude=keep_amplitude)
         if reset_time:
+            if len(channels) == 0:
+                self.update_end(stop)
             self.reset_time()
 
     def add_HVI_variable(self, marker_name, value):
diff --git a/pulse_lib/segments/segment_pulse.py b/pulse_lib/segments/segment_pulse.py
index ef76ce15..88488376 100644
--- a/pulse_lib/segments/segment_pulse.py
+++ b/pulse_lib/segments/segment_pulse.py
@@ -32,7 +32,7 @@ class segment_pulse(segment_base):
         super().__init__(name, pulse_data(hres=hres), segment_type)
 
     @loop_controller
-    def add_block(self,start,stop, amplitude):
+    def add_block(self, start, stop, amplitude):
         '''
         add a block pulse on top of the existing pulse.
         '''
@@ -95,6 +95,8 @@ class segment_pulse(segment_base):
                                                 step=stop_amplitude))
             self.data_tmp.add_delta(pulse_delta(np.inf,
                                                 step=-stop_amplitude))
+        else:
+            self.data_tmp.update_end_time(stop + self.data_tmp.start_time)
 
         return self.data_tmp
 
-- 
GitLab