Skip to content
Snippets Groups Projects
Commit 5dba23d9 authored by Sander de Snoo's avatar Sander de Snoo
Browse files

Improved performance of loop_controller and find_common_dimension using shortcut for shape = (1,)

parent 6797a173
No related branches found
No related tags found
No related merge requests found
......@@ -17,6 +17,8 @@ def find_common_dimension(dim_1, dim_2):
Will raise error is dimensions are not compatible
'''
if dim_2 == (1,):
return dim_1
dim_1 = list(dim_1)[::-1]
dim_2 = list(dim_2)[::-1]
dim_comb = []
......@@ -178,38 +180,44 @@ def loop_controller(func):
if isinstance(kwarg, loop_obj):
loop_info_kwargs.append(_get_loop_info(kwarg, key))
orig_data = obj.data
for lp in loop_info_args:
for i in range(len(lp['axis'])-1,-1,-1):
data_shape = obj.data.shape
lp_axis = lp['axis'][i]
lp_length = lp['shape'][i]
new_dim, axis = get_new_dim_loop(data_shape, lp_axis, lp_length)
lp['axis'][i] = axis
obj.data = update_dimension(obj.data, new_dim)
if lp['setpnt'] is not None:
lp['setpnt'][i].axis = axis
obj._setpoints += lp['setpnt'][i]
for lp in loop_info_kwargs:
for i in range(len(lp['axis'])-1,-1,-1):
new_dim, axis = get_new_dim_loop(obj.data.shape, lp['axis'][i], lp['shape'][i])
lp['axis'][i] = axis
obj.data = update_dimension(obj.data, new_dim)
if len(loop_info_args) == 0 and len(loop_info_kwargs) == 0:
data = obj.data
if data.shape != (1,):
loop_over_data(func, data, args, kwargs)
else:
obj.data_tmp = data[0]
data[0] = func(*args, **kwargs)
if lp['setpnt'] is not None:
lp['setpnt'][i].axis = axis
obj._setpoints += lp['setpnt'][i]
else:
if orig_data is not obj.data:
print(f'data {obj.name} change {orig_data.shape} -> {obj.data.shape}')
orig_data = obj.data
for lp in loop_info_args:
for i in range(len(lp['axis'])-1,-1,-1):
data_shape = obj.data.shape
lp_axis = lp['axis'][i]
lp_length = lp['shape'][i]
new_dim, axis = get_new_dim_loop(data_shape, lp_axis, lp_length)
lp['axis'][i] = axis
obj.data = update_dimension(obj.data, new_dim)
if lp['setpnt'] is not None:
lp['setpnt'][i].axis = axis
obj._setpoints += lp['setpnt'][i]
for lp in loop_info_kwargs:
for i in range(len(lp['axis'])-1,-1,-1):
new_dim, axis = get_new_dim_loop(obj.data.shape, lp['axis'][i], lp['shape'][i])
lp['axis'][i] = axis
obj.data = update_dimension(obj.data, new_dim)
if lp['setpnt'] is not None:
lp['setpnt'][i].axis = axis
obj._setpoints += lp['setpnt'][i]
if orig_data is not obj.data:
print(f'data {obj.name} change {orig_data.shape} -> {obj.data.shape}')
obj_data = obj.data
if len(loop_info_args) > 0 or len(loop_info_kwargs) > 0:
loop_over_data_lp(func, obj_data, args, loop_info_args, kwargs, loop_info_kwargs)
else:
loop_over_data(func, obj_data, args, kwargs)
return wrapper
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment