Skip to content
Snippets Groups Projects
Commit e2194ff8 authored by Sander de Snoo's avatar Sander de Snoo
Browse files

Improved loop_controller

parent 2a0fb6c8
No related branches found
No related tags found
No related merge requests found
import copy
from functools import wraps
from dataclasses import dataclass
from typing import Union, List
import numpy as np
from pulse_lib.segments.utility.looping import loop_obj
from pulse_lib.segments.data_classes.data_generic import data_container
from pulse_lib.segments.utility.setpoint_mgr import setpoint
from functools import wraps
import numpy as np
import copy
def find_common_dimension(dim_1, dim_2):
'''
......@@ -167,12 +170,21 @@ def _get_new_dim_loop(current_dim, axis, shape):
return tuple(new_dim), axis
def _update_segment_dims(segment, lp, arg_index, rendering=False):
axes = list(lp.axis)
def _update_segment_dims(segment, lp, rendering=False):
data = segment.data if not rendering else segment.pulse_data_all
data_shape = data.shape
for i,a in enumerate(lp.axis):
if (data_shape[a] != lp.shape[i]
or (not lp.no_setpoints and a not in segment._setpoints)):
# update dimes / setpoints
break
else:
# nothing to update
return lp.axis
axes = list(lp.axis)
for i in range(len(lp.axis)-1,-1,-1):
data_shape = data.shape
lp_axis = lp.axis[i]
lp_length = lp.shape[i]
new_shape, axis = _get_new_dim_loop(data_shape, lp_axis, lp_length)
......@@ -194,7 +206,13 @@ def _update_segment_dims(segment, lp, arg_index, rendering=False):
else:
segment._pulse_data_all = data
return {'arg_index':arg_index, 'axes':axes}
return axes
@dataclass
class LoopInfo:
key: Union[int,str]
axes: List[int]
_in_loop = False
def loop_controller(func):
......@@ -225,13 +243,13 @@ def loop_controller(func):
for i,arg in enumerate(args):
if isinstance(arg, loop_obj):
loop_info = _update_segment_dims(obj, arg, i)
loop_info_args.append(loop_info)
axes = _update_segment_dims(obj, arg)
loop_info_args.append(LoopInfo(i, axes))
for key,kwarg in kwargs.items():
if isinstance(kwarg, loop_obj):
loop_info = _update_segment_dims(obj, kwarg, key)
loop_info_kwargs.append(loop_info)
axes = _update_segment_dims(obj, kwarg)
loop_info_kwargs.append(LoopInfo(key, axes))
data = obj.data
......@@ -270,13 +288,13 @@ def loop_controller_post_processing(func):
for i,arg in enumerate(args):
if isinstance(arg, loop_obj):
loop_info = _update_segment_dims(obj, arg, i, rendering=True)
loop_info_args.append(loop_info)
axes = _update_segment_dims(obj, arg, rendering=True)
loop_info_args.append(LoopInfo(i, axes))
for key,kwarg in kwargs.items():
if isinstance(kwarg, loop_obj):
loop_info = _update_segment_dims(obj, kwarg, key, rendering=True)
loop_info_kwargs.append(loop_info)
axes = _update_segment_dims(obj, kwarg, rendering=True)
loop_info_kwargs.append(LoopInfo(key, axes))
data = obj.pulse_data_all
end_times = obj._end_times
......@@ -317,12 +335,12 @@ def loop_over_data_lp(func, obj, data, end_times, args, args_info, kwargs, kwarg
for i in range(shape[0]):
for arg in args_info:
if n_dim-1 in arg['axes']:
index = arg['arg_index']
if n_dim-1 in arg.axes:
index = arg.key
args_cpy[index] = args[index][i]
for kwarg in kwargs_info:
if n_dim-1 in kwarg['axes']:
index = kwarg['arg_index']
if n_dim-1 in kwarg.axes:
index = kwarg.key
kwargs_cpy[index] = kwargs[index][i]
if n_dim == 1:
......
......@@ -51,6 +51,9 @@ class setpoint_mgr():
def __repr__(self):
return self.__str__()
def __contains__(self, axis):
return axis in self._setpoints
def __getitem__(self, axis):
"""
get setpoint data for a certain axis
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment