diff --git a/pulse_lib/segments/segment_base.py b/pulse_lib/segments/segment_base.py index e4e373f3e288f560b2a9ad8508660736b2943096..6dce2e02e124b757d92ed0263c3cd3623a02bd1c 100644 --- a/pulse_lib/segments/segment_base.py +++ b/pulse_lib/segments/segment_base.py @@ -1,6 +1,7 @@ """ File containing the parent class where all segment objects are derived from. """ +import copy import numpy as np import matplotlib.pyplot as plt @@ -10,8 +11,7 @@ from pulse_lib.segments.data_classes.data_generic import data_container from pulse_lib.segments.utility.looping import loop_obj from pulse_lib.segments.utility.setpoint_mgr import setpoint_mgr from pulse_lib.segments.data_classes.data_generic import map_index - -import copy +from pulse_lib.segments.utility.data_handling_functions import update_dimension class segment_base(): @@ -168,6 +168,8 @@ class segment_base(): item.data = data_item if self._data_hvi_variable is not None: if self._data_hvi_variable is not self.data: + # assert segment HVI variables has right shape. + item._data_hvi_variable.data = update_dimension(self._data_hvi_variable.data, self.shape) item._data_hvi_variable = item._data_hvi_variable[key[0]] else: item._data_hvi_variable = item.data @@ -214,6 +216,8 @@ class segment_base(): Args: loop_obj (loop_obj) : loop object with certain dimension to add. ''' + if not isinstance(loop_obj, float): + raise Exception(f'update_dim failed. Reload pulselib!') return self.data_tmp def add_HVI_marker(self, marker_name, t_off = 0): diff --git a/pulse_lib/tests/looping/test_segment_update_dim.py b/pulse_lib/tests/looping/test_segment_update_dim.py new file mode 100644 index 0000000000000000000000000000000000000000..383a125405c26aa1b172c000a298b2c1ca2dc4af --- /dev/null +++ b/pulse_lib/tests/looping/test_segment_update_dim.py @@ -0,0 +1,35 @@ + +from pulse_lib.tests.configurations.test_configuration import context +import pulse_lib.segments.utility.looping as lp + +def test(): + pulse = context.init_pulselib(n_gates=2) + + n_pulses = lp.array([1,2,4,9], axis=0, name='n_pulses') + + s = pulse.mk_segment() + + context.segment = s + + s.P1.update_dim(n_pulses) + for i,n in enumerate(n_pulses): + p1 = s.P1[i] + for _ in range(int(n)): + p1.add_ramp_ss(0, 100, -80, 80) + p1.reset_time() + + s.P2.add_block(0, 100, 60) + + for i in range(len(n_pulses)): + context.plot_segments([s], index=[i]) + + sequence = pulse.mk_sequence([s]) + context.add_hw_schedule(sequence) + for n in sequence.n_pulses.values: + sequence.n_pulses(n) + context.plot_awgs(sequence, ylim=(-0.100,0.100)) + + return None + +if __name__ == '__main__': + ds = test()