FIX: Resource cleanup on Python shutdown to avoid segfaults by bewithgaurav · Pull Request #255 · microsoft/mssql-python

Expand Up @@ -475,3 +475,212 @@ def test_mixed_cursor_cleanup_scenarios(conn_str, tmp_path): assert "All tests passed" in result.stdout # Should not have error logs assert "Exception during cursor cleanup" not in result.stderr

def test_sql_syntax_error_no_segfault_on_shutdown(conn_str): """Test that SQL syntax errors don't cause segfault during Python shutdown""" # This test reproduces the exact scenario that was causing segfaults escaped_conn_str = conn_str.replace('\\', '\\\\').replace('"', '\\"') code = f""" from mssql_python import connect
# Create connection conn = connect("{escaped_conn_str}") cursor = conn.cursor()
# Execute invalid SQL that causes syntax error - this was causing segfault cursor.execute("syntax error")
# Don't explicitly close cursor/connection - let Python shutdown handle cleanup print("Script completed, shutting down...") # This would NOT print anyways # Segfault would happen here during Python shutdown """
# Run in subprocess to catch segfaults result = subprocess.run( [sys.executable, "-c", code], capture_output=True, text=True )
# Should not segfault (exit code 139 on Unix, 134 on macOS) assert result.returncode == 1, f"Expected exit code 1 due to syntax error, but got {result.returncode}. STDERR: {result.stderr}"
def test_multiple_sql_syntax_errors_no_segfault(conn_str): """Test multiple SQL syntax errors don't cause segfault during cleanup""" escaped_conn_str = conn_str.replace('\\', '\\\\').replace('"', '\\"') code = f""" from mssql_python import connect
conn = connect("{escaped_conn_str}")
# Multiple cursors with syntax errors cursors = [] for i in range(3): cursor = conn.cursor() cursors.append(cursor) cursor.execute(f"invalid sql syntax {{i}}")
# Mix of syntax errors and valid queries cursor_valid = conn.cursor() cursor_valid.execute("SELECT 1") cursor_valid.fetchall() cursors.append(cursor_valid)
# Don't close anything - test Python shutdown cleanup print("Multiple syntax errors handled, shutting down...") """
result = subprocess.run( [sys.executable, "-c", code], capture_output=True, text=True )
assert result.returncode == 1, f"Expected exit code 1 due to syntax errors, but got {result.returncode}. STDERR: {result.stderr}"

def test_connection_close_during_active_query_no_segfault(conn_str): """Test closing connection while cursor has pending results doesn't cause segfault""" escaped_conn_str = conn_str.replace('\\', '\\\\').replace('"', '\\"') code = f""" from mssql_python import connect
# Create connection and cursor conn = connect("{escaped_conn_str}") cursor = conn.cursor()
# Execute query but don't fetch results - leave them pending cursor.execute("SELECT COUNT(*) FROM sys.objects")
# Close connection while results are still pending # This tests handle cleanup when STMT has pending results but DBC is freed conn.close()
print("Connection closed with pending cursor results") # Cursor destructor will run during normal cleanup, not shutdown """
result = subprocess.run( [sys.executable, "-c", code], capture_output=True, text=True )
# Should not segfault - should exit cleanly assert result.returncode == 0, f"Expected clean exit, but got exit code {result.returncode}. STDERR: {result.stderr}" assert "Connection closed with pending cursor results" in result.stdout

def test_concurrent_cursor_operations_no_segfault(conn_str): """Test concurrent cursor operations don't cause segfaults or race conditions""" escaped_conn_str = conn_str.replace('\\', '\\\\').replace('"', '\\"') code = f""" import threading from mssql_python import connect
conn = connect("{escaped_conn_str}") results = [] exceptions = []
def worker(thread_id): try: for i in range(15): cursor = conn.cursor() cursor.execute(f"SELECT {{thread_id * 100 + i}} as value") result = cursor.fetchone() results.append(result[0]) # Don't explicitly close cursor - test concurrent destructors except Exception as e: exceptions.append(f"Thread {{thread_id}}: {{e}}")
# Create multiple threads doing concurrent cursor operations threads = [] for i in range(4): t = threading.Thread(target=worker, args=(i,)) threads.append(t) t.start()
for t in threads: t.join()
print(f"Completed: {{len(results)}} results, {{len(exceptions)}} exceptions")
# Report any exceptions for debugging for exc in exceptions: print(f"Exception: {{exc}}")
print("Concurrent operations completed") """
result = subprocess.run( [sys.executable, "-c", code], capture_output=True, text=True )
# Should not segfault assert result.returncode == 0, f"Expected clean exit, but got exit code {result.returncode}. STDERR: {result.stderr}" assert "Concurrent operations completed" in result.stdout
# Check that most operations completed successfully # Allow for some exceptions due to threading, but shouldn't be many output_lines = result.stdout.split('\n') completed_line = [line for line in output_lines if 'Completed:' in line] if completed_line: # Extract numbers from "Completed: X results, Y exceptions" import re match = re.search(r'Completed: (\d+) results, (\d+) exceptions', completed_line[0]) if match: results_count = int(match.group(1)) exceptions_count = int(match.group(2)) # Should have completed most operations (allow some threading issues) assert results_count >= 50, f"Too few successful operations: {results_count}" assert exceptions_count <= 10, f"Too many exceptions: {exceptions_count}"

def test_aggressive_threading_abrupt_exit_no_segfault(conn_str): """Test abrupt exit with active threads and pending queries doesn't cause segfault""" escaped_conn_str = conn_str.replace('\\', '\\\\').replace('"', '\\"') code = f""" import threading import sys import time from mssql_python import connect
conn = connect("{escaped_conn_str}")
def aggressive_worker(thread_id): '''Worker that creates cursors with pending results and doesn't clean up''' for i in range(8): cursor = conn.cursor() # Execute query but don't fetch - leave results pending cursor.execute(f"SELECT COUNT(*) FROM sys.objects WHERE object_id > {{thread_id * 1000 + i}}")
# Create another cursor immediately without cleaning up the first cursor2 = conn.cursor() cursor2.execute(f"SELECT TOP 3 * FROM sys.objects WHERE object_id > {{thread_id * 1000 + i}}")
# Don't fetch results, don't close cursors - maximum chaos time.sleep(0.005) # Let other threads interleave
# Start multiple daemon threads for i in range(3): t = threading.Thread(target=aggressive_worker, args=(i,), daemon=True) t.start()
# Let them run briefly then exit abruptly time.sleep(0.3) print("Exiting abruptly with active threads and pending queries") sys.exit(0) # Abrupt exit without joining threads """
result = subprocess.run( [sys.executable, "-c", code], capture_output=True, text=True )
# Should not segfault - should exit cleanly even with abrupt exit assert result.returncode == 0, f"Expected clean exit, but got exit code {result.returncode}. STDERR: {result.stderr}" assert "Exiting abruptly with active threads and pending queries" in result.stdout