update coeff handling to allow multi-variable basis · python-control/python-control@47262f5
@@ -268,17 +268,16 @@ def _cost_function(self, coeffs):
268268start_time = time.process_time()
269269logging.info("_cost_function called at: %g", start_time)
270270271-# Retrieve the initial state and reshape the input vector
271+# Retrieve the saved initial state
272272x = self.x
273-coeffs = coeffs.reshape((self.system.ninputs, -1))
274273275-# Compute time points (if basis present)
274+# Compute inputs
276275if self.basis:
277276if self.log:
278277logging.debug("coefficients = " + str(coeffs))
279278inputs = self._coeffs_to_inputs(coeffs)
280279else:
281-inputs = coeffs
280+inputs = coeffs.reshape((self.system.ninputs, -1))
282281283282# See if we already have a simulation for this condition
284283if np.array_equal(coeffs, self.last_coeffs) and \
@@ -391,15 +390,14 @@ def _constraint_function(self, coeffs):
391390start_time = time.process_time()
392391logging.info("_constraint_function called at: %g", start_time)
393392394-# Retrieve the initial state and reshape the input vector
393+# Retrieve the initial state
395394x = self.x
396-coeffs = coeffs.reshape((self.system.ninputs, -1))
397395398-# Compute time points (if basis present)
396+# Compute input at time points
399397if self.basis:
400398inputs = self._coeffs_to_inputs(coeffs)
401399else:
402-inputs = coeffs
400+inputs = coeffs.reshape((self.system.ninputs, -1))
403401404402# See if we already have a simulation for this condition
405403if np.array_equal(coeffs, self.last_coeffs) \
@@ -473,15 +471,14 @@ def _eqconst_function(self, coeffs):
473471start_time = time.process_time()
474472logging.info("_eqconst_function called at: %g", start_time)
475473476-# Retrieve the initial state and reshape the input vector
474+# Retrieve the initial state
477475x = self.x
478-coeffs = coeffs.reshape((self.system.ninputs, -1))
479476480-# Compute time points (if basis present)
477+# Compute input at time points
481478if self.basis:
482479inputs = self._coeffs_to_inputs(coeffs)
483480else:
484-inputs = coeffs
481+inputs = coeffs.reshape((self.system.ninputs, -1))
485482486483# See if we already have a simulation for this condition
487484if np.array_equal(coeffs, self.last_coeffs) and \
@@ -609,34 +606,36 @@ def _inputs_to_coeffs(self, inputs):
609606return inputs
610607611608# Solve least squares problems (M x = b) for coeffs on each input
612-coeffs = np.zeros((self.system.ninputs, self.basis.N))
609+coeffs = []
613610for i in range(self.system.ninputs):
614611# Set up the matrices to get inputs
615-M = np.zeros((self.timepts.size, self.basis.N))
612+M = np.zeros((self.timepts.size, self.basis.var_ncoefs(i)))
616613b = np.zeros(self.timepts.size)
617614618615# Evaluate at each time point and for each basis function
619616# TODO: vectorize
620617for j, t in enumerate(self.timepts):
621-for k in range(self.basis.N):
618+for k in range(self.basis.var_ncoefs(i)):
622619M[j, k] = self.basis(k, t)
623- b[j] = inputs[i, j]
620+b[j] = inputs[i, j]
624621625622# Solve a least squares problem for the coefficients
626623alpha, residuals, rank, s = np.linalg.lstsq(M, b, rcond=None)
627-coeffs[i, :] = alpha
624+coeffs.append(alpha)
628625629-return coeffs
626+return np.hstack(coeffs)
630627631628# Utility function to convert coefficient vector to input vector
632629def _coeffs_to_inputs(self, coeffs):
633630# TODO: vectorize
634631inputs = np.zeros((self.system.ninputs, self.timepts.size))
635-for i, t in enumerate(self.timepts):
636-for k in range(self.basis.N):
637-phi_k = self.basis(k, t)
638-for inp in range(self.system.ninputs):
639-inputs[inp, i] += coeffs[inp, k] * phi_k
632+offset = 0
633+for i in range(self.system.ninputs):
634+length = self.basis.var_ncoefs(i)
635+for j, t in enumerate(self.timepts):
636+for k in range(length):
637+inputs[i, j] += coeffs[offset + k] * self.basis(k, t)
638+offset += length
640639return inputs
641640642641#
@@ -680,7 +679,7 @@ def _print_statistics(self, reset=True):
680679681680# Compute the optimal trajectory from the current state
682681def compute_trajectory(
683-self, x, squeeze=None, transpose=None, return_states=None,
682+self, x, squeeze=None, transpose=None, return_states=True,
684683initial_guess=None, print_summary=True, **kwargs):
685684"""Compute the optimal input at state x
686685@@ -689,8 +688,7 @@ def compute_trajectory(
689688 x : array-like or number, optional
690689 Initial state for the system.
691690 return_states : bool, optional
692- If True, return the values of the state at each time (default =
693- False).
691+ If True (default), return the values of the state at each time.
694692 squeeze : bool, optional
695693 If True and if the system has a single output, return the system
696694 output as a 1D array rather than a 2D array. If False, return the
@@ -837,7 +835,7 @@ class OptimalControlResult(sp.optimize.OptimizeResult):
837835838836 """
839837def __init__(
840-self, ocp, res, return_states=False, print_summary=False,
838+self, ocp, res, return_states=True, print_summary=False,
841839transpose=None, squeeze=None):
842840"""Create a OptimalControlResult object"""
843841@@ -848,14 +846,11 @@ def __init__(
848846# Remember the optimal control problem that we solved
849847self.problem = ocp
850848851-# Reshape and process the input vector
852-coeffs = res.x.reshape((ocp.system.ninputs, -1))
853-854-# Compute time points (if basis present)
849+# Compute input at time points
855850if ocp.basis:
856-inputs = ocp._coeffs_to_inputs(coeffs)
851+inputs = ocp._coeffs_to_inputs(res.x)
857852else:
858-inputs = coeffs
853+inputs = res.x.reshape((ocp.system.ninputs, -1))
859854860855# See if we got an answer
861856if not res.success:
@@ -894,7 +889,7 @@ def __init__(
894889def solve_ocp(
895890sys, horizon, X0, cost, trajectory_constraints=None, terminal_cost=None,
896891terminal_constraints=[], initial_guess=None, basis=None, squeeze=None,
897-transpose=None, return_states=False, log=False, **kwargs):
892+transpose=None, return_states=True, log=False, **kwargs):
898893899894"""Compute the solution to an optimal control problem
900895@@ -949,7 +944,7 @@ def solve_ocp(
949944 If `True`, turn on logging messages (using Python logging module).
950945951946 return_states : bool, optional
952- If True, return the values of the state at each time (default = False).
947+ If True, return the values of the state at each time (default = True).
953948954949 squeeze : bool, optional
955950 If True and if the system has a single output, return the system