Skip to content
Snippets Groups Projects
Commit 439c74e8 authored by Sander Snoo's avatar Sander Snoo
Browse files

Reduced amount of copying.

parent cdfc746e
No related branches found
No related tags found
No related merge requests found
...@@ -45,6 +45,9 @@ class loop_obj(): ...@@ -45,6 +45,9 @@ class loop_obj():
raise ValueError("Axis must be defined in descending order, e.g. [1,0]") raise ValueError("Axis must be defined in descending order, e.g. [1,0]")
self.axis = axis self.axis = axis
if self.no_setpoints:
return
if names is None: if names is None:
names = labels names = labels
elif labels is None: elif labels is None:
...@@ -77,31 +80,30 @@ class loop_obj(): ...@@ -77,31 +80,30 @@ class loop_obj():
raise ValueError("Provided incorrect units.") raise ValueError("Provided incorrect units.")
self.units = units self.units = units
if not self.no_setpoints: if setvals is None:
if setvals is None: if len(self.data.shape) == 1:
if len(self.data.shape) == 1: self.setvals = (self.data, )
self.setvals = (self.data, )
else:
raise ValueError ('Multidimensional setpoints cannot be inferred from input.')
else: else:
self.setvals = tuple() raise ValueError ('Multidimensional setpoints cannot be inferred from input.')
if isinstance(setvals,list) or isinstance(setvals, np.ndarray): else:
setvals = np.asarray(setvals) self.setvals = tuple()
if isinstance(setvals,list) or isinstance(setvals, np.ndarray):
setvals = np.asarray(setvals)
if self.shape != setvals.shape: if self.shape != setvals.shape:
raise ValueError("setvals should have the same dimensions as the data dimensions.")
setvals = (setvals, )
else:
setvals = list(setvals)
for setval_idx in range(len(setvals)):
setvals[setval_idx] = np.asarray(setvals[setval_idx])
if self.shape[setval_idx] != len(setvals[setval_idx]):
raise ValueError("setvals should have the same dimensions as the data dimensions.") raise ValueError("setvals should have the same dimensions as the data dimensions.")
setvals = (setvals, )
else:
setvals = list(setvals)
for setval_idx in range(len(setvals)):
setvals[setval_idx] = np.asarray(setvals[setval_idx])
if self.shape[setval_idx] != len(setvals[setval_idx]):
raise ValueError("setvals should have the same dimensions as the data dimensions.")
setvals = tuple(setvals) setvals = tuple(setvals)
self.setvals += setvals self.setvals += setvals
self.setvals_set = True self.setvals_set = True
def __len__(self): def __len__(self):
return len(self.data) return len(self.data)
...@@ -248,7 +250,8 @@ class loop_obj(): ...@@ -248,7 +250,8 @@ class loop_obj():
cpy.names = copy.copy(self.names) cpy.names = copy.copy(self.names)
cpy.labels = copy.copy(self.labels) cpy.labels = copy.copy(self.labels)
cpy.setvals = copy.copy(self.setvals) cpy.setvals = copy.copy(self.setvals)
cpy.setvals_set = copy.copy(self.setvals_set) cpy.setvals_set = self.setvals_set
cpy.no_setpoints = self.no_setpoints
cpy.units = copy.copy(self.units) cpy.units = copy.copy(self.units)
cpy.axis = copy.copy(self.axis) cpy.axis = copy.copy(self.axis)
cpy.dtype = copy.copy(self.dtype) cpy.dtype = copy.copy(self.dtype)
...@@ -287,6 +290,9 @@ class loop_obj(): ...@@ -287,6 +290,9 @@ class loop_obj():
@staticmethod @staticmethod
def __combine_axis(this, other): def __combine_axis(this, other):
if this.axis == other.axis and this.shape == other.shape:
return this.data, other.data
new_axis = sorted(set(this.axis) | set(other.axis)) new_axis = sorted(set(this.axis) | set(other.axis))
new_axis.reverse() new_axis.reverse()
...@@ -300,10 +306,11 @@ class loop_obj(): ...@@ -300,10 +306,11 @@ class loop_obj():
try: try:
ithis = this.axis.index(axis) ithis = this.axis.index(axis)
sel_this.append(slice(None)) sel_this.append(slice(None))
new_names.append(this.names[ithis]) if not this.no_setpoints:
new_labels.append(this.labels[ithis]) new_names.append(this.names[ithis])
new_units.append(this.units[ithis]) new_labels.append(this.labels[ithis])
new_setvals.append(this.setvals[ithis]) new_units.append(this.units[ithis])
new_setvals.append(this.setvals[ithis])
try: try:
# check equality of shapes # check equality of shapes
iother = other.axis.index(axis) iother = other.axis.index(axis)
...@@ -316,11 +323,11 @@ class loop_obj(): ...@@ -316,11 +323,11 @@ class loop_obj():
# add new axis # add new axis
sel_this.append(np.newaxis) sel_this.append(np.newaxis)
iother = other.axis.index(axis) iother = other.axis.index(axis)
new_names.append(other.names[iother]) if not this.no_setpoints:
new_labels.append(other.labels[iother]) new_names.append(other.names[iother])
new_units.append(other.units[iother]) new_labels.append(other.labels[iother])
new_setvals.append(other.setvals[iother]) new_units.append(other.units[iother])
new_setvals.append(other.setvals[iother])
try: try:
iother = other.axis.index(axis) iother = other.axis.index(axis)
sel_other.append(slice(None)) sel_other.append(slice(None))
...@@ -328,10 +335,11 @@ class loop_obj(): ...@@ -328,10 +335,11 @@ class loop_obj():
sel_other.append(np.newaxis) sel_other.append(np.newaxis)
this.axis = new_axis this.axis = new_axis
this.names = new_names if not this.no_setpoints:
this.labels = new_labels this.names = tuple(new_names)
this.units = new_units this.labels = tuple(new_labels)
this.setvals = new_setvals this.units = tuple(new_units)
this.setvals = tuple(new_setvals)
return this.data[tuple(sel_this)], other.data[tuple(sel_other)] return this.data[tuple(sel_this)], other.data[tuple(sel_other)]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment