Skip to content
Snippets Groups Projects
Commit 4bac8f11 authored by Sander de Snoo's avatar Sander de Snoo
Browse files

Fix for loop with length 1

parent 7bf05812
No related branches found
No related tags found
No related merge requests found
import copy import copy
from functools import wraps from functools import wraps
from dataclasses import dataclass from dataclasses import dataclass
from typing import Union, List
import numpy as np import numpy as np
from pulse_lib.segments.utility.looping import loop_obj from pulse_lib.segments.utility.looping import loop_obj
from pulse_lib.segments.data_classes.data_generic import data_container from pulse_lib.segments.data_classes.data_generic import data_container
from pulse_lib.segments.utility.setpoint_mgr import setpoint from pulse_lib.segments.utility.setpoint_mgr import setpoint
use_end_time_cache = True use_end_time_cache = True
def find_common_dimension(dim_1, dim_2): def find_common_dimension(dim_1, dim_2):
''' '''
finds the union of two dimensions finds the union of two dimensions
...@@ -174,7 +175,7 @@ def _get_new_dim_loop(current_dim, axis, shape): ...@@ -174,7 +175,7 @@ def _get_new_dim_loop(current_dim, axis, shape):
def _update_segment_dims(segment, lp, rendering=False): def _update_segment_dims(segment, lp, rendering=False):
data = segment.data if not rendering else segment.pulse_data_all data = segment.data if not rendering else segment.pulse_data_all
for i,a in enumerate(lp.axis): for i, a in enumerate(lp.axis):
if (a >= len(data.shape) if (a >= len(data.shape)
or data.shape[a] != lp.shape[i] or data.shape[a] != lp.shape[i]
or (not lp.no_setpoints and a not in segment._setpoints)): or (not lp.no_setpoints and a not in segment._setpoints)):
...@@ -193,7 +194,7 @@ def _update_segment_dims(segment, lp, rendering=False): ...@@ -193,7 +194,7 @@ def _update_segment_dims(segment, lp, rendering=False):
if new_shape != data.shape: if new_shape != data.shape:
if segment.is_slice: if segment.is_slice:
# TODO: Fix this with refactored indexing. # TODO: Fix this with refactored indexing.
raise Exception(f'Cannot resize data in slice (Indexing). ' raise Exception('Cannot resize data in slice (Indexing). '
'All loop axes must be added before indexing segment.') 'All loop axes must be added before indexing segment.')
data = update_dimension(data, new_shape) data = update_dimension(data, new_shape)
if use_end_time_cache: if use_end_time_cache:
...@@ -214,8 +215,9 @@ def _update_segment_dims(segment, lp, rendering=False): ...@@ -214,8 +215,9 @@ def _update_segment_dims(segment, lp, rendering=False):
@dataclass @dataclass
class LoopInfo: class LoopInfo:
key: Union[int,str] key: int | str
axes: List[int] axes: list[int]
_in_loop = False _in_loop = False
def loop_controller(func): def loop_controller(func):
...@@ -241,25 +243,26 @@ def loop_controller(func): ...@@ -241,25 +243,26 @@ def loop_controller(func):
loop_info_args = [] loop_info_args = []
loop_info_kwargs = [] loop_info_kwargs = []
for i,arg in enumerate(args): for i, arg in enumerate(args):
if isinstance(arg, loop_obj): if isinstance(arg, loop_obj):
axes = _update_segment_dims(obj, arg) axes = _update_segment_dims(obj, arg)
loop_info_args.append(LoopInfo(i, axes)) loop_info_args.append(LoopInfo(i, axes))
for key,kwarg in kwargs.items(): for key, kwarg in kwargs.items():
if isinstance(kwarg, loop_obj): if isinstance(kwarg, loop_obj):
axes = _update_segment_dims(obj, kwarg) axes = _update_segment_dims(obj, kwarg)
loop_info_kwargs.append(LoopInfo(key, axes)) loop_info_kwargs.append(LoopInfo(key, axes))
data = obj.data data = obj.data
if data.shape == (1,): if len(loop_info_args) == 0 and len(loop_info_kwargs) == 0:
obj.data_tmp = data[0] if data.shape == (1,):
data[0] = func(obj, *args, **kwargs) obj.data_tmp = data[0]
if use_end_time_cache: data[0] = func(obj, *args, **kwargs)
obj._end_times[0] = data[0].end_time if use_end_time_cache:
elif len(loop_info_args) == 0 and len(loop_info_kwargs) == 0: obj._end_times[0] = data[0].end_time
loop_over_data(func, obj, data, obj._end_times, args, kwargs) else:
loop_over_data(func, obj, data, obj._end_times, args, kwargs)
else: else:
args_cpy = list(args) args_cpy = list(args)
kwargs_cpy = kwargs.copy() kwargs_cpy = kwargs.copy()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment