implement indexing for transfer function · python-control/python-control@97f201b

2 files changed

lines changed

Original file line numberDiff line numberDiff line change

@@ -5,8 +5,8 @@

55
66

import unittest

77

import numpy as np

8-

from control.statesp import StateSpace, _convertToStateSpace

9-

from control.xferfcn import TransferFunction, _convertToTransferFunction

8+

from control.statesp import StateSpace, _convertToStateSpace, rss

9+

from control.xferfcn import TransferFunction, _convertToTransferFunction, ss2tf

1010

from control.lti import evalfr

1111

from control.exception import slycot_check

1212

# from control.lti import isdtime

@@ -536,6 +536,25 @@ def testMIMO(self):

536536

np.testing.assert_array_almost_equal(H.num[1][0], H2.num[1][0])

537537

np.testing.assert_array_almost_equal(H.den[1][0], H2.den[1][0])

538538
539+

def testIndexing(self):

540+

tm = ss2tf(rss(5, 3, 3))

541+
542+

# scalar indexing

543+

sys01 = tm[0, 1]

544+

np.testing.assert_array_almost_equal(sys01.num[0][0], tm.num[0][1])

545+

np.testing.assert_array_almost_equal(sys01.den[0][0], tm.den[0][1])

546+
547+

# slice indexing

548+

sys = tm[:2, 1:3]

549+

np.testing.assert_array_almost_equal(sys.num[0][0], tm.num[0][1])

550+

np.testing.assert_array_almost_equal(sys.den[0][0], tm.den[0][1])

551+

np.testing.assert_array_almost_equal(sys.num[0][1], tm.num[0][2])

552+

np.testing.assert_array_almost_equal(sys.den[0][1], tm.den[0][2])

553+

np.testing.assert_array_almost_equal(sys.num[1][0], tm.num[1][1])

554+

np.testing.assert_array_almost_equal(sys.den[1][0], tm.den[1][1])

555+

np.testing.assert_array_almost_equal(sys.num[1][1], tm.num[1][2])

556+

np.testing.assert_array_almost_equal(sys.den[1][1], tm.den[1][2])

557+
539558

def testMatrixMult(self):

540559

"""MIMO transfer functions should be multiplyable by constant

541560

matrices"""

Original file line numberDiff line numberDiff line change

@@ -491,6 +491,46 @@ def __pow__(self, other):

491491

if other < 0:

492492

return (TransferFunction([1], [1]) / self) * (self**(other+1))

493493
494+

def __getitem__(self, key):

495+

key1, key2 = key

496+
497+

# pre-process

498+

if isinstance(key1, int):

499+

key1 = slice(key1, key1 + 1, 1)

500+

if isinstance(key2, int):

501+

key2 = slice(key2, key2 + 1, 1)

502+

# dim1

503+

start1, stop1, step1 = key1.start, key1.stop, key1.step

504+

if step1 is None:

505+

step1 = 1

506+

if start1 is None:

507+

start1 = 0

508+

if stop1 is None:

509+

stop1 = len(self.num)

510+

# dim1

511+

start2, stop2, step2 = key2.start, key2.stop, key2.step

512+

if step2 is None:

513+

step2 = 1

514+

if start2 is None:

515+

start2 = 0

516+

if stop2 is None:

517+

stop2 = len(self.num[0])

518+
519+

num = []

520+

den = []

521+

for i in range(start1, stop1, step1):

522+

num_i = []

523+

den_i = []

524+

for j in range(start2, stop2, step2):

525+

num_i.append(self.num[i][j])

526+

den_i.append(self.den[i][j])

527+

num.append(num_i)

528+

den.append(den_i)

529+

if self.isctime():

530+

return TransferFunction(num, den)

531+

else:

532+

return TransferFunction(num, den, self.dt)

533+
494534

def evalfr(self, omega):

495535

"""Evaluate a transfer function at a single angular frequency.

496536