diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/UcxWorkerWrapper.scala b/src/main/scala/org/apache/spark/shuffle/ucx/UcxWorkerWrapper.scala index 6ad88feb..512ea6cc 100755 --- a/src/main/scala/org/apache/spark/shuffle/ucx/UcxWorkerWrapper.scala +++ b/src/main/scala/org/apache/spark/shuffle/ucx/UcxWorkerWrapper.scala @@ -188,7 +188,14 @@ case class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, i def connectByWorkerAddress(executorId: transport.ExecutorId, workerAddress: ByteBuffer): Unit = { logDebug(s"Worker $this connecting back to $executorId by worker address") val ep = worker.newEndpoint(new UcpEndpointParams().setName(s"Server connection to $executorId") - .setUcpAddress(workerAddress)) + .setUcpAddress(workerAddress) + .setPeerErrorHandlingMode() + .setErrorHandler(new UcpEndpointErrorHandler() { + override def onError(ep: UcpEndpoint, status: Int, errorMsg: String): Unit = { + logError(s"Endpoint to $executorId got an error: $errorMsg") + connections.remove(executorId) + } + })) connections.put(executorId, ep) }