diff --git a/pulse_lib/segments/data_classes/data_pulse.py b/pulse_lib/segments/data_classes/data_pulse.py index 39ffdcb0680a6a9cd80d71195d7c0f9cc7c0b01d..50ed0cb0b453b809e761df0e8819e835d3e35cd1 100644 --- a/pulse_lib/segments/data_classes/data_pulse.py +++ b/pulse_lib/segments/data_classes/data_pulse.py @@ -10,7 +10,7 @@ from typing import Any, Dict, Callable, List from pulse_lib.segments.utility.rounding import iround from pulse_lib.segments.data_classes.data_generic import parent_data -from pulse_lib.segments.data_classes.data_IQ import envelope_generator, IQ_data_single +from pulse_lib.segments.data_classes.data_IQ import envelope_generator total_pulse_deltas = 0 @@ -123,6 +123,23 @@ class PhaseShift: def stop(self): return self.time + def __add__(self, other): + if isinstance(other, PhaseShift): + if other.channel_name != self.channel_name: + raise Exception(f'Segment corruption: {other.channel_name} != {self.channel_name}') + return PhaseShift( + self.time, + self.phase_shift + other.phase_shift, + self.channel_name) + else: + raise Exception(f'Cannot add PhaseShift to {type(other)}') + + @property + def is_near_zero(self): + # near zero if |shift| < 2*pi/2**31 + eps = 2*np.pi/2**32 + return -eps < self.phase_shift < eps + # Changed [v1.6.0] @dataclass class OffsetRamp: @@ -155,6 +172,7 @@ class pulse_data(parent_data): self._end_time = 0 self._consolidated = False self._preprocessed = False + self._phase_shifts_consolidated = False def add_delta(self, delta): if not delta.is_near_zero: @@ -182,6 +200,7 @@ class pulse_data(parent_data): self._update_end_time(custom_pulse.stop) def add_phase_shift(self, phase_shift:PhaseShift): + self._phase_shifts_consolidated = False self.phase_shifts.append(phase_shift) self._update_end_time(phase_shift.time) @@ -243,6 +262,7 @@ class pulse_data(parent_data): self.phase_shifts += other_phase_shifts self._consolidated = False + self._phase_shifts_consolidated = False self._update_end_time(time + other.total_time) def repeat(self, n): @@ -281,6 +301,7 @@ class pulse_data(parent_data): self.phase_shifts = new_phase_shifts self._consolidated = False + self._phase_shifts_consolidated = False self._end_time = (n+1) * time def shift_MW_frequency(self, frequency): @@ -321,6 +342,7 @@ class pulse_data(parent_data): my_copy.start_time = copy.copy(self.start_time) my_copy._end_time = self._end_time my_copy._consolidated = self._consolidated + my_copy._phase_shifts_consolidated = self._phase_shifts_consolidated return my_copy @@ -365,6 +387,7 @@ class pulse_data(parent_data): self.phase_shifts += other.phase_shifts self.custom_pulse_data += other.custom_pulse_data self._end_time = max(self._end_time, other._end_time) + self._phase_shifts_consolidated = False elif isinstance(other, Number): self.pulse_deltas.insert(0, pulse_delta(0, other, 0)) @@ -402,7 +425,8 @@ class pulse_data(parent_data): new_data.phase_shifts = copy.copy(self.phase_shifts) new_data._end_time = self._end_time new_data.start_time = self.start_time - new_data._consolidated = True + new_data._consolidated = self._consolidated + new_data._phase_shifts_consolidated = self._phase_shifts_consolidated else: raise TypeError(f'Cannot multiply pulse_data with {type(other)}') @@ -494,9 +518,32 @@ class pulse_data(parent_data): return integrated_value + def _consolidate_phase_shifts(self): + if self._phase_shifts_consolidated: + return + + if len(self.phase_shifts) > 1: + self.phase_shifts.sort(key=lambda p:p.time) + new_shifts = [] + last = self.phase_shifts[0] + for phase_shift in self.phase_shifts[1:]: + if phase_shift.time == last.time: + last = last + phase_shift + else: + if not last.is_near_zero: + new_shifts.append(last) + last = phase_shift + if not last.is_near_zero: + new_shifts.append(last) + + self.pulse_deltas = new_shifts + + self._phase_shifts_consolidated = True + def get_data_elements(self): elements = [] self._pre_process() + self._consolidate_phase_shifts() for time, duration, v_start, v_stop in zip(self._times, self._intervals, self._amplitudes, self._amplitudes_end): elements.append(OffsetRamp(time, time+duration, v_start, v_stop))