ensured, that __call__ can be inherited through multiple levels of hi… · pythonnet/pythonnet@5bb1007

11

using System;

22

using System.Collections.Generic;

33

using System.Diagnostics;

4+

using System.Linq;

45

using System.Reflection;

56

using System.Runtime.InteropServices;

67

@@ -297,44 +298,31 @@ public static IntPtr tp_call(IntPtr ob, IntPtr args, IntPtr kw)

297298298299

if (cb.type != typeof(Delegate))

299300

{

300-

IntPtr dict = Marshal.ReadIntPtr(tp, TypeOffset.tp_dict);

301-

IntPtr methodObjectHandle = Runtime.PyDict_GetItemString(dict, "__call__");

302-

if (methodObjectHandle == IntPtr.Zero || methodObjectHandle == Runtime.PyNone)

303-

{

304-

Exceptions.SetError(Exceptions.TypeError, "object is not callable");

305-

return IntPtr.Zero;

306-

}

307-308-

if (GetManagedObject(methodObjectHandle) is MethodObject methodObject)

309-

{

310-

return methodObject.Invoke(ob, args, kw);

301+

var calls = cb.type.GetMethods(BindingFlags.Public | BindingFlags.Instance)

302+

.Where(m => m.Name == "__call__")

303+

.ToList();

304+

if (calls.Count > 0) {

305+

var callBinder = new MethodBinder();

306+

foreach (MethodInfo call in calls) {

307+

callBinder.AddMethod(call);

308+

}

309+

return callBinder.Invoke(ob, args, kw);

311310

}

312311313-

methodObjectHandle = IntPtr.Zero;

314-315-

foreach (IntPtr pythonBase in GetPythonBases(tp)) {

316-

dict = Marshal.ReadIntPtr(pythonBase, TypeOffset.tp_dict);

312+

using var super = new PyObject(Runtime.SelfIncRef(Runtime.PySuper));

313+

using var self = new PyObject(Runtime.SelfIncRef(ob));

314+

using var none = new PyObject(Runtime.SelfIncRef(Runtime.PyNone));

315+

foreach (IntPtr managedTypeDerivingFromPython in GetTypesWithPythonBasesInHierarchy(tp)) {

316+

using var @base = super.Invoke(new PyObject(managedTypeDerivingFromPython), self);

317+

using var call = @base.GetAttrOrElse("__call__", none);

317318318-

methodObjectHandle = Runtime.PyDict_GetItemString(dict, "__call__");

319-

if (methodObjectHandle != IntPtr.Zero && methodObjectHandle != Runtime.PyNone) break;

320-

}

319+

if (call.Handle == Runtime.PyNone) continue;

321320322-

if (methodObjectHandle == IntPtr.Zero || methodObjectHandle == Runtime.PyNone) {

323-

Exceptions.SetError(Exceptions.TypeError, "object is not callable");

324-

return IntPtr.Zero;

321+

return Runtime.PyObject_Call(call.Handle, args, kw);

325322

}

326323327-

var boundMethod = Runtime.PyMethod_New(methodObjectHandle, ob);

328-

if (boundMethod == IntPtr.Zero) { return IntPtr.Zero; }

329-330-

try

331-

{

332-

return Runtime.PyObject_Call(boundMethod, args, kw);

333-

}

334-

finally

335-

{

336-

Runtime.XDecref(boundMethod);

337-

}

324+

Exceptions.SetError(Exceptions.TypeError, "object is not callable");

325+

return IntPtr.Zero;

338326

}

339327340328

var co = (CLRObject)GetManagedObject(ob);

@@ -377,6 +365,38 @@ internal static IEnumerable<IntPtr> GetPythonBases(IntPtr tp) {

377365

yield return tp;

378366

}

379367368+

internal static IEnumerable<IntPtr> GetTypesWithPythonBasesInHierarchy(IntPtr tp) {

369+

Debug.Assert(IsManagedType(tp));

370+371+

var candidateQueue = new Queue<IntPtr>();

372+

candidateQueue.Enqueue(tp);

373+

while (candidateQueue.Count > 0) {

374+

tp = candidateQueue.Dequeue();

375+

IntPtr bases = Marshal.ReadIntPtr(tp, TypeOffset.tp_bases);

376+

if (bases != IntPtr.Zero) {

377+

long baseCount = Runtime.PyTuple_Size(bases);

378+

bool hasPythonBase = false;

379+

for (long baseIndex = 0; baseIndex < baseCount; baseIndex++) {

380+

IntPtr @base = Runtime.PyTuple_GetItem(bases, baseIndex);

381+

if (IsManagedType(@base)) {

382+

candidateQueue.Enqueue(@base);

383+

} else {

384+

hasPythonBase = true;

385+

}

386+

}

387+388+

if (hasPythonBase) yield return tp;

389+

} else {

390+

tp = Marshal.ReadIntPtr(tp, TypeOffset.tp_base);

391+

if (tp != IntPtr.Zero && IsManagedType(tp))

392+

candidateQueue.Enqueue(tp);

393+

}

394+

}

395+

}

396+397+

/// <summary>

398+

/// Checks if specified type is a CLR type

399+

/// </summary>

380400

internal static bool IsManagedType(IntPtr tp)

381401

{

382402

var flags = Util.ReadCLong(tp, TypeOffset.tp_flags);