// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.36.6
// protoc v5.29.3
// source: api/v1/pyproc.proto
package pyprocv1
import (
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
reflect "reflect"
sync "sync"
unsafe "unsafe"
)
const (
// Verify that this generated code is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion)
// Verify that runtime/protoimpl is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
)
// CallRequest represents a generic method call request
type CallRequest struct {
state protoimpl.MessageState `protogen:"open.v1"`
// Unique request ID for correlation
Id uint64 `protobuf:"varint,1,opt,name=id,proto3" json:"id,omitempty"`
// Method name to invoke on the worker
Method string `protobuf:"bytes,2,opt,name=method,proto3" json:"method,omitempty"`
// JSON-encoded input data
Input []byte `protobuf:"bytes,3,opt,name=input,proto3" json:"input,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *CallRequest) Reset() {
*x = CallRequest{}
mi := &file_api_v1_pyproc_proto_msgTypes[0]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *CallRequest) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*CallRequest) ProtoMessage() {}
func (x *CallRequest) ProtoReflect() protoreflect.Message {
mi := &file_api_v1_pyproc_proto_msgTypes[0]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use CallRequest.ProtoReflect.Descriptor instead.
func (*CallRequest) Descriptor() ([]byte, []int) {
return file_api_v1_pyproc_proto_rawDescGZIP(), []int{0}
}
func (x *CallRequest) GetId() uint64 {
if x != nil {
return x.Id
}
return 0
}
func (x *CallRequest) GetMethod() string {
if x != nil {
return x.Method
}
return ""
}
func (x *CallRequest) GetInput() []byte {
if x != nil {
return x.Input
}
return nil
}
// CallResponse represents a generic method call response
type CallResponse struct {
state protoimpl.MessageState `protogen:"open.v1"`
// Request ID for correlation
Id uint64 `protobuf:"varint,1,opt,name=id,proto3" json:"id,omitempty"`
// Success status
Ok bool `protobuf:"varint,2,opt,name=ok,proto3" json:"ok,omitempty"`
// JSON-encoded response body
Body []byte `protobuf:"bytes,3,opt,name=body,proto3" json:"body,omitempty"`
// Error message if ok is false
ErrorMessage string `protobuf:"bytes,4,opt,name=error_message,json=errorMessage,proto3" json:"error_message,omitempty"`
// Error type for structured error handling
ErrorType string `protobuf:"bytes,5,opt,name=error_type,json=errorType,proto3" json:"error_type,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *CallResponse) Reset() {
*x = CallResponse{}
mi := &file_api_v1_pyproc_proto_msgTypes[1]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *CallResponse) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*CallResponse) ProtoMessage() {}
func (x *CallResponse) ProtoReflect() protoreflect.Message {
mi := &file_api_v1_pyproc_proto_msgTypes[1]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use CallResponse.ProtoReflect.Descriptor instead.
func (*CallResponse) Descriptor() ([]byte, []int) {
return file_api_v1_pyproc_proto_rawDescGZIP(), []int{1}
}
func (x *CallResponse) GetId() uint64 {
if x != nil {
return x.Id
}
return 0
}
func (x *CallResponse) GetOk() bool {
if x != nil {
return x.Ok
}
return false
}
func (x *CallResponse) GetBody() []byte {
if x != nil {
return x.Body
}
return nil
}
func (x *CallResponse) GetErrorMessage() string {
if x != nil {
return x.ErrorMessage
}
return ""
}
func (x *CallResponse) GetErrorType() string {
if x != nil {
return x.ErrorType
}
return ""
}
// HealthCheckRequest for worker health monitoring
type HealthCheckRequest struct {
state protoimpl.MessageState `protogen:"open.v1"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *HealthCheckRequest) Reset() {
*x = HealthCheckRequest{}
mi := &file_api_v1_pyproc_proto_msgTypes[2]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *HealthCheckRequest) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*HealthCheckRequest) ProtoMessage() {}
func (x *HealthCheckRequest) ProtoReflect() protoreflect.Message {
mi := &file_api_v1_pyproc_proto_msgTypes[2]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use HealthCheckRequest.ProtoReflect.Descriptor instead.
func (*HealthCheckRequest) Descriptor() ([]byte, []int) {
return file_api_v1_pyproc_proto_rawDescGZIP(), []int{2}
}
// HealthCheckResponse contains worker health information
type HealthCheckResponse struct {
state protoimpl.MessageState `protogen:"open.v1"`
// Overall health status
Healthy bool `protobuf:"varint,1,opt,name=healthy,proto3" json:"healthy,omitempty"`
// Worker uptime in seconds
UptimeSeconds int64 `protobuf:"varint,2,opt,name=uptime_seconds,json=uptimeSeconds,proto3" json:"uptime_seconds,omitempty"`
// Additional metadata about the worker
Metadata map[string]string `protobuf:"bytes,3,rep,name=metadata,proto3" json:"metadata,omitempty" protobuf_key:"bytes,1,opt,name=key" protobuf_val:"bytes,2,opt,name=value"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *HealthCheckResponse) Reset() {
*x = HealthCheckResponse{}
mi := &file_api_v1_pyproc_proto_msgTypes[3]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *HealthCheckResponse) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*HealthCheckResponse) ProtoMessage() {}
func (x *HealthCheckResponse) ProtoReflect() protoreflect.Message {
mi := &file_api_v1_pyproc_proto_msgTypes[3]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use HealthCheckResponse.ProtoReflect.Descriptor instead.
func (*HealthCheckResponse) Descriptor() ([]byte, []int) {
return file_api_v1_pyproc_proto_rawDescGZIP(), []int{3}
}
func (x *HealthCheckResponse) GetHealthy() bool {
if x != nil {
return x.Healthy
}
return false
}
func (x *HealthCheckResponse) GetUptimeSeconds() int64 {
if x != nil {
return x.UptimeSeconds
}
return 0
}
func (x *HealthCheckResponse) GetMetadata() map[string]string {
if x != nil {
return x.Metadata
}
return nil
}
var File_api_v1_pyproc_proto protoreflect.FileDescriptor
const file_api_v1_pyproc_proto_rawDesc = "" +
"\n" +
"\x13api/v1/pyproc.proto\x12\tpyproc.v1\"K\n" +
"\vCallRequest\x12\x0e\n" +
"\x02id\x18\x01 \x01(\x04R\x02id\x12\x16\n" +
"\x06method\x18\x02 \x01(\tR\x06method\x12\x14\n" +
"\x05input\x18\x03 \x01(\fR\x05input\"\x86\x01\n" +
"\fCallResponse\x12\x0e\n" +
"\x02id\x18\x01 \x01(\x04R\x02id\x12\x0e\n" +
"\x02ok\x18\x02 \x01(\bR\x02ok\x12\x12\n" +
"\x04body\x18\x03 \x01(\fR\x04body\x12#\n" +
"\rerror_message\x18\x04 \x01(\tR\ferrorMessage\x12\x1d\n" +
"\n" +
"error_type\x18\x05 \x01(\tR\terrorType\"\x14\n" +
"\x12HealthCheckRequest\"\xdd\x01\n" +
"\x13HealthCheckResponse\x12\x18\n" +
"\ahealthy\x18\x01 \x01(\bR\ahealthy\x12%\n" +
"\x0euptime_seconds\x18\x02 \x01(\x03R\ruptimeSeconds\x12H\n" +
"\bmetadata\x18\x03 \x03(\v2,.pyproc.v1.HealthCheckResponse.MetadataEntryR\bmetadata\x1a;\n" +
"\rMetadataEntry\x12\x10\n" +
"\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n" +
"\x05value\x18\x02 \x01(\tR\x05value:\x028\x012\x96\x01\n" +
"\rPyProcService\x127\n" +
"\x04Call\x12\x16.pyproc.v1.CallRequest\x1a\x17.pyproc.v1.CallResponse\x12L\n" +
"\vHealthCheck\x12\x1d.pyproc.v1.HealthCheckRequest\x1a\x1e.pyproc.v1.HealthCheckResponseB2Z0github.com/YuminosukeSato/pyproc/api/v1;pyprocv1b\x06proto3"
var (
file_api_v1_pyproc_proto_rawDescOnce sync.Once
file_api_v1_pyproc_proto_rawDescData []byte
)
func file_api_v1_pyproc_proto_rawDescGZIP() []byte {
file_api_v1_pyproc_proto_rawDescOnce.Do(func() {
file_api_v1_pyproc_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_api_v1_pyproc_proto_rawDesc), len(file_api_v1_pyproc_proto_rawDesc)))
})
return file_api_v1_pyproc_proto_rawDescData
}
var file_api_v1_pyproc_proto_msgTypes = make([]protoimpl.MessageInfo, 5)
var file_api_v1_pyproc_proto_goTypes = []any{
(*CallRequest)(nil), // 0: pyproc.v1.CallRequest
(*CallResponse)(nil), // 1: pyproc.v1.CallResponse
(*HealthCheckRequest)(nil), // 2: pyproc.v1.HealthCheckRequest
(*HealthCheckResponse)(nil), // 3: pyproc.v1.HealthCheckResponse
nil, // 4: pyproc.v1.HealthCheckResponse.MetadataEntry
}
var file_api_v1_pyproc_proto_depIdxs = []int32{
4, // 0: pyproc.v1.HealthCheckResponse.metadata:type_name -> pyproc.v1.HealthCheckResponse.MetadataEntry
0, // 1: pyproc.v1.PyProcService.Call:input_type -> pyproc.v1.CallRequest
2, // 2: pyproc.v1.PyProcService.HealthCheck:input_type -> pyproc.v1.HealthCheckRequest
1, // 3: pyproc.v1.PyProcService.Call:output_type -> pyproc.v1.CallResponse
3, // 4: pyproc.v1.PyProcService.HealthCheck:output_type -> pyproc.v1.HealthCheckResponse
3, // [3:5] is the sub-list for method output_type
1, // [1:3] is the sub-list for method input_type
1, // [1:1] is the sub-list for extension type_name
1, // [1:1] is the sub-list for extension extendee
0, // [0:1] is the sub-list for field type_name
}
func init() { file_api_v1_pyproc_proto_init() }
func file_api_v1_pyproc_proto_init() {
if File_api_v1_pyproc_proto != nil {
return
}
type x struct{}
out := protoimpl.TypeBuilder{
File: protoimpl.DescBuilder{
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: unsafe.Slice(unsafe.StringData(file_api_v1_pyproc_proto_rawDesc), len(file_api_v1_pyproc_proto_rawDesc)),
NumEnums: 0,
NumMessages: 5,
NumExtensions: 0,
NumServices: 1,
},
GoTypes: file_api_v1_pyproc_proto_goTypes,
DependencyIndexes: file_api_v1_pyproc_proto_depIdxs,
MessageInfos: file_api_v1_pyproc_proto_msgTypes,
}.Build()
File_api_v1_pyproc_proto = out.File
file_api_v1_pyproc_proto_goTypes = nil
file_api_v1_pyproc_proto_depIdxs = nil
}
// Code generated by protoc-gen-go-grpc. DO NOT EDIT.
// versions:
// - protoc-gen-go-grpc v1.5.1
// - protoc v5.29.3
// source: api/v1/pyproc.proto
package pyprocv1
import (
context "context"
grpc "google.golang.org/grpc"
codes "google.golang.org/grpc/codes"
status "google.golang.org/grpc/status"
)
// This is a compile-time assertion to ensure that this generated file
// is compatible with the grpc package it is being compiled against.
// Requires gRPC-Go v1.64.0 or later.
const _ = grpc.SupportPackageIsVersion9
const (
PyProcService_Call_FullMethodName = "/pyproc.v1.PyProcService/Call"
PyProcService_HealthCheck_FullMethodName = "/pyproc.v1.PyProcService/HealthCheck"
)
// PyProcServiceClient is the client API for PyProcService service.
//
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream.
//
// PyProcService defines the gRPC service for Python worker communication
type PyProcServiceClient interface {
// Generic call for backward compatibility with existing JSON protocol
Call(ctx context.Context, in *CallRequest, opts ...grpc.CallOption) (*CallResponse, error)
// Health check RPC for worker health monitoring
HealthCheck(ctx context.Context, in *HealthCheckRequest, opts ...grpc.CallOption) (*HealthCheckResponse, error)
}
type pyProcServiceClient struct {
cc grpc.ClientConnInterface
}
func NewPyProcServiceClient(cc grpc.ClientConnInterface) PyProcServiceClient {
return &pyProcServiceClient{cc}
}
func (c *pyProcServiceClient) Call(ctx context.Context, in *CallRequest, opts ...grpc.CallOption) (*CallResponse, error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
out := new(CallResponse)
err := c.cc.Invoke(ctx, PyProcService_Call_FullMethodName, in, out, cOpts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *pyProcServiceClient) HealthCheck(ctx context.Context, in *HealthCheckRequest, opts ...grpc.CallOption) (*HealthCheckResponse, error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
out := new(HealthCheckResponse)
err := c.cc.Invoke(ctx, PyProcService_HealthCheck_FullMethodName, in, out, cOpts...)
if err != nil {
return nil, err
}
return out, nil
}
// PyProcServiceServer is the server API for PyProcService service.
// All implementations must embed UnimplementedPyProcServiceServer
// for forward compatibility.
//
// PyProcService defines the gRPC service for Python worker communication
type PyProcServiceServer interface {
// Generic call for backward compatibility with existing JSON protocol
Call(context.Context, *CallRequest) (*CallResponse, error)
// Health check RPC for worker health monitoring
HealthCheck(context.Context, *HealthCheckRequest) (*HealthCheckResponse, error)
mustEmbedUnimplementedPyProcServiceServer()
}
// UnimplementedPyProcServiceServer must be embedded to have
// forward compatible implementations.
//
// NOTE: this should be embedded by value instead of pointer to avoid a nil
// pointer dereference when methods are called.
type UnimplementedPyProcServiceServer struct{}
func (UnimplementedPyProcServiceServer) Call(context.Context, *CallRequest) (*CallResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method Call not implemented")
}
func (UnimplementedPyProcServiceServer) HealthCheck(context.Context, *HealthCheckRequest) (*HealthCheckResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method HealthCheck not implemented")
}
func (UnimplementedPyProcServiceServer) mustEmbedUnimplementedPyProcServiceServer() {}
func (UnimplementedPyProcServiceServer) testEmbeddedByValue() {}
// UnsafePyProcServiceServer may be embedded to opt out of forward compatibility for this service.
// Use of this interface is not recommended, as added methods to PyProcServiceServer will
// result in compilation errors.
type UnsafePyProcServiceServer interface {
mustEmbedUnimplementedPyProcServiceServer()
}
func RegisterPyProcServiceServer(s grpc.ServiceRegistrar, srv PyProcServiceServer) {
// If the following call pancis, it indicates UnimplementedPyProcServiceServer was
// embedded by pointer and is nil. This will cause panics if an
// unimplemented method is ever invoked, so we test this at initialization
// time to prevent it from happening at runtime later due to I/O.
if t, ok := srv.(interface{ testEmbeddedByValue() }); ok {
t.testEmbeddedByValue()
}
s.RegisterService(&PyProcService_ServiceDesc, srv)
}
func _PyProcService_Call_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(CallRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(PyProcServiceServer).Call(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: PyProcService_Call_FullMethodName,
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(PyProcServiceServer).Call(ctx, req.(*CallRequest))
}
return interceptor(ctx, in, info, handler)
}
func _PyProcService_HealthCheck_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(HealthCheckRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(PyProcServiceServer).HealthCheck(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: PyProcService_HealthCheck_FullMethodName,
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(PyProcServiceServer).HealthCheck(ctx, req.(*HealthCheckRequest))
}
return interceptor(ctx, in, info, handler)
}
// PyProcService_ServiceDesc is the grpc.ServiceDesc for PyProcService service.
// It's only intended for direct use with grpc.RegisterService,
// and not to be introspected or modified (even as a copy)
var PyProcService_ServiceDesc = grpc.ServiceDesc{
ServiceName: "pyproc.v1.PyProcService",
HandlerType: (*PyProcServiceServer)(nil),
Methods: []grpc.MethodDesc{
{
MethodName: "Call",
Handler: _PyProcService_Call_Handler,
},
{
MethodName: "HealthCheck",
Handler: _PyProcService_HealthCheck_Handler,
},
},
Streams: []grpc.StreamDesc{},
Metadata: "api/v1/pyproc.proto",
}
// Package rpc_clients provides a common interface for different RPC protocols
// to enable fair performance comparison over Unix Domain Sockets.
package rpc_clients
import (
"context"
"time"
)
// RPCClient defines the common interface for all RPC protocol implementations.
// Each implementation must provide connection management and call functionality.
type RPCClient interface {
// Connect establishes a connection to the RPC server via Unix Domain Socket
Connect(udsPath string) error
// Call invokes a remote method with given arguments and stores the reply
Call(ctx context.Context, method string, args interface{}, reply interface{}) error
// Close terminates the connection and cleans up resources
Close() error
// Name returns the protocol name for identification in benchmarks
Name() string
}
// BenchmarkConfig holds configuration for benchmark testing
type BenchmarkConfig struct {
UDSPath string // Unix Domain Socket path
Timeout time.Duration // Request timeout
MaxRetries int // Maximum retry attempts
WarmupCalls int // Number of warmup calls before measurement
}
// TestPayload represents different payload sizes for benchmarking
type TestPayload struct {
Size string // "small", "medium", "large"
Method string // RPC method to call
Data interface{} // Actual payload data
}
// Small payload (~64 bytes) - Simple operation
func SmallPayload() TestPayload {
return TestPayload{
Size: "small",
Method: "predict",
Data: map[string]interface{}{
"value": 42,
},
}
}
// Medium payload (~2KB) - Typical API request
func MediumPayload() TestPayload {
values := make([]int, 100)
for i := range values {
values[i] = i + 1
}
return TestPayload{
Size: "medium",
Method: "process_batch",
Data: map[string]interface{}{
"values": values,
"metadata": map[string]interface{}{
"user_id": "test-user-123",
"timestamp": time.Now().Unix(),
"version": "1.0.0",
},
},
}
}
// Large payload (~1MB) - Data transfer scenario
func LargePayload() TestPayload {
// Generate ~1MB of data
numbers := make([]int, 100000)
for i := range numbers {
numbers[i] = i % 1000
}
return TestPayload{
Size: "large",
Method: "compute_stats",
Data: map[string]interface{}{
"numbers": numbers,
"options": map[string]interface{}{
"compute_variance": true,
"compute_std_dev": true,
"compute_median": true,
},
},
}
}
// BenchmarkResult stores the results of a benchmark run
type BenchmarkResult struct {
Protocol string // Protocol name
PayloadSize string // Payload size category
Latency time.Duration // Average latency
Throughput float64 // Requests per second
ErrorRate float64 // Percentage of failed requests
CPUUsage float64 // Average CPU usage percentage
MemoryUsage int64 // Memory usage in bytes
}
package rpc_clients
import (
"context"
"encoding/json"
"fmt"
"net"
"sync"
"sync/atomic"
)
// JSONRPCClient implements JSON-RPC 2.0 protocol over Unix Domain Socket
type JSONRPCClient struct {
conn net.Conn
udsPath string
requestID uint64
mu sync.Mutex
}
// JSONRPCRequest represents a JSON-RPC 2.0 request
type JSONRPCRequest struct {
JSONRPC string `json:"jsonrpc"`
Method string `json:"method"`
Params interface{} `json:"params"`
ID uint64 `json:"id"`
}
// JSONRPCResponse represents a JSON-RPC 2.0 response
type JSONRPCResponse struct {
JSONRPC string `json:"jsonrpc"`
Result json.RawMessage `json:"result,omitempty"`
Error *JSONRPCError `json:"error,omitempty"`
ID uint64 `json:"id"`
}
// JSONRPCError represents a JSON-RPC 2.0 error
type JSONRPCError struct {
Code int `json:"code"`
Message string `json:"message"`
Data interface{} `json:"data,omitempty"`
}
// NewJSONRPCClient creates a new JSON-RPC client
func NewJSONRPCClient() *JSONRPCClient {
return &JSONRPCClient{}
}
// Connect establishes connection to JSON-RPC server via UDS
func (c *JSONRPCClient) Connect(udsPath string) error {
conn, err := net.Dial("unix", udsPath)
if err != nil {
return fmt.Errorf("failed to connect to JSON-RPC server: %w", err)
}
c.conn = conn
c.udsPath = udsPath
return nil
}
// Call invokes a JSON-RPC method
func (c *JSONRPCClient) Call(ctx context.Context, method string, args interface{}, reply interface{}) error {
if c.conn == nil {
return fmt.Errorf("not connected")
}
// Generate unique request ID
id := atomic.AddUint64(&c.requestID, 1)
// Create JSON-RPC request
request := JSONRPCRequest{
JSONRPC: "2.0",
Method: method,
Params: args,
ID: id,
}
// Send request
c.mu.Lock()
encoder := json.NewEncoder(c.conn)
if err := encoder.Encode(request); err != nil {
c.mu.Unlock()
return fmt.Errorf("failed to send request: %w", err)
}
// Receive response
decoder := json.NewDecoder(c.conn)
var response JSONRPCResponse
if err := decoder.Decode(&response); err != nil {
c.mu.Unlock()
return fmt.Errorf("failed to receive response: %w", err)
}
c.mu.Unlock()
// Check for error
if response.Error != nil {
return fmt.Errorf("JSON-RPC error %d: %s", response.Error.Code, response.Error.Message)
}
// Unmarshal result
if reply != nil && response.Result != nil {
if err := json.Unmarshal(response.Result, reply); err != nil {
return fmt.Errorf("failed to unmarshal result: %w", err)
}
}
return nil
}
// Close terminates the connection
func (c *JSONRPCClient) Close() error {
if c.conn != nil {
return c.conn.Close()
}
return nil
}
// Name returns the protocol identifier
func (c *JSONRPCClient) Name() string {
return "json-rpc"
}
package rpc_clients
import (
"bytes"
"context"
"encoding/binary"
"fmt"
"net"
"sync"
"sync/atomic"
"github.com/vmihailenco/msgpack/v5"
)
// MsgpackRPCClient implements MessagePack-RPC protocol over Unix Domain Socket
type MsgpackRPCClient struct {
conn net.Conn
udsPath string
requestID uint32
mu sync.Mutex
}
// MsgpackRequest represents a MessagePack-RPC request
// Format: [type, msgid, method, params]
type MsgpackRequest struct {
Type uint8 // 0 for request
MsgID uint32 // Message ID
Method string // Method name
Params interface{} // Parameters
}
// MsgpackResponse represents a MessagePack-RPC response
// Format: [type, msgid, error, result]
type MsgpackResponse struct {
Type uint8 // 1 for response
MsgID uint32 // Message ID
Error interface{} // Error if any
Result interface{} // Result value
}
// NewMsgpackRPCClient creates a new MessagePack-RPC client
func NewMsgpackRPCClient() *MsgpackRPCClient {
return &MsgpackRPCClient{}
}
// Connect establishes connection to MessagePack-RPC server via UDS
func (c *MsgpackRPCClient) Connect(udsPath string) error {
conn, err := net.Dial("unix", udsPath)
if err != nil {
return fmt.Errorf("failed to connect to MessagePack-RPC server: %w", err)
}
c.conn = conn
c.udsPath = udsPath
return nil
}
// Call invokes a MessagePack-RPC method
func (c *MsgpackRPCClient) Call(ctx context.Context, method string, args interface{}, reply interface{}) error {
if c.conn == nil {
return fmt.Errorf("not connected")
}
// Generate unique request ID
msgID := atomic.AddUint32(&c.requestID, 1)
// Create MessagePack-RPC request array
request := []interface{}{
uint8(0), // Request type
msgID, // Message ID
method, // Method name
args, // Parameters
}
// Encode request
var buf bytes.Buffer
encoder := msgpack.NewEncoder(&buf)
if err := encoder.Encode(request); err != nil {
return fmt.Errorf("failed to encode request: %w", err)
}
// Send request with length prefix (4 bytes)
reqData := buf.Bytes()
lenBuf := make([]byte, 4)
binary.BigEndian.PutUint32(lenBuf, uint32(len(reqData)))
c.mu.Lock()
if _, err := c.conn.Write(lenBuf); err != nil {
c.mu.Unlock()
return fmt.Errorf("failed to send length prefix: %w", err)
}
if _, err := c.conn.Write(reqData); err != nil {
c.mu.Unlock()
return fmt.Errorf("failed to send request: %w", err)
}
// Read response length
if _, err := c.conn.Read(lenBuf); err != nil {
c.mu.Unlock()
return fmt.Errorf("failed to read response length: %w", err)
}
respLen := binary.BigEndian.Uint32(lenBuf)
// Read response data
respData := make([]byte, respLen)
if _, err := c.conn.Read(respData); err != nil {
c.mu.Unlock()
return fmt.Errorf("failed to read response: %w", err)
}
c.mu.Unlock()
// Decode response
decoder := msgpack.NewDecoder(bytes.NewReader(respData))
var response []interface{}
if err := decoder.Decode(&response); err != nil {
return fmt.Errorf("failed to decode response: %w", err)
}
// Validate response format
if len(response) != 4 {
return fmt.Errorf("invalid response format")
}
// Check message type (should be 1 for response)
if respType, ok := response[0].(uint8); !ok || respType != 1 {
return fmt.Errorf("invalid response type")
}
// Check message ID matches
if respID, ok := response[1].(uint32); !ok || respID != msgID {
return fmt.Errorf("message ID mismatch")
}
// Check for error
if response[2] != nil {
return fmt.Errorf("MessagePack-RPC error: %v", response[2])
}
// Extract result
if reply != nil && response[3] != nil {
// Convert response[3] to the expected reply type
// This is simplified for benchmark purposes
if m, ok := reply.(*map[string]interface{}); ok {
if result, ok := response[3].(map[string]interface{}); ok {
*m = result
} else {
*m = map[string]interface{}{
"result": response[3],
}
}
}
}
return nil
}
// Close terminates the connection
func (c *MsgpackRPCClient) Close() error {
if c.conn != nil {
return c.conn.Close()
}
return nil
}
// Name returns the protocol identifier
func (c *MsgpackRPCClient) Name() string {
return "msgpack-rpc"
}
package rpc_clients
import (
"context"
"fmt"
"time"
"github.com/YuminosukeSato/pyproc/pkg/pyproc"
)
// PyprocClient wraps the existing pyproc Pool to implement RPCClient interface
type PyprocClient struct {
pool *pyproc.Pool
config pyproc.WorkerConfig
}
// NewPyprocClient creates a new pyproc client instance
func NewPyprocClient(pythonExec, workerScript string) *PyprocClient {
return &PyprocClient{
config: pyproc.WorkerConfig{
PythonExec: pythonExec,
WorkerScript: workerScript,
StartTimeout: 5 * time.Second,
},
}
}
// Connect establishes connection to the pyproc worker
func (c *PyprocClient) Connect(udsPath string) error {
c.config.SocketPath = udsPath
c.config.ID = fmt.Sprintf("pyproc-bench-%d", time.Now().Unix())
// Create a pool with a single worker for fair comparison
pool, err := pyproc.NewPool(pyproc.PoolOptions{
Config: pyproc.PoolConfig{
Workers: 1,
MaxInFlight: 10,
},
WorkerConfig: c.config,
}, nil)
if err != nil {
return fmt.Errorf("failed to create pyproc pool: %w", err)
}
ctx := context.Background()
if err := pool.Start(ctx); err != nil {
return fmt.Errorf("failed to start pyproc pool: %w", err)
}
// Give it time to initialize
time.Sleep(100 * time.Millisecond)
c.pool = pool
return nil
}
// Call invokes a method on the pyproc worker
func (c *PyprocClient) Call(ctx context.Context, method string, args interface{}, reply interface{}) error {
if c.pool == nil {
return fmt.Errorf("pool not connected")
}
return c.pool.Call(ctx, method, args, reply)
}
// Close shuts down the pyproc worker
func (c *PyprocClient) Close() error {
if c.pool == nil {
return nil
}
ctx := context.Background()
return c.pool.Shutdown(ctx)
}
// Name returns the protocol identifier
func (c *PyprocClient) Name() string {
return "pyproc"
}
package rpc_clients
import (
"bytes"
"context"
"encoding/xml"
"fmt"
"io"
"net"
"net/http"
"sync"
)
// XMLRPCClient implements XML-RPC protocol over Unix Domain Socket
type XMLRPCClient struct {
httpClient *http.Client
udsPath string
mu sync.Mutex
}
// NewXMLRPCClient creates a new XML-RPC client
func NewXMLRPCClient() *XMLRPCClient {
return &XMLRPCClient{}
}
// Connect establishes connection to XML-RPC server via UDS
func (c *XMLRPCClient) Connect(udsPath string) error {
// Create HTTP client with Unix Domain Socket transport
c.httpClient = &http.Client{
Transport: &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return net.Dial("unix", udsPath)
},
},
}
c.udsPath = udsPath
return nil
}
// Call invokes an XML-RPC method
func (c *XMLRPCClient) Call(ctx context.Context, method string, args interface{}, reply interface{}) error {
if c.httpClient == nil {
return fmt.Errorf("not connected")
}
// Create XML-RPC request
request, err := c.encodeRequest(method, args)
if err != nil {
return fmt.Errorf("failed to encode request: %w", err)
}
// Send HTTP POST request
req, err := http.NewRequestWithContext(ctx, "POST", "http://unix/RPC2", bytes.NewReader(request))
if err != nil {
return fmt.Errorf("failed to create HTTP request: %w", err)
}
req.Header.Set("Content-Type", "text/xml")
c.mu.Lock()
resp, err := c.httpClient.Do(req)
c.mu.Unlock()
if err != nil {
return fmt.Errorf("failed to send request: %w", err)
}
defer func() { _ = resp.Body.Close() }()
// Read response
body, err := io.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("failed to read response: %w", err)
}
// Decode XML-RPC response
return c.decodeResponse(body, reply)
}
// encodeRequest encodes method call to XML-RPC format
func (c *XMLRPCClient) encodeRequest(method string, args interface{}) ([]byte, error) {
var buf bytes.Buffer
// Write XML header
buf.WriteString(`<?xml version="1.0"?>`)
buf.WriteString(`<methodCall>`)
buf.WriteString(`<methodName>` + method + `</methodName>`)
buf.WriteString(`<params>`)
// Encode parameters
if args != nil {
buf.WriteString(`<param>`)
if err := c.encodeValue(&buf, args); err != nil {
return nil, err
}
buf.WriteString(`</param>`)
}
buf.WriteString(`</params>`)
buf.WriteString(`</methodCall>`)
return buf.Bytes(), nil
}
// encodeValue encodes a value to XML-RPC format
func (c *XMLRPCClient) encodeValue(buf *bytes.Buffer, v interface{}) error {
buf.WriteString(`<value>`)
switch val := v.(type) {
case int:
buf.WriteString(fmt.Sprintf(`<int>%d</int>`, val))
case string:
buf.WriteString(`<string>`)
_ = xml.EscapeText(buf, []byte(val))
buf.WriteString(`</string>`)
case map[string]interface{}:
buf.WriteString(`<struct>`)
for k, v := range val {
buf.WriteString(`<member>`)
buf.WriteString(`<name>` + k + `</name>`)
if err := c.encodeValue(buf, v); err != nil {
return err
}
buf.WriteString(`</member>`)
}
buf.WriteString(`</struct>`)
case []interface{}:
buf.WriteString(`<array><data>`)
for _, item := range val {
if err := c.encodeValue(buf, item); err != nil {
return err
}
}
buf.WriteString(`</data></array>`)
default:
// Simplified encoding for benchmark purposes
buf.WriteString(fmt.Sprintf(`<string>%v</string>`, val))
}
buf.WriteString(`</value>`)
return nil
}
// decodeResponse decodes XML-RPC response
func (c *XMLRPCClient) decodeResponse(data []byte, reply interface{}) error {
// Simplified XML parsing for benchmark purposes
// In production, use proper XML-RPC library
// Check for fault
if bytes.Contains(data, []byte("<fault>")) {
return fmt.Errorf("XML-RPC fault in response")
}
// Extract result value (simplified)
// In real implementation, properly parse XML structure
if reply != nil {
// For benchmark purposes, we'll just set a simple result
if m, ok := reply.(*map[string]interface{}); ok {
*m = map[string]interface{}{
"result": "processed",
}
}
}
return nil
}
// Close terminates the connection
func (c *XMLRPCClient) Close() error {
if c.httpClient != nil {
c.httpClient.CloseIdleConnections()
}
return nil
}
// Name returns the protocol identifier
func (c *XMLRPCClient) Name() string {
return "xml-rpc"
}
// Package main provides the pyproc CLI for scaffolding Python worker projects.
package main
import (
"embed"
"fmt"
"os"
"path/filepath"
"text/template"
"github.com/spf13/cobra"
)
//go:embed templates/*
var templates embed.FS
var rootCmd = &cobra.Command{
Use: "pyproc",
Short: "PyProc - Call Python from Go without CGO",
Long: `PyProc is a high-performance IPC library for Go and Python integration.
It uses Unix domain sockets for fast, secure communication between Go and Python processes.`,
Version: "0.1.0",
}
var initCmd = &cobra.Command{
Use: "init [project-name]",
Short: "Initialize a new PyProc project",
Long: `Creates a new PyProc project with Go and Python scaffolding.`,
Args: cobra.MaximumNArgs(1),
RunE: runInit,
}
var scaffoldCmd = &cobra.Command{
Use: "scaffold [type]",
Short: "Generate scaffold code",
Long: `Generate scaffold code for Go or Python workers.`,
Args: cobra.ExactArgs(1),
RunE: runScaffold,
}
func init() {
rootCmd.AddCommand(initCmd)
rootCmd.AddCommand(scaffoldCmd)
initCmd.Flags().String("go-module", "", "Go module name (e.g., github.com/user/project)")
initCmd.Flags().Bool("with-docker", false, "Include Docker Compose configuration")
initCmd.Flags().Bool("with-k8s", false, "Include Kubernetes manifests")
scaffoldCmd.Flags().String("name", "worker", "Name of the worker")
scaffoldCmd.Flags().String("output", ".", "Output directory")
}
func main() {
if err := rootCmd.Execute(); err != nil {
fmt.Fprintln(os.Stderr, err)
os.Exit(1)
}
}
func runInit(cmd *cobra.Command, args []string) error {
projectName := "pyproc-app"
if len(args) > 0 {
projectName = args[0]
}
goModule, _ := cmd.Flags().GetString("go-module")
withDocker, _ := cmd.Flags().GetBool("with-docker")
withK8s, _ := cmd.Flags().GetBool("with-k8s")
if goModule == "" {
goModule = fmt.Sprintf("github.com/example/%s", projectName)
}
// Create project directory
if err := os.MkdirAll(projectName, 0755); err != nil {
return fmt.Errorf("failed to create project directory: %w", err)
}
// Generate project structure
dirs := []string{
filepath.Join(projectName, "cmd", "app"),
filepath.Join(projectName, "worker", "python"),
filepath.Join(projectName, "api"),
}
for _, dir := range dirs {
if err := os.MkdirAll(dir, 0755); err != nil {
return fmt.Errorf("failed to create directory %s: %w", dir, err)
}
}
// Template data
data := struct {
ProjectName string
GoModule string
}{
ProjectName: projectName,
GoModule: goModule,
}
// Generate files from templates
files := map[string]string{
"templates/go.mod.tmpl": filepath.Join(projectName, "go.mod"),
"templates/main.go.tmpl": filepath.Join(projectName, "cmd", "app", "main.go"),
"templates/worker.py.tmpl": filepath.Join(projectName, "worker", "python", "worker.py"),
"templates/requirements.txt.tmpl": filepath.Join(projectName, "worker", "python", "requirements.txt"),
"templates/README.md.tmpl": filepath.Join(projectName, "README.md"),
"templates/Makefile.tmpl": filepath.Join(projectName, "Makefile"),
}
if withDocker {
files["templates/docker-compose.yml.tmpl"] = filepath.Join(projectName, "docker-compose.yml")
files["templates/Dockerfile.go.tmpl"] = filepath.Join(projectName, "Dockerfile.go")
files["templates/Dockerfile.python.tmpl"] = filepath.Join(projectName, "Dockerfile.python")
}
if withK8s {
k8sDir := filepath.Join(projectName, "k8s")
if err := os.MkdirAll(k8sDir, 0755); err != nil {
return fmt.Errorf("failed to create k8s directory: %w", err)
}
files["templates/k8s-deployment.yaml.tmpl"] = filepath.Join(k8sDir, "deployment.yaml")
files["templates/k8s-service.yaml.tmpl"] = filepath.Join(k8sDir, "service.yaml")
}
for tmplPath, outPath := range files {
if err := generateFromTemplate(tmplPath, outPath, data); err != nil {
return fmt.Errorf("failed to generate %s: %w", outPath, err)
}
}
fmt.Printf("✅ Created PyProc project: %s\n", projectName)
fmt.Printf("\nNext steps:\n")
fmt.Printf(" cd %s\n", projectName)
fmt.Printf(" go mod tidy\n")
fmt.Printf(" pip install -r worker/python/requirements.txt\n")
fmt.Printf(" make run\n")
return nil
}
func runScaffold(cmd *cobra.Command, args []string) error {
scaffoldType := args[0]
name, _ := cmd.Flags().GetString("name")
output, _ := cmd.Flags().GetString("output")
data := struct {
Name string
}{
Name: name,
}
switch scaffoldType {
case "go", "golang":
outPath := filepath.Join(output, fmt.Sprintf("%s_client.go", name))
return generateFromTemplate("templates/scaffold_go.tmpl", outPath, data)
case "python", "py":
outPath := filepath.Join(output, fmt.Sprintf("%s_worker.py", name))
return generateFromTemplate("templates/scaffold_python.tmpl", outPath, data)
default:
return fmt.Errorf("unknown scaffold type: %s (use 'go' or 'python')", scaffoldType)
}
}
func generateFromTemplate(tmplPath, outPath string, data interface{}) error {
// Read template from embedded files
tmplContent, err := templates.ReadFile(tmplPath)
if err != nil {
return fmt.Errorf("failed to read template: %w", err)
}
// Parse and execute template
tmpl, err := template.New(filepath.Base(tmplPath)).Parse(string(tmplContent))
if err != nil {
return fmt.Errorf("failed to parse template: %w", err)
}
// Create output file
outFile, err := os.Create(outPath)
if err != nil {
return fmt.Errorf("failed to create output file: %w", err)
}
defer func() { _ = outFile.Close() }()
// Execute template
if err := tmpl.Execute(outFile, data); err != nil {
return fmt.Errorf("failed to execute template: %w", err)
}
fmt.Printf("Generated: %s\n", outPath)
return nil
}
// Package main provides a basic example of using pyproc.
package main
import (
"context"
"encoding/json"
"fmt"
"log"
"path/filepath"
"time"
"github.com/YuminosukeSato/pyproc/internal/framing"
"github.com/YuminosukeSato/pyproc/internal/protocol"
"github.com/YuminosukeSato/pyproc/pkg/pyproc"
)
func main() {
// Create logger
logger := pyproc.NewLogger(pyproc.LoggingConfig{
Level: "info",
Format: "text",
TraceEnabled: true,
})
// Get the worker script path relative to the repo root
workerScript, err := filepath.Abs(filepath.Join("examples", "basic", "worker.py"))
if err != nil {
log.Fatal(err)
}
// Socket path
socketPath := "/tmp/pyproc-example.sock"
// Create worker configuration
cfg := pyproc.WorkerConfig{
ID: "example-worker",
SocketPath: socketPath,
PythonExec: "python3",
WorkerScript: workerScript,
StartTimeout: 10 * time.Second,
}
// Create and start worker
ctx := context.Background()
worker := pyproc.NewWorker(cfg, logger)
fmt.Println("Starting Python worker...")
if err := worker.Start(ctx); err != nil {
log.Fatalf("Failed to start worker: %v", err)
}
defer func() {
fmt.Println("\nStopping worker...")
_ = worker.Stop()
}()
fmt.Printf("Worker started (PID: %d)\n\n", worker.GetPID())
// Connect to the worker
conn, err := pyproc.ConnectToWorker(socketPath, 5*time.Second)
if err != nil {
log.Fatalf("Failed to connect to worker: %v", err)
}
defer func() { _ = conn.Close() }()
framer := framing.NewFramer(conn)
// Example 1: Simple prediction
fmt.Println("=== Example 1: Simple Prediction ===")
if err := callPredict(framer, 42); err != nil {
log.Printf("Predict failed: %v", err)
}
// Example 2: Batch processing
fmt.Println("\n=== Example 2: Batch Processing ===")
if err := callBatchProcess(framer, []int{1, 2, 3, 4, 5}); err != nil {
log.Printf("Batch process failed: %v", err)
}
// Example 3: Text transformation
fmt.Println("\n=== Example 3: Text Transformation ===")
if err := callTextTransform(framer, "Hello PyProc!", "upper"); err != nil {
log.Printf("Text transform failed: %v", err)
}
if err := callTextTransform(framer, "Hello PyProc!", "reverse"); err != nil {
log.Printf("Text transform failed: %v", err)
}
// Example 4: Compute statistics
fmt.Println("\n=== Example 4: Compute Statistics ===")
if err := callComputeStats(framer, []float64{1.5, 2.5, 3.5, 4.5, 5.5}); err != nil {
log.Printf("Compute stats failed: %v", err)
}
// Example 5: Health check
fmt.Println("\n=== Example 5: Health Check ===")
if err := callHealth(framer); err != nil {
log.Printf("Health check failed: %v", err)
}
fmt.Println("\n✅ All examples completed successfully!")
}
func callPredict(framer *framing.Framer, value int) error {
req, err := protocol.NewRequest(1, "predict", map[string]interface{}{
"value": value,
})
if err != nil {
return err
}
// Send request
reqData, _ := req.Marshal()
if err := framer.WriteMessage(reqData); err != nil {
return err
}
// Read response
respData, err := framer.ReadMessage()
if err != nil {
return err
}
var resp protocol.Response
if err := resp.Unmarshal(respData); err != nil {
return err
}
if !resp.OK {
return fmt.Errorf("error from Python: %s", resp.ErrorMsg)
}
var result map[string]interface{}
if err := resp.UnmarshalBody(&result); err != nil {
return err
}
fmt.Printf("Input: %d\n", value)
fmt.Printf("Result: %.0f (model: %s, confidence: %.2f)\n",
result["result"], result["model"], result["confidence"])
return nil
}
func callBatchProcess(framer *framing.Framer, values []int) error {
req, err := protocol.NewRequest(2, "process_batch", map[string]interface{}{
"values": values,
})
if err != nil {
return err
}
// Send request
reqData, _ := req.Marshal()
if err := framer.WriteMessage(reqData); err != nil {
return err
}
// Read response
respData, err := framer.ReadMessage()
if err != nil {
return err
}
var resp protocol.Response
if err := resp.Unmarshal(respData); err != nil {
return err
}
if !resp.OK {
return fmt.Errorf("error from Python: %s", resp.ErrorMsg)
}
var result map[string]interface{}
if err := resp.UnmarshalBody(&result); err != nil {
return err
}
fmt.Printf("Input values: %v\n", values)
fmt.Printf("Processed results: %v\n", result["results"])
fmt.Printf("Count: %.0f, Sum: %.0f\n", result["count"], result["sum"])
return nil
}
func callTextTransform(framer *framing.Framer, text, operation string) error {
req, err := protocol.NewRequest(3, "transform_text", map[string]interface{}{
"text": text,
"operation": operation,
})
if err != nil {
return err
}
// Send request
reqData, _ := req.Marshal()
if err := framer.WriteMessage(reqData); err != nil {
return err
}
// Read response
respData, err := framer.ReadMessage()
if err != nil {
return err
}
var resp protocol.Response
if err := resp.Unmarshal(respData); err != nil {
return err
}
if !resp.OK {
return fmt.Errorf("error from Python: %s", resp.ErrorMsg)
}
var result map[string]interface{}
if err := resp.UnmarshalBody(&result); err != nil {
return err
}
fmt.Printf("Operation: %s\n", operation)
fmt.Printf("Original: %s\n", result["original"])
fmt.Printf("Transformed: %s\n", result["transformed"])
return nil
}
func callComputeStats(framer *framing.Framer, numbers []float64) error {
req, err := protocol.NewRequest(4, "compute_stats", map[string]interface{}{
"numbers": numbers,
})
if err != nil {
return err
}
// Send request
reqData, _ := req.Marshal()
if err := framer.WriteMessage(reqData); err != nil {
return err
}
// Read response
respData, err := framer.ReadMessage()
if err != nil {
return err
}
var resp protocol.Response
if err := resp.Unmarshal(respData); err != nil {
return err
}
if !resp.OK {
return fmt.Errorf("error from Python: %s", resp.ErrorMsg)
}
var result map[string]interface{}
if err := resp.UnmarshalBody(&result); err != nil {
return err
}
fmt.Printf("Numbers: %v\n", numbers)
fmt.Printf("Statistics:\n")
fmt.Printf(" Count: %.0f\n", result["count"])
fmt.Printf(" Mean: %.2f\n", result["mean"])
fmt.Printf(" Min: %.2f\n", result["min"])
fmt.Printf(" Max: %.2f\n", result["max"])
fmt.Printf(" Sum: %.2f\n", result["sum"])
return nil
}
func callHealth(framer *framing.Framer) error {
req, err := protocol.NewRequest(5, "health", map[string]interface{}{})
if err != nil {
return err
}
// Send request
reqData, _ := req.Marshal()
if err := framer.WriteMessage(reqData); err != nil {
return err
}
// Read response
respData, err := framer.ReadMessage()
if err != nil {
return err
}
var resp protocol.Response
if err := resp.Unmarshal(respData); err != nil {
return err
}
if !resp.OK {
return fmt.Errorf("error from Python: %s", resp.ErrorMsg)
}
var result map[string]interface{}
if err := resp.UnmarshalBody(&result); err != nil {
return err
}
jsonBytes, _ := json.MarshalIndent(result, "", " ")
fmt.Printf("Health check response:\n%s\n", jsonBytes)
return nil
}
// Package framing implements an enhanced framing protocol with request ID
// and CRC32C checksum for reliable multiplexed message transmission.
package framing
import (
"encoding/binary"
"fmt"
"hash/crc32"
)
// Frame header constants
const (
// Frame header size: 2 (magic) + 4 (length) + 8 (request ID) + 4 (CRC32C) = 18 bytes
FrameHeaderSize = 18
// Magic bytes to identify valid frames
MagicByte1 = 0x50 // 'P'
MagicByte2 = 0x59 // 'Y'
)
// FrameHeader represents the enhanced frame header
type FrameHeader struct {
Magic [2]byte // Magic bytes for frame validation
Length uint32 // Total frame length (including header)
RequestID uint64 // Request ID for multiplexing
CRC32C uint32 // CRC32C checksum of the payload
}
// Frame represents a complete frame with header and payload
type Frame struct {
Header FrameHeader
Payload []byte
}
var crc32cTable = crc32.MakeTable(crc32.Castagnoli)
// NewFrame creates a new frame with the given request ID and payload
func NewFrame(requestID uint64, payload []byte) *Frame {
return &Frame{
Header: FrameHeader{
Magic: [2]byte{MagicByte1, MagicByte2},
Length: uint32(FrameHeaderSize + len(payload)),
RequestID: requestID,
CRC32C: crc32.Checksum(payload, crc32cTable),
},
Payload: payload,
}
}
// Marshal serializes the frame to bytes
func (f *Frame) Marshal() []byte {
buf := make([]byte, f.Header.Length)
// Write magic bytes
buf[0] = f.Header.Magic[0]
buf[1] = f.Header.Magic[1]
// Write length (4 bytes, big-endian)
binary.BigEndian.PutUint32(buf[2:6], f.Header.Length)
// Write request ID (8 bytes, big-endian)
binary.BigEndian.PutUint64(buf[6:14], f.Header.RequestID)
// Write CRC32C (4 bytes, big-endian)
binary.BigEndian.PutUint32(buf[14:18], f.Header.CRC32C)
// Copy payload (starting after the header)
if len(f.Payload) > 0 {
copy(buf[FrameHeaderSize:], f.Payload)
}
return buf
}
// UnmarshalFrame deserializes a frame from bytes
func UnmarshalFrame(data []byte) (*Frame, error) {
if len(data) < FrameHeaderSize {
return nil, fmt.Errorf("frame too short: %d bytes", len(data))
}
// Check magic bytes
if data[0] != MagicByte1 || data[1] != MagicByte2 {
return nil, fmt.Errorf("invalid magic bytes: %02x%02x", data[0], data[1])
}
// Parse header
header := FrameHeader{
Magic: [2]byte{data[0], data[1]},
Length: binary.BigEndian.Uint32(data[2:6]),
RequestID: binary.BigEndian.Uint64(data[6:14]),
CRC32C: binary.BigEndian.Uint32(data[14:18]),
}
// Validate length
if int(header.Length) != len(data) {
return nil, fmt.Errorf("frame length mismatch: header says %d, got %d", header.Length, len(data))
}
// Extract payload (starting after the header)
payload := data[FrameHeaderSize:]
// Verify CRC32C
calculatedCRC := crc32.Checksum(payload, crc32cTable)
if calculatedCRC != header.CRC32C {
return nil, fmt.Errorf("CRC32C mismatch: expected %08x, got %08x", header.CRC32C, calculatedCRC)
}
return &Frame{
Header: header,
Payload: payload,
}, nil
}
// ValidateChecksum verifies the CRC32C checksum
func (f *Frame) ValidateChecksum() bool {
return crc32.Checksum(f.Payload, crc32cTable) == f.Header.CRC32C
}
// UpdateChecksum recalculates and updates the CRC32C checksum
func (f *Frame) UpdateChecksum() {
f.Header.CRC32C = crc32.Checksum(f.Payload, crc32cTable)
f.Header.Length = uint32(FrameHeaderSize + len(f.Payload))
}
// Package framing implements the 4-byte length prefixed framing protocol
// for reliable message transmission over Unix Domain Sockets.
package framing
import (
"encoding/binary"
"fmt"
"io"
)
const (
// DefaultMaxFrameSize is the default maximum frame size (10MB)
DefaultMaxFrameSize = 10 * 1024 * 1024
)
// Framer handles framing of messages over a stream
type Framer struct {
rw io.ReadWriter
maxFrameSize int
// Enhanced mode enables request ID and CRC32C
enhancedMode bool
}
// NewFramer creates a new framer with default max frame size
func NewFramer(rw io.ReadWriter) *Framer {
return &Framer{
rw: rw,
maxFrameSize: DefaultMaxFrameSize,
enhancedMode: false,
}
}
// NewFramerWithMaxSize creates a new framer with specified max frame size
func NewFramerWithMaxSize(rw io.ReadWriter, maxSize int) *Framer {
return &Framer{
rw: rw,
maxFrameSize: maxSize,
enhancedMode: false,
}
}
// NewEnhancedFramer creates a framer with request ID and CRC32C support
func NewEnhancedFramer(rw io.ReadWriter) *Framer {
return &Framer{
rw: rw,
maxFrameSize: DefaultMaxFrameSize,
enhancedMode: true,
}
}
// WriteMessage writes a framed message
// Frame format: [4 bytes length (big-endian)] [message bytes]
func (f *Framer) WriteMessage(data []byte) error {
if len(data) > f.maxFrameSize {
return fmt.Errorf("message size %d exceeds max frame size %d", len(data), f.maxFrameSize)
}
// Write length header (4 bytes, big-endian)
lengthBuf := make([]byte, 4)
binary.BigEndian.PutUint32(lengthBuf, uint32(len(data)))
if _, err := f.rw.Write(lengthBuf); err != nil {
return fmt.Errorf("failed to write frame length: %w", err)
}
// Write message data
if _, err := f.rw.Write(data); err != nil {
return fmt.Errorf("failed to write frame data: %w", err)
}
return nil
}
// WriteFrame writes an enhanced frame with request ID and CRC32C
func (f *Framer) WriteFrame(frame *Frame) error {
if !f.enhancedMode {
// Fall back to simple message write
return f.WriteMessage(frame.Payload)
}
if len(frame.Payload) > f.maxFrameSize {
return fmt.Errorf("payload size %d exceeds max frame size %d", len(frame.Payload), f.maxFrameSize)
}
// Marshal the entire frame
data := frame.Marshal()
// Write the complete frame
if _, err := f.rw.Write(data); err != nil {
return fmt.Errorf("failed to write frame: %w", err)
}
return nil
}
// ReadMessage reads a framed message
func (f *Framer) ReadMessage() ([]byte, error) {
// Read length header (4 bytes)
lengthBuf := make([]byte, 4)
if _, err := io.ReadFull(f.rw, lengthBuf); err != nil {
if err == io.EOF {
return nil, io.EOF
}
return nil, fmt.Errorf("failed to read frame length: %w", err)
}
// Parse length
length := binary.BigEndian.Uint32(lengthBuf)
if int(length) > f.maxFrameSize {
return nil, fmt.Errorf("frame size %d exceeds max frame size %d", length, f.maxFrameSize)
}
// Read message data
data := make([]byte, length)
if _, err := io.ReadFull(f.rw, data); err != nil {
return nil, fmt.Errorf("failed to read frame data: %w", err)
}
return data, nil
}
// ReadFrame reads an enhanced frame with request ID and CRC32C
func (f *Framer) ReadFrame() (*Frame, error) {
if !f.enhancedMode {
// Fall back to simple message read
data, err := f.ReadMessage()
if err != nil {
return nil, err
}
// Create a simple frame with no request ID
return &Frame{
Payload: data,
}, nil
}
// Peek at magic bytes first
magicBuf := make([]byte, 2)
if _, err := io.ReadFull(f.rw, magicBuf); err != nil {
if err == io.EOF {
return nil, io.EOF
}
return nil, fmt.Errorf("failed to read magic bytes: %w", err)
}
// Check magic bytes
if magicBuf[0] != MagicByte1 || magicBuf[1] != MagicByte2 {
return nil, fmt.Errorf("invalid magic bytes: %02x%02x", magicBuf[0], magicBuf[1])
}
// Read the rest of the header
headerBuf := make([]byte, FrameHeaderSize-2) // -2 for magic bytes already read
if _, err := io.ReadFull(f.rw, headerBuf); err != nil {
return nil, fmt.Errorf("failed to read frame header: %w", err)
}
// Parse header fields
length := binary.BigEndian.Uint32(headerBuf[0:4])
if int(length) > f.maxFrameSize+FrameHeaderSize {
return nil, fmt.Errorf("frame size %d exceeds max frame size %d", length, f.maxFrameSize)
}
// Read payload
payloadSize := int(length) - FrameHeaderSize
payload := make([]byte, payloadSize)
if payloadSize > 0 {
if _, err := io.ReadFull(f.rw, payload); err != nil {
return nil, fmt.Errorf("failed to read frame payload: %w", err)
}
}
// Reconstruct complete frame data for unmarshaling
completeData := make([]byte, length)
copy(completeData[0:2], magicBuf)
copy(completeData[2:FrameHeaderSize], headerBuf)
if payloadSize > 0 {
copy(completeData[FrameHeaderSize:], payload)
}
// Unmarshal and validate
return UnmarshalFrame(completeData)
}
// Package protocol defines the message types and communication protocol
// for pyproc worker communication over Unix Domain Sockets.
package protocol
import (
"encoding/json"
"errors"
"fmt"
)
// MessageType defines the type of message being sent
type MessageType string
const (
// MessageTypeRequest is a regular request message
MessageTypeRequest MessageType = "request"
// MessageTypeResponse is a regular response message
MessageTypeResponse MessageType = "response"
// MessageTypeCancellation is a cancellation control message
MessageTypeCancellation MessageType = "cancellation"
)
// Message is the envelope for all messages between Go and Python
type Message struct {
Type MessageType `json:"type"`
Payload json.RawMessage `json:"payload"`
}
// Request represents a request from Go to Python
type Request struct {
ID uint64 `json:"id"`
Method string `json:"method"`
Body json.RawMessage `json:"body"`
Headers map[string]string `json:"headers,omitempty"` // For trace context propagation
}
// Response represents a response from Python to Go
type Response struct {
ID uint64 `json:"id"`
OK bool `json:"ok"`
Body json.RawMessage `json:"body,omitempty"`
ErrorMsg string `json:"error,omitempty"`
}
// CancellationRequest represents a cancellation signal for a specific request
type CancellationRequest struct {
ID uint64 `json:"id"` // Request ID to cancel
Reason string `json:"reason"` // Reason for cancellation (e.g., "context cancelled", "timeout")
}
// NewRequest creates a new request with the given method and body
func NewRequest(id uint64, method string, body interface{}) (*Request, error) {
bodyBytes, err := json.Marshal(body)
if err != nil {
return nil, fmt.Errorf("failed to marshal request body: %w", err)
}
return &Request{
ID: id,
Method: method,
Body: bodyBytes,
}, nil
}
// NewResponse creates a new successful response
func NewResponse(id uint64, body interface{}) (*Response, error) {
bodyBytes, err := json.Marshal(body)
if err != nil {
return nil, fmt.Errorf("failed to marshal response body: %w", err)
}
return &Response{
ID: id,
OK: true,
Body: bodyBytes,
}, nil
}
// NewErrorResponse creates a new error response
func NewErrorResponse(id uint64, err error) *Response {
return &Response{
ID: id,
OK: false,
ErrorMsg: err.Error(),
}
}
// Marshal serializes the request to JSON
func (r *Request) Marshal() ([]byte, error) {
return json.Marshal(r)
}
// Unmarshal deserializes the request from JSON
func (r *Request) Unmarshal(data []byte) error {
return json.Unmarshal(data, r)
}
// UnmarshalBody unmarshals the request body into the given interface
func (r *Request) UnmarshalBody(v interface{}) error {
return json.Unmarshal(r.Body, v)
}
// Marshal serializes the response to JSON
func (r *Response) Marshal() ([]byte, error) {
return json.Marshal(r)
}
// Unmarshal deserializes the response from JSON
func (r *Response) Unmarshal(data []byte) error {
return json.Unmarshal(data, r)
}
// UnmarshalBody unmarshals the response body into the given interface
func (r *Response) UnmarshalBody(v interface{}) error {
if r.Body == nil {
return fmt.Errorf("response body is nil")
}
return json.Unmarshal(r.Body, v)
}
// Error returns the error message if the response is an error
func (r *Response) Error() error {
if r.OK {
return nil
}
if r.ErrorMsg == "" {
return fmt.Errorf("unknown error")
}
return errors.New(r.ErrorMsg)
}
// NewCancellationRequest creates a new cancellation request
func NewCancellationRequest(id uint64, reason string) *CancellationRequest {
return &CancellationRequest{
ID: id,
Reason: reason,
}
}
// WrapMessage wraps a payload with a message type envelope
func WrapMessage(msgType MessageType, payload interface{}) (*Message, error) {
payloadBytes, err := json.Marshal(payload)
if err != nil {
return nil, fmt.Errorf("failed to marshal payload: %w", err)
}
return &Message{
Type: msgType,
Payload: payloadBytes,
}, nil
}
// UnwrapMessage extracts the payload from a message envelope
func UnwrapMessage(data []byte) (*Message, error) {
var msg Message
if err := json.Unmarshal(data, &msg); err != nil {
return nil, fmt.Errorf("failed to unmarshal message: %w", err)
}
return &msg, nil
}
package pyproc
import (
"fmt"
"os"
)
// Codec defines the interface for encoding/decoding messages
type Codec interface {
// Marshal serializes a value to bytes
Marshal(v interface{}) ([]byte, error)
// Unmarshal deserializes bytes to a value
Unmarshal(data []byte, v interface{}) error
// Name returns the name of the codec
Name() string
}
// CodecType represents the type of codec to use
type CodecType string
const (
// CodecJSON uses JSON encoding (default)
CodecJSON CodecType = "json"
// CodecMessagePack uses MessagePack encoding
CodecMessagePack CodecType = "msgpack"
// CodecProtobuf uses Protocol Buffers encoding
CodecProtobuf CodecType = "protobuf"
)
// GetJSONCodecType returns the JSON codec implementation being used
// Can be overridden with PYPROC_JSON_CODEC environment variable
func GetJSONCodecType() string {
if codecType := os.Getenv("PYPROC_JSON_CODEC"); codecType != "" {
return codecType
}
// Return the compile-time selected codec
return (&JSONCodec{}).Name()
}
// NewCodec creates a new codec based on the type
func NewCodec(codecType CodecType) (Codec, error) {
switch codecType {
case CodecJSON, "":
return &JSONCodec{}, nil
case CodecMessagePack:
return &MessagePackCodec{}, nil
case CodecProtobuf:
// TODO: Implement in Phase 3
return nil, fmt.Errorf("protobuf codec not yet implemented")
default:
return nil, fmt.Errorf("unknown codec type: %s", codecType)
}
}
//go:build !json_goccy && !json_segmentio
package pyproc
import (
"encoding/json"
)
// JSONCodec implements Codec using standard library encoding/json
type JSONCodec struct{}
// Marshal serializes a value to JSON bytes using standard library
func (c *JSONCodec) Marshal(v interface{}) ([]byte, error) {
return json.Marshal(v)
}
// Unmarshal deserializes JSON bytes to a value using standard library
func (c *JSONCodec) Unmarshal(data []byte, v interface{}) error {
return json.Unmarshal(data, v)
}
// Name returns the name of the codec
func (c *JSONCodec) Name() string {
return "json-stdlib"
}
package pyproc
import (
"github.com/vmihailenco/msgpack/v5"
)
// MessagePackCodec implements Codec using MessagePack encoding
type MessagePackCodec struct{}
// Marshal serializes a value to MessagePack bytes
func (c *MessagePackCodec) Marshal(v interface{}) ([]byte, error) {
return msgpack.Marshal(v)
}
// Unmarshal deserializes MessagePack bytes to a value
func (c *MessagePackCodec) Unmarshal(data []byte, v interface{}) error {
return msgpack.Unmarshal(data, v)
}
// Name returns the name of the codec
func (c *MessagePackCodec) Name() string {
return "msgpack"
}
// Package pyproc provides a Go library for calling Python functions
// without CGO, using Unix Domain Sockets for high-performance IPC.
package pyproc
import (
"fmt"
"time"
"github.com/spf13/viper"
)
// Config holds all configuration for pyproc
type Config struct {
Pool PoolConfig `mapstructure:"pool"`
Python PythonConfig `mapstructure:"python"`
Socket SocketConfig `mapstructure:"socket"`
Protocol ProtocolConfig `mapstructure:"protocol"`
Logging LoggingConfig `mapstructure:"logging"`
Metrics MetricsConfig `mapstructure:"metrics"`
}
// PoolConfig defines worker pool settings
type PoolConfig struct {
Workers int `mapstructure:"workers"`
MaxInFlight int `mapstructure:"max_in_flight"`
MaxInFlightPerWorker int `mapstructure:"max_in_flight_per_worker"`
StartTimeout time.Duration `mapstructure:"start_timeout"`
HealthInterval time.Duration `mapstructure:"health_interval"`
Restart RestartConfig `mapstructure:"restart"`
}
// RestartConfig defines restart policy
type RestartConfig struct {
MaxAttempts int `mapstructure:"max_attempts"`
InitialBackoff time.Duration `mapstructure:"initial_backoff"`
MaxBackoff time.Duration `mapstructure:"max_backoff"`
Multiplier float64 `mapstructure:"multiplier"`
}
// PythonConfig defines Python runtime settings
type PythonConfig struct {
Executable string `mapstructure:"executable"`
WorkerScript string `mapstructure:"worker_script"`
Env map[string]string `mapstructure:"env"`
}
// SocketConfig defines Unix domain socket settings
type SocketConfig struct {
Dir string `mapstructure:"dir"`
Prefix string `mapstructure:"prefix"`
Permissions uint32 `mapstructure:"permissions"`
}
// ProtocolConfig defines protocol settings
type ProtocolConfig struct {
MaxFrameSize int `mapstructure:"max_frame_size"`
RequestTimeout time.Duration `mapstructure:"request_timeout"`
ConnectionTimeout time.Duration `mapstructure:"connection_timeout"`
}
// LoggingConfig defines logging settings
type LoggingConfig struct {
Level string `mapstructure:"level"`
Format string `mapstructure:"format"`
TraceEnabled bool `mapstructure:"trace_enabled"`
}
// MetricsConfig defines metrics collection settings
type MetricsConfig struct {
Enabled bool `mapstructure:"enabled"`
Endpoint string `mapstructure:"endpoint"`
Path string `mapstructure:"path"`
}
// LoadConfig loads configuration from file and environment
func LoadConfig(configPath string) (*Config, error) {
v := viper.New()
// Set defaults
setDefaults(v)
// Set config file
if configPath != "" {
v.SetConfigFile(configPath)
} else {
v.SetConfigName("config")
v.SetConfigType("yaml")
v.AddConfigPath(".")
v.AddConfigPath("./config")
v.AddConfigPath("/etc/pyproc")
}
// Read environment variables
v.SetEnvPrefix("PYPROC")
v.AutomaticEnv()
// Read config file
if err := v.ReadInConfig(); err != nil {
// It's ok if config file doesn't exist, we have defaults
if _, ok := err.(viper.ConfigFileNotFoundError); !ok {
return nil, fmt.Errorf("failed to read config: %w", err)
}
}
// Unmarshal config
var cfg Config
if err := v.Unmarshal(&cfg); err != nil {
return nil, fmt.Errorf("failed to unmarshal config: %w", err)
}
// Convert duration fields (viper reads them as seconds)
cfg.Pool.StartTimeout *= time.Second
cfg.Pool.HealthInterval *= time.Second
cfg.Pool.Restart.InitialBackoff *= time.Millisecond
cfg.Pool.Restart.MaxBackoff *= time.Millisecond
cfg.Protocol.RequestTimeout *= time.Second
cfg.Protocol.ConnectionTimeout *= time.Second
return &cfg, nil
}
func setDefaults(v *viper.Viper) {
// Pool defaults
v.SetDefault("pool.workers", 4)
v.SetDefault("pool.max_in_flight", 10)
v.SetDefault("pool.max_in_flight_per_worker", 1)
v.SetDefault("pool.start_timeout", 30)
v.SetDefault("pool.health_interval", 30)
v.SetDefault("pool.restart.max_attempts", 5)
v.SetDefault("pool.restart.initial_backoff", 1000)
v.SetDefault("pool.restart.max_backoff", 30000)
v.SetDefault("pool.restart.multiplier", 2.0)
// Python defaults
v.SetDefault("python.executable", "python3")
v.SetDefault("python.worker_script", "./worker.py")
v.SetDefault("python.env", map[string]string{
"PYTHONUNBUFFERED": "1",
})
// Socket defaults
v.SetDefault("socket.dir", "/tmp")
v.SetDefault("socket.prefix", "pyproc")
v.SetDefault("socket.permissions", 0600)
// Protocol defaults
v.SetDefault("protocol.max_frame_size", 10485760) // 10MB
v.SetDefault("protocol.request_timeout", 60)
v.SetDefault("protocol.connection_timeout", 5)
// Logging defaults
v.SetDefault("logging.level", "info")
v.SetDefault("logging.format", "json")
v.SetDefault("logging.trace_enabled", true)
// Metrics defaults
v.SetDefault("metrics.enabled", true)
v.SetDefault("metrics.endpoint", ":9090")
v.SetDefault("metrics.path", "/metrics")
}
package pyproc
import (
"context"
"fmt"
"net"
"time"
)
const defaultSleepDuration = 100 * time.Millisecond
// ConnectToWorker connects to a worker via Unix domain socket
func ConnectToWorker(socketPath string, timeout time.Duration) (net.Conn, error) {
// Set connection timeout
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
for {
conn, err := net.Dial("unix", socketPath)
if err == nil {
return conn, nil
}
// Sleep before retry, or return if timeout occurs during sleep
if err := sleepWithCtx(ctx, defaultSleepDuration); err != nil {
return nil, fmt.Errorf("failed to connect to worker at %s after %v", socketPath, timeout)
}
}
}
func sleepWithCtx(ctx context.Context, d time.Duration) error {
// Wait a bit before retrying
timer := time.NewTimer(d)
defer timer.Stop()
select {
case <-ctx.Done():
return ctx.Err()
case <-timer.C:
return nil
}
}
package pyproc
import (
"encoding/json"
"fmt"
"net/http"
)
// LivenessHandler returns an http.Handler for Kubernetes liveness probes.
// It always returns HTTP 200 with {"status":"ok"} as long as the process is alive.
func LivenessHandler(_ *Pool) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
fmt.Fprintf(w, `{"status":"ok"}`) //nolint:errcheck
})
}
// ReadinessHandler returns an http.Handler for Kubernetes readiness probes.
// It returns HTTP 200 when at least one worker is healthy and the pool is not
// shutting down; otherwise it returns HTTP 503.
func ReadinessHandler(pool *Pool) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "application/json")
health := pool.Health()
if health.HealthyWorkers > 0 && !pool.shutdown.Load() {
w.WriteHeader(http.StatusOK)
resp, _ := json.Marshal(map[string]interface{}{
"status": "ready",
"healthy_workers": health.HealthyWorkers,
"total_workers": health.TotalWorkers,
})
w.Write(resp) //nolint:errcheck
return
}
w.WriteHeader(http.StatusServiceUnavailable)
resp, _ := json.Marshal(map[string]interface{}{
"status": "not_ready",
"healthy_workers": health.HealthyWorkers,
"total_workers": health.TotalWorkers,
})
w.Write(resp) //nolint:errcheck
})
}
// StartupHandler returns an http.Handler for Kubernetes startup probes.
// It returns HTTP 200 once the pool has completed its Start() sequence;
// otherwise it returns HTTP 503.
func StartupHandler(pool *Pool) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "application/json")
if pool.started.Load() {
w.WriteHeader(http.StatusOK)
fmt.Fprintf(w, `{"status":"started"}`) //nolint:errcheck
return
}
w.WriteHeader(http.StatusServiceUnavailable)
fmt.Fprintf(w, `{"status":"not_started"}`) //nolint:errcheck
})
}
package pyproc
import (
"context"
"log/slog"
"os"
"sync/atomic"
)
// traceIDKey is the context key for trace ID
type traceIDKey struct{}
// traceIDCounter is used to generate unique trace IDs
var traceIDCounter atomic.Uint64
// Logger wraps slog.Logger with trace ID support
type Logger struct {
*slog.Logger
traceEnabled bool
}
// NewLogger creates a new logger with the specified configuration
func NewLogger(cfg LoggingConfig) *Logger {
var handler slog.Handler
opts := &slog.HandlerOptions{
Level: parseLogLevel(cfg.Level),
}
switch cfg.Format {
case "json":
handler = slog.NewJSONHandler(os.Stdout, opts)
default:
handler = slog.NewTextHandler(os.Stdout, opts)
}
return &Logger{
Logger: slog.New(handler),
traceEnabled: cfg.TraceEnabled,
}
}
// WithTraceID adds a trace ID to the context
func WithTraceID(ctx context.Context) context.Context {
traceID := traceIDCounter.Add(1)
return context.WithValue(ctx, traceIDKey{}, traceID)
}
// GetTraceID retrieves the trace ID from the context
func GetTraceID(ctx context.Context) (uint64, bool) {
id, ok := ctx.Value(traceIDKey{}).(uint64)
return id, ok
}
// InfoContext logs an info message with trace ID if enabled
func (l *Logger) InfoContext(ctx context.Context, msg string, args ...any) {
if l.traceEnabled {
if traceID, ok := GetTraceID(ctx); ok {
args = append([]any{"trace_id", traceID}, args...)
}
}
l.Logger.InfoContext(ctx, msg, args...)
}
// ErrorContext logs an error message with trace ID if enabled
func (l *Logger) ErrorContext(ctx context.Context, msg string, args ...any) {
if l.traceEnabled {
if traceID, ok := GetTraceID(ctx); ok {
args = append([]any{"trace_id", traceID}, args...)
}
}
l.Logger.ErrorContext(ctx, msg, args...)
}
// DebugContext logs a debug message with trace ID if enabled
func (l *Logger) DebugContext(ctx context.Context, msg string, args ...any) {
if l.traceEnabled {
if traceID, ok := GetTraceID(ctx); ok {
args = append([]any{"trace_id", traceID}, args...)
}
}
l.Logger.DebugContext(ctx, msg, args...)
}
// WarnContext logs a warning message with trace ID if enabled
func (l *Logger) WarnContext(ctx context.Context, msg string, args ...any) {
if l.traceEnabled {
if traceID, ok := GetTraceID(ctx); ok {
args = append([]any{"trace_id", traceID}, args...)
}
}
l.Logger.WarnContext(ctx, msg, args...)
}
// WithWorker returns a logger with worker ID attached
func (l *Logger) WithWorker(workerID string) *Logger {
return &Logger{
Logger: l.With("worker_id", workerID),
traceEnabled: l.traceEnabled,
}
}
// WithMethod returns a logger with method name attached
func (l *Logger) WithMethod(method string) *Logger {
return &Logger{
Logger: l.With("method", method),
traceEnabled: l.traceEnabled,
}
}
// WithRequestID returns a logger with request ID attached
func (l *Logger) WithRequestID(requestID uint64) *Logger {
return &Logger{
Logger: l.With("request_id", requestID),
traceEnabled: l.traceEnabled,
}
}
func parseLogLevel(level string) slog.Level {
switch level {
case "debug":
return slog.LevelDebug
case "info":
return slog.LevelInfo
case "warn":
return slog.LevelWarn
case "error":
return slog.LevelError
default:
return slog.LevelInfo
}
}
package pyproc
import (
"fmt"
"io"
"net/http"
)
// MetricsHandler returns an http.Handler that serves Prometheus text exposition
// format metrics from the given PoolWithMetrics.
// No external dependencies; hand-written text format.
func MetricsHandler(pool *PoolWithMetrics) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
snap := pool.GetMetrics()
health := pool.Health()
w.Header().Set("Content-Type", "text/plain; version=0.0.4; charset=utf-8")
writeMetrics(w, snap, health)
})
}
func writeMetrics(w io.Writer, snap MetricsSnapshot, health HealthStatus) {
// Request counters
fmt.Fprintf(w, "# HELP pyproc_requests_total Total number of requests.\n") //nolint:errcheck
fmt.Fprintf(w, "# TYPE pyproc_requests_total counter\n") //nolint:errcheck
fmt.Fprintf(w, "pyproc_requests_total{status=\"success\"} %d\n", snap.RequestsSucceeded) //nolint:errcheck
fmt.Fprintf(w, "pyproc_requests_total{status=\"failed\"} %d\n", snap.RequestsFailed) //nolint:errcheck
fmt.Fprintf(w, "pyproc_requests_total{status=\"timeout\"} %d\n", snap.RequestsTimeout) //nolint:errcheck
// Latency percentiles
fmt.Fprintf(w, "# HELP pyproc_request_duration_seconds Request latency percentiles in seconds.\n") //nolint:errcheck
fmt.Fprintf(w, "# TYPE pyproc_request_duration_seconds gauge\n") //nolint:errcheck
fmt.Fprintf(w, "pyproc_request_duration_seconds{quantile=\"0.5\"} %f\n", snap.LatencyP50.Seconds()) //nolint:errcheck
fmt.Fprintf(w, "pyproc_request_duration_seconds{quantile=\"0.95\"} %f\n", snap.LatencyP95.Seconds()) //nolint:errcheck
fmt.Fprintf(w, "pyproc_request_duration_seconds{quantile=\"0.99\"} %f\n", snap.LatencyP99.Seconds()) //nolint:errcheck
// Worker gauges
fmt.Fprintf(w, "# HELP pyproc_workers_total Total number of workers.\n") //nolint:errcheck
fmt.Fprintf(w, "# TYPE pyproc_workers_total gauge\n") //nolint:errcheck
fmt.Fprintf(w, "pyproc_workers_total %d\n", health.TotalWorkers) //nolint:errcheck
fmt.Fprintf(w, "# HELP pyproc_workers_healthy Number of healthy workers.\n") //nolint:errcheck
fmt.Fprintf(w, "# TYPE pyproc_workers_healthy gauge\n") //nolint:errcheck
fmt.Fprintf(w, "pyproc_workers_healthy %d\n", health.HealthyWorkers) //nolint:errcheck
// Inflight
fmt.Fprintf(w, "# HELP pyproc_inflight_requests Number of in-flight requests.\n") //nolint:errcheck
fmt.Fprintf(w, "# TYPE pyproc_inflight_requests gauge\n") //nolint:errcheck
fmt.Fprintf(w, "pyproc_inflight_requests %d\n", snap.QueueDepth) //nolint:errcheck
// Worker restarts
fmt.Fprintf(w, "# HELP pyproc_worker_restarts_total Total worker restarts.\n") //nolint:errcheck
fmt.Fprintf(w, "# TYPE pyproc_worker_restarts_total counter\n") //nolint:errcheck
fmt.Fprintf(w, "pyproc_worker_restarts_total %d\n", snap.WorkerRestarts) //nolint:errcheck
}
package pyproc
import (
"context"
"encoding/json"
"errors"
"fmt"
"net"
"sync"
"sync/atomic"
"time"
"github.com/YuminosukeSato/pyproc/internal/framing"
"github.com/YuminosukeSato/pyproc/internal/protocol"
"github.com/YuminosukeSato/pyproc/pkg/pyproc/telemetry"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/codes"
"go.opentelemetry.io/otel/trace"
)
// PoolOptions provides additional options for creating a pool
type PoolOptions struct {
Config PoolConfig // Base pool configuration
WorkerConfig WorkerConfig // Configuration for each worker
// ExternalMode, when true, tells the pool to connect to pre-existing
// worker processes (e.g. Kubernetes sidecar containers) instead of
// spawning child processes. ExternalSocketPaths must list the UDS paths.
ExternalMode bool
ExternalSocketPaths []string
// ExternalMaxRetries controls the number of connection retry attempts
// for external workers. If zero, defaults to 10.
ExternalMaxRetries int
// ExternalRetryInterval is the initial retry interval for external
// workers. Each retry doubles the interval. If zero, defaults to 500ms.
ExternalRetryInterval time.Duration
}
type workerHandle interface {
Start(context.Context) error
Stop() error
IsHealthy(context.Context) bool
GetSocketPath() string
}
// Pool manages multiple Python workers with load balancing
type Pool struct {
opts PoolOptions
logger *Logger
workers []*poolWorker
nextIdx atomic.Uint64
shutdown atomic.Bool
started atomic.Bool
activeCallsWG sync.WaitGroup
callsMu sync.Mutex
wg sync.WaitGroup
// Backpressure control
semaphore chan struct{}
workerAvailable chan struct{}
shutdownCh chan struct{}
// Health monitoring
healthMu sync.RWMutex
healthStatus HealthStatus
healthCancel context.CancelFunc
// Request tracking for cancellation
activeRequests map[uint64]*activeRequest
activeRequestsMu sync.RWMutex
// Observability
tracer trace.Tracer
}
// activeRequest tracks an in-flight request for cancellation support
type activeRequest struct {
id uint64
workerIdx int
conn net.Conn
cancelOnce sync.Once
done chan struct{}
}
// poolWorker wraps a Worker with connection pooling
type poolWorker struct {
worker workerHandle
connPool chan net.Conn
inflightGate chan struct{}
requestID atomic.Uint64
healthy atomic.Bool
}
// HealthStatus represents the health of the pool
type HealthStatus struct {
TotalWorkers int
HealthyWorkers int
LastCheck time.Time
}
// NewPool creates a new worker pool. When opts.ExternalMode is true the pool
// connects to pre-existing worker processes at the paths given in
// opts.ExternalSocketPaths instead of spawning child processes.
func NewPool(opts PoolOptions, logger *Logger) (*Pool, error) {
if opts.ExternalMode {
return newExternalPool(opts, logger)
}
if opts.Config.Workers <= 0 {
return nil, errors.New("workers must be > 0")
}
if opts.Config.MaxInFlight <= 0 {
opts.Config.MaxInFlight = 10
}
if opts.Config.MaxInFlightPerWorker <= 0 {
opts.Config.MaxInFlightPerWorker = 1
}
if opts.Config.HealthInterval <= 0 {
opts.Config.HealthInterval = 30 * time.Second
}
if logger == nil {
logger = NewLogger(LoggingConfig{Level: "info", Format: "json"})
}
pool := &Pool{
opts: opts,
logger: logger,
workers: make([]*poolWorker, opts.Config.Workers),
semaphore: make(chan struct{}, opts.Config.MaxInFlight),
workerAvailable: make(chan struct{}, opts.Config.Workers*opts.Config.MaxInFlightPerWorker),
shutdownCh: make(chan struct{}),
activeRequests: make(map[uint64]*activeRequest),
}
// Create workers
for i := 0; i < opts.Config.Workers; i++ {
workerCfg := opts.WorkerConfig
workerCfg.ID = fmt.Sprintf("worker-%d", i)
workerCfg.SocketPath = fmt.Sprintf("%s-%d", opts.WorkerConfig.SocketPath, i)
if workerCfg.StartTimeout == 0 {
workerCfg.StartTimeout = 5 * time.Second
}
worker := NewWorker(workerCfg, logger)
pool.workers[i] = &poolWorker{
worker: worker,
connPool: make(chan net.Conn, opts.Config.MaxInFlightPerWorker),
inflightGate: make(chan struct{}, opts.Config.MaxInFlightPerWorker),
}
}
return pool, nil
}
// newExternalPool creates a pool that uses ExternalWorker instances.
func newExternalPool(opts PoolOptions, logger *Logger) (*Pool, error) {
if len(opts.ExternalSocketPaths) == 0 {
return nil, errors.New("ExternalSocketPaths must not be empty in ExternalMode")
}
numWorkers := len(opts.ExternalSocketPaths)
opts.Config.Workers = numWorkers
if opts.Config.MaxInFlight <= 0 {
opts.Config.MaxInFlight = 10
}
if opts.Config.MaxInFlightPerWorker <= 0 {
opts.Config.MaxInFlightPerWorker = 1
}
if opts.Config.HealthInterval <= 0 {
opts.Config.HealthInterval = 30 * time.Second
}
if logger == nil {
logger = NewLogger(LoggingConfig{Level: "info", Format: "json"})
}
pool := &Pool{
opts: opts,
logger: logger,
workers: make([]*poolWorker, numWorkers),
semaphore: make(chan struct{}, opts.Config.MaxInFlight),
workerAvailable: make(chan struct{}, numWorkers*opts.Config.MaxInFlightPerWorker),
shutdownCh: make(chan struct{}),
activeRequests: make(map[uint64]*activeRequest),
}
for i, sockPath := range opts.ExternalSocketPaths {
worker := NewExternalWorkerWithOptions(ExternalWorkerOptions{
SocketPath: sockPath,
ConnectTimeout: opts.WorkerConfig.StartTimeout,
MaxRetries: opts.ExternalMaxRetries,
RetryInterval: opts.ExternalRetryInterval,
})
pool.workers[i] = &poolWorker{
worker: worker,
connPool: make(chan net.Conn, opts.Config.MaxInFlightPerWorker),
inflightGate: make(chan struct{}, opts.Config.MaxInFlightPerWorker),
}
}
return pool, nil
}
// WithTracer sets the OpenTelemetry tracer for the pool.
// This enables distributed tracing for all Pool.Call() operations.
// If not set, no tracing will be performed (zero overhead).
func (p *Pool) WithTracer(tracer trace.Tracer) *Pool {
p.tracer = tracer
return p
}
// Start starts all workers in the pool
func (p *Pool) Start(ctx context.Context) error {
p.logger.Info("starting worker pool", "workers", p.opts.Config.Workers)
// Start all workers
for i, pw := range p.workers {
if err := pw.worker.Start(ctx); err != nil {
// Stop already started workers
for j := 0; j < i; j++ {
_ = p.workers[j].worker.Stop()
}
return fmt.Errorf("failed to start worker %d: %w", i, err)
}
// Don't mark as healthy until first health check succeeds
// pw.healthy.Store(true) - removed, will be set by health check
}
// Pre-populate connection pools (synchronous to avoid race conditions)
// We do minimal pre-population to reduce startup latency while avoiding complexity
for _, pw := range p.workers {
// Pre-populate just one connection per worker for faster first call
conn, err := p.connect(pw.worker.GetSocketPath())
if err != nil {
p.logger.Debug("failed to pre-populate connection", "error", err)
continue
}
select {
case pw.connPool <- conn:
default:
// Pool is full (shouldn't happen with MaxInFlightPerWorker=1), close connection
if err := conn.Close(); err != nil {
p.logger.Error("failed to close connection", "error", err)
}
}
}
// Give workers time to stabilize before health check
time.Sleep(100 * time.Millisecond)
// Start health monitoring
healthCtx, cancel := context.WithCancel(context.Background())
p.healthCancel = cancel
p.wg.Add(1)
go p.healthMonitor(healthCtx)
// Initial health check
p.updateHealthStatus()
p.started.Store(true)
p.logger.Info("worker pool started successfully")
return nil
}
// Call invokes a method on one of the workers using round-robin
func (p *Pool) Call(ctx context.Context, method string, input interface{}, output interface{}) error {
// Start tracing span if tracer is configured
var span trace.Span
if p.tracer != nil {
ctx, span = p.tracer.Start(ctx, "pyproc.Pool.Call",
trace.WithSpanKind(trace.SpanKindClient),
trace.WithAttributes(
attribute.String("pyproc.method", method),
),
)
defer func() {
// Span will be ended here, status is set in error paths
span.End()
}()
}
p.callsMu.Lock()
if p.shutdown.Load() {
p.callsMu.Unlock()
err := errors.New("pool is shut down")
if span != nil {
span.RecordError(err)
span.SetStatus(codes.Error, err.Error())
}
return err
}
p.activeCallsWG.Add(1)
p.callsMu.Unlock()
defer p.activeCallsWG.Done()
// Acquire semaphore for backpressure
select {
case p.semaphore <- struct{}{}:
defer func() { <-p.semaphore }()
case <-ctx.Done():
err := ctx.Err()
if span != nil {
span.RecordError(err)
span.SetStatus(codes.Error, err.Error())
}
return err
case <-p.shutdownCh:
err := errors.New("pool is shut down")
if span != nil {
span.RecordError(err)
span.SetStatus(codes.Error, err.Error())
}
return err
}
pw, workerIdx, err := p.acquireWorker(ctx)
if err != nil {
if span != nil {
span.RecordError(err)
span.SetStatus(codes.Error, err.Error())
}
return err
}
defer func() {
<-pw.inflightGate
p.signalWorkerAvailable()
}()
// Add worker ID to span attributes
if span != nil {
span.SetAttributes(attribute.Int("pyproc.worker_id", workerIdx))
}
// Get connection from pool
var conn net.Conn
select {
case pooledConn, ok := <-pw.connPool:
if !ok {
err := errors.New("connection pool closed")
if span != nil {
span.RecordError(err)
span.SetStatus(codes.Error, err.Error())
}
return err
}
conn = pooledConn
default:
// Create new connection if pool is empty
var err error
conn, err = p.connect(pw.worker.GetSocketPath())
if err != nil {
if span != nil {
span.RecordError(err)
span.SetStatus(codes.Error, "failed to connect")
}
return fmt.Errorf("failed to connect: %w", err)
}
}
// Generate request ID
reqID := pw.requestID.Add(1)
// Track active request for cancellation
activeReq := &activeRequest{
id: reqID,
workerIdx: workerIdx,
conn: conn,
done: make(chan struct{}),
}
p.activeRequestsMu.Lock()
p.activeRequests[reqID] = activeReq
p.activeRequestsMu.Unlock()
// Monitor context for cancellation
go p.monitorCancellation(ctx, activeReq)
// Flag to track if connection was closed due to cancellation
connClosed := false
// Clean up active request and return connection on exit
defer func() {
close(activeReq.done)
p.activeRequestsMu.Lock()
delete(p.activeRequests, reqID)
p.activeRequestsMu.Unlock()
// Return connection to pool only if not closed
if !connClosed && !p.shutdown.Load() {
select {
case pw.connPool <- conn:
default:
_ = conn.Close()
}
} else if !connClosed {
_ = conn.Close()
}
}()
// Send request
req, err := protocol.NewRequest(reqID, method, input)
if err != nil {
if span != nil {
span.RecordError(err)
span.SetStatus(codes.Error, "failed to create request")
}
return err
}
// Inject trace context into request headers if tracing is enabled
if span != nil {
if req.Headers == nil {
req.Headers = make(map[string]string)
}
telemetry.InjectTraceContext(ctx, req.Headers)
}
// For now, send in legacy format for backward compatibility
// TODO: Switch to wrapped format once Python side is fully tested
framer := framing.NewFramer(conn)
reqData, err := req.Marshal()
if err != nil {
if span != nil {
span.RecordError(err)
span.SetStatus(codes.Error, "failed to marshal request")
}
return err
}
if err := framer.WriteMessage(reqData); err != nil {
connClosed = true
_ = conn.Close() // Connection is bad, don't return to pool
if span != nil {
span.RecordError(err)
span.SetStatus(codes.Error, "failed to send request")
}
return err
}
// Read response
respData, err := framer.ReadMessage()
if err != nil {
// Check if error is due to cancellation
select {
case <-ctx.Done():
connClosed = true
err := ctx.Err()
if span != nil {
span.RecordError(err)
span.SetStatus(codes.Error, err.Error())
}
return err
default:
connClosed = true
_ = conn.Close() // Connection is bad, don't return to pool
if span != nil {
span.RecordError(err)
span.SetStatus(codes.Error, "failed to read response")
}
return err
}
}
// For now, handle legacy format
// TODO: Switch to wrapped format once Python side is fully tested
var resp protocol.Response
if err := resp.Unmarshal(respData); err != nil {
err := fmt.Errorf("failed to unmarshal response: %w", err)
if span != nil {
span.RecordError(err)
span.SetStatus(codes.Error, "unmarshal failed")
}
return err
}
if !resp.OK {
err := resp.Error()
if span != nil {
span.RecordError(err)
span.SetStatus(codes.Error, "python worker error")
}
return err
}
// Handle special methods for testing
if method == "echo_worker_id" {
// Add worker ID to response
var result map[string]interface{}
if err := json.Unmarshal(resp.Body, &result); err == nil {
result["worker_id"] = float64(workerIdx)
modifiedBody, _ := json.Marshal(result)
resp.Body = modifiedBody
}
}
err = resp.UnmarshalBody(output)
if err != nil {
if span != nil {
span.RecordError(err)
span.SetStatus(codes.Error, "failed to unmarshal output")
}
return err
}
// Success: set span status to OK
if span != nil {
span.SetStatus(codes.Ok, "")
}
return nil
}
// Shutdown gracefully shuts down all workers
func (p *Pool) Shutdown(_ context.Context) error {
if !p.shutdown.CompareAndSwap(false, true) {
return nil // Already shutting down
}
close(p.shutdownCh)
p.logger.Info("shutting down worker pool")
// Cancel health monitoring
if p.healthCancel != nil {
p.healthCancel()
}
// Wait for in-flight calls to complete before closing pools.
// The lock ensures all in-progress Call() goroutines have finished
// their activeCallsWG.Add(1) before we start waiting.
p.callsMu.Lock()
p.activeCallsWG.Wait()
p.callsMu.Unlock()
// Close all connection pools
for _, pw := range p.workers {
close(pw.connPool)
for conn := range pw.connPool {
_ = conn.Close()
}
}
// Stop all workers
var errs []error
for i, pw := range p.workers {
if err := pw.worker.Stop(); err != nil {
errs = append(errs, fmt.Errorf("worker %d: %w", i, err))
}
}
// Wait for goroutines
p.wg.Wait()
if len(errs) > 0 {
return fmt.Errorf("shutdown errors: %v", errs)
}
p.logger.Info("worker pool shut down successfully")
return nil
}
func (p *Pool) signalWorkerAvailable() {
select {
case p.workerAvailable <- struct{}{}:
default:
}
}
func (p *Pool) acquireWorker(ctx context.Context) (*poolWorker, int, error) {
workers := p.workers
if len(workers) == 0 {
return nil, -1, errors.New("no workers available")
}
for {
startIdx := int(p.nextIdx.Add(1) - 1)
healthyFound := false
for i := 0; i < len(workers); i++ {
idx := (startIdx + i) % len(workers)
pw := workers[idx]
if !pw.healthy.Load() {
continue
}
healthyFound = true
select {
case pw.inflightGate <- struct{}{}:
return pw, idx, nil
default:
}
}
if !healthyFound {
return nil, -1, errors.New("no healthy workers available")
}
select {
case <-ctx.Done():
return nil, -1, ctx.Err()
case <-p.shutdownCh:
return nil, -1, errors.New("pool is shut down")
case <-p.workerAvailable:
}
}
}
// Health returns the current health status of the pool
func (p *Pool) Health() HealthStatus {
p.healthMu.RLock()
defer p.healthMu.RUnlock()
return p.healthStatus
}
// IsHealthy checks if a worker is healthy
func (w *Worker) IsHealthy(_ context.Context) bool {
// Check if process is running
if w.state.Load() != int32(WorkerStateRunning) {
return false
}
// For now, just check if the process is running
// TODO: Implement actual health RPC call
return w.IsRunning()
}
// connect establishes a connection to a worker
func (p *Pool) connect(socketPath string) (net.Conn, error) {
conn, err := net.Dial("unix", socketPath)
if err != nil {
return nil, fmt.Errorf("failed to connect to %s: %w", socketPath, err)
}
return conn, nil
}
// healthMonitor periodically checks worker health
func (p *Pool) healthMonitor(ctx context.Context) {
defer p.wg.Done()
ticker := time.NewTicker(p.opts.Config.HealthInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
p.updateHealthStatus()
}
}
}
// updateHealthStatus updates the health status of all workers
func (p *Pool) updateHealthStatus() {
healthy := 0
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
for _, pw := range p.workers {
if pw.worker.IsHealthy(ctx) {
pw.healthy.Store(true)
healthy++
} else {
pw.healthy.Store(false)
}
}
p.healthMu.Lock()
p.healthStatus = HealthStatus{
TotalWorkers: len(p.workers),
HealthyWorkers: healthy,
LastCheck: time.Now(),
}
p.healthMu.Unlock()
if healthy < len(p.workers) {
p.logger.Warn("some workers are unhealthy",
"healthy", healthy, "total", len(p.workers))
}
}
// monitorCancellation monitors the context and sends cancellation if needed
func (p *Pool) monitorCancellation(ctx context.Context, req *activeRequest) {
select {
case <-ctx.Done():
// Context was cancelled, send cancellation message and close connection
req.cancelOnce.Do(func() {
p.logger.Debug("context cancelled, sending cancellation", "request_id", req.id)
// Try to send cancellation message first
p.sendCancellation(req)
// Then close the connection to ensure cancellation
if req.conn != nil {
_ = req.conn.Close()
}
})
case <-req.done:
// Request completed normally
return
}
}
// sendCancellation sends a cancellation message to the Python worker
func (p *Pool) sendCancellation(req *activeRequest) {
p.logger.Debug("sending cancellation", "request_id", req.id)
// Create cancellation request
cancelReq := protocol.NewCancellationRequest(req.id, "context cancelled")
// Wrap in message envelope
msgEnvelope, err := protocol.WrapMessage(protocol.MessageTypeCancellation, cancelReq)
if err != nil {
p.logger.Error("failed to create cancellation message", "error", err)
return
}
// Marshal to JSON
data, err := json.Marshal(msgEnvelope)
if err != nil {
p.logger.Error("failed to marshal cancellation message", "error", err)
return
}
// Send cancellation message using the same connection
// This is best-effort, if it fails we still close the connection
framer := framing.NewFramer(req.conn)
if err := framer.WriteMessage(data); err != nil {
p.logger.Debug("failed to send cancellation message, will rely on connection close", "error", err)
}
}
package pyproc
import (
"context"
"fmt"
"time"
)
// CallTyped is a type-safe wrapper for Pool.Call using Go generics.
// This is the RECOMMENDED way to call Python functions from Go.
//
// Benefits:
// - Compile-time type checking for request and response
// - No runtime type assertions needed
// - Clear function signatures in your code
// - Better IDE autocomplete support
// - Identical performance to untyped Call() - zero overhead
//
// Type Parameters:
// - TIn: The input type (must be JSON-serializable)
// - TOut: The output type (must match Python response structure)
//
// Example usage:
//
// type PredictRequest struct {
// Value float64 `json:"value"`
// }
// type PredictResponse struct {
// Result float64 `json:"result"`
// }
//
// result, err := pyproc.CallTyped[PredictRequest, PredictResponse](
// ctx, pool, "predict", PredictRequest{Value: 42},
// )
// if err != nil {
// return err
// }
// fmt.Printf("Result: %v\n", result.Result) // Type-safe access
//
// Error Handling:
// - Returns clear error messages for JSON marshaling failures
// - Returns descriptive errors for type mismatches
// - All errors are wrapped with context using fmt.Errorf with %w
//
// Performance:
// - Benchmarked at <1% overhead compared to untyped Call()
// - Actually uses 13% less memory and 35% fewer allocations
// - See BenchmarkTypedVsUntyped in pool_generic_test.go
func CallTyped[TIn any, TOut any](ctx context.Context, pool *Pool, method string, input TIn) (TOut, error) {
var output TOut
// Call the underlying pool method
err := pool.Call(ctx, method, input, &output)
if err != nil {
return output, fmt.Errorf("call %s failed: %w", method, err)
}
return output, nil
}
// CallTypedWithTransport is a type-safe wrapper for PoolWithTransport.Call using Go generics
func CallTypedWithTransport[TIn any, TOut any](ctx context.Context, pool *PoolWithTransport, method string, input TIn) (TOut, error) {
var output TOut
// Call the underlying pool method
err := pool.Call(ctx, method, input, &output)
if err != nil {
return output, fmt.Errorf("call %s failed: %w", method, err)
}
return output, nil
}
// TypedPool provides a type-safe wrapper around Pool with predefined input/output types.
// Use this when you want to reuse the same type pair for multiple method calls.
//
// Benefits over direct CallTyped:
// - Type parameters specified once at creation
// - Cleaner code when calling multiple methods with same types
// - Full access to pool lifecycle methods (Start, Shutdown, Health)
//
// Example usage:
//
// type Request struct { Value int `json:"value"` }
// type Response struct { Result int `json:"result"` }
//
// pool, err := pyproc.NewPool(opts, logger)
// if err != nil {
// return err
// }
//
// typedPool := pyproc.NewTypedPool[Request, Response](pool)
// if err := typedPool.Start(ctx); err != nil {
// return err
// }
// defer typedPool.Shutdown(ctx)
//
// // Multiple calls with type safety
// resp1, err := typedPool.Call(ctx, "method1", Request{Value: 1})
// resp2, err := typedPool.Call(ctx, "method2", Request{Value: 2})
type TypedPool[TIn any, TOut any] struct {
pool *Pool
}
// NewTypedPool creates a new typed pool wrapper with predefined input/output types.
// The returned TypedPool can call any Python method that accepts TIn and returns TOut.
func NewTypedPool[TIn any, TOut any](pool *Pool) *TypedPool[TIn, TOut] {
return &TypedPool[TIn, TOut]{
pool: pool,
}
}
// Call executes a method with type safety
func (tp *TypedPool[TIn, TOut]) Call(ctx context.Context, method string, input TIn) (TOut, error) {
return CallTyped[TIn, TOut](ctx, tp.pool, method, input)
}
// Start starts all workers in the pool
func (tp *TypedPool[TIn, TOut]) Start(ctx context.Context) error {
return tp.pool.Start(ctx)
}
// Shutdown gracefully shuts down the pool
func (tp *TypedPool[TIn, TOut]) Shutdown(ctx context.Context) error {
return tp.pool.Shutdown(ctx)
}
// Health returns the health status of the pool
func (tp *TypedPool[TIn, TOut]) Health() HealthStatus {
return tp.pool.Health()
}
// TypedWorkerClient provides a type-safe client for a specific Python worker method.
// Use this when you're repeatedly calling the same method and want the cleanest API.
//
// Benefits:
// - Method name specified once at creation
// - Simplest call syntax: client.Call(ctx, input)
// - Built-in batch execution with BatchCall
// - Type safety for both single and batch operations
//
// Example usage:
//
// type Request struct { Value int `json:"value"` }
// type Response struct { Result int `json:"result"` }
//
// pool, err := pyproc.NewPool(opts, logger)
// if err != nil {
// return err
// }
//
// // Create a client for the "predict" method
// predictClient := pyproc.NewTypedWorkerClient[Request, Response](pool, "predict")
//
// // Single call - cleanest syntax
// resp, err := predictClient.Call(ctx, Request{Value: 42})
//
// // Batch call - parallel execution
// inputs := []Request{{Value: 1}, {Value: 2}, {Value: 3}}
// responses, errors := predictClient.BatchCall(ctx, inputs)
type TypedWorkerClient[TIn any, TOut any] struct {
pool *Pool
method string
}
// NewTypedWorkerClient creates a type-safe client for a specific Python worker method.
// The returned client will always call the specified method with type safety.
func NewTypedWorkerClient[TIn any, TOut any](pool *Pool, method string) *TypedWorkerClient[TIn, TOut] {
return &TypedWorkerClient[TIn, TOut]{
pool: pool,
method: method,
}
}
// Call executes the predefined method with type safety.
// This is the simplest way to call a Python method with full type safety.
func (tc *TypedWorkerClient[TIn, TOut]) Call(ctx context.Context, input TIn) (TOut, error) {
return CallTyped[TIn, TOut](ctx, tc.pool, tc.method, input)
}
// BatchCall executes multiple requests in parallel using goroutines.
// Returns a slice of results and a slice of errors (one per input).
// Results and errors are guaranteed to be in the same order as inputs.
//
// Example:
//
// inputs := []Request{{Value: 1}, {Value: 2}, {Value: 3}}
// results, errors := client.BatchCall(ctx, inputs)
// for i := range inputs {
// if errors[i] != nil {
// log.Printf("Request %d failed: %v", i, errors[i])
// continue
// }
// log.Printf("Result %d: %v", i, results[i])
// }
func (tc *TypedWorkerClient[TIn, TOut]) BatchCall(ctx context.Context, inputs []TIn) ([]TOut, []error) {
results := make([]TOut, len(inputs))
errors := make([]error, len(inputs))
if len(inputs) == 0 {
return results, errors
}
start := time.Now()
// Use goroutines for parallel execution
type result struct {
index int
output TOut
err error
}
resultCh := make(chan result, len(inputs))
for i, input := range inputs {
go func(idx int, in TIn) {
out, err := tc.Call(ctx, in)
resultCh <- result{index: idx, output: out, err: err}
}(i, input)
}
// Collect results, but stop waiting if ctx is done.
completed := make([]bool, len(inputs))
remaining := len(inputs)
for remaining > 0 {
select {
case res := <-resultCh:
if !completed[res.index] {
results[res.index] = res.output
errors[res.index] = res.err
completed[res.index] = true
remaining--
}
case <-ctx.Done():
timeoutErr := timeoutErrorForContext(ctx, start)
for i := range inputs {
if !completed[i] {
errors[i] = timeoutErr
}
}
return results, errors
}
}
return results, errors
}
// Example usage types for common patterns
// PredictRequest represents a prediction request
type PredictRequest struct {
Value float64 `json:"value"`
}
// PredictResponse represents a prediction response
type PredictResponse struct {
Result float64 `json:"result"`
}
// TransformRequest represents a text transformation request
type TransformRequest struct {
Text string `json:"text"`
}
// TransformResponse represents a text transformation response
type TransformResponse struct {
TransformedText string `json:"transformed_text"`
WordCount int `json:"word_count"`
}
// BatchRequest represents a batch processing request
type BatchRequest struct {
Items []map[string]interface{} `json:"items"`
}
// BatchResponse represents a batch processing response
type BatchResponse struct {
Results []map[string]interface{} `json:"results"`
Count int `json:"count"`
}
// StatsRequest represents a statistics computation request
type StatsRequest struct {
Numbers []float64 `json:"numbers"`
}
// StatsResponse represents a statistics computation response
type StatsResponse struct {
Mean float64 `json:"mean"`
Median float64 `json:"median"`
StdDev float64 `json:"std_dev"`
Min float64 `json:"min"`
Max float64 `json:"max"`
}
package pyproc
import (
"context"
"slices"
"sync"
"sync/atomic"
"time"
)
// PoolMetrics tracks metrics for connection pooling
type PoolMetrics struct {
// Connection metrics
ConnectionsCreated atomic.Uint64
ConnectionsDestroyed atomic.Uint64
ConnectionsActive atomic.Int32
ConnectionsIdle atomic.Int32
// Request metrics
RequestsTotal atomic.Uint64
RequestsSucceeded atomic.Uint64
RequestsFailed atomic.Uint64
RequestsTimeout atomic.Uint64
// Latency tracking
latencyMu sync.RWMutex
latencies []time.Duration
maxLatencies int
// Worker metrics
WorkerRestarts atomic.Uint64
WorkerFailures atomic.Uint64
// Pool utilization
PoolUtilization atomic.Uint64 // percentage * 100
QueueDepth atomic.Int32
}
// NewPoolMetrics creates a new metrics tracker
func NewPoolMetrics() *PoolMetrics {
return &PoolMetrics{
maxLatencies: 10000, // Keep last 10k latencies for percentile calculation
latencies: make([]time.Duration, 0, 10000),
}
}
// RecordLatency records a request latency
func (m *PoolMetrics) RecordLatency(latency time.Duration) {
m.latencyMu.Lock()
defer m.latencyMu.Unlock()
if len(m.latencies) >= m.maxLatencies {
// Remove oldest entry
m.latencies = m.latencies[1:]
}
m.latencies = append(m.latencies, latency)
}
// GetLatencyPercentile calculates latency percentile
func (m *PoolMetrics) GetLatencyPercentile(percentile float64) time.Duration {
m.latencyMu.RLock()
defer m.latencyMu.RUnlock()
if len(m.latencies) == 0 {
return 0
}
// Create a copy for sorting
sorted := make([]time.Duration, len(m.latencies))
copy(sorted, m.latencies)
slices.Sort(sorted)
// Simple percentile calculation (not perfectly accurate but fast)
index := int(float64(len(sorted)-1) * percentile / 100.0)
if index < 0 {
index = 0
}
if index >= len(sorted) {
index = len(sorted) - 1
}
return sorted[index]
}
// GetMetricsSnapshot returns a snapshot of current metrics.
// Percentiles are computed inline to avoid nested RLock on latencyMu,
// which would deadlock under concurrent writes (Go RWMutex is not reentrant).
func (m *PoolMetrics) GetMetricsSnapshot() MetricsSnapshot {
m.latencyMu.RLock()
var p50, p95, p99 time.Duration
if len(m.latencies) > 0 {
sorted := make([]time.Duration, len(m.latencies))
copy(sorted, m.latencies)
slices.Sort(sorted)
p50 = percentileFromSorted(sorted, 50)
p95 = percentileFromSorted(sorted, 95)
p99 = percentileFromSorted(sorted, 99)
}
m.latencyMu.RUnlock()
return MetricsSnapshot{
ConnectionsCreated: m.ConnectionsCreated.Load(),
ConnectionsDestroyed: m.ConnectionsDestroyed.Load(),
ConnectionsActive: m.ConnectionsActive.Load(),
ConnectionsIdle: m.ConnectionsIdle.Load(),
RequestsTotal: m.RequestsTotal.Load(),
RequestsSucceeded: m.RequestsSucceeded.Load(),
RequestsFailed: m.RequestsFailed.Load(),
RequestsTimeout: m.RequestsTimeout.Load(),
WorkerRestarts: m.WorkerRestarts.Load(),
WorkerFailures: m.WorkerFailures.Load(),
PoolUtilization: float64(m.PoolUtilization.Load()) / 100.0,
QueueDepth: m.QueueDepth.Load(),
LatencyP50: p50,
LatencyP95: p95,
LatencyP99: p99,
}
}
// percentileFromSorted returns the value at the given percentile from a pre-sorted slice.
func percentileFromSorted(sorted []time.Duration, percentile float64) time.Duration {
index := int(float64(len(sorted)-1) * percentile / 100.0)
if index < 0 {
index = 0
}
if index >= len(sorted) {
index = len(sorted) - 1
}
return sorted[index]
}
// MetricsSnapshot represents a point-in-time metrics snapshot
type MetricsSnapshot struct {
// Connections
ConnectionsCreated uint64
ConnectionsDestroyed uint64
ConnectionsActive int32
ConnectionsIdle int32
// Requests
RequestsTotal uint64
RequestsSucceeded uint64
RequestsFailed uint64
RequestsTimeout uint64
// Workers
WorkerRestarts uint64
WorkerFailures uint64
// Performance
PoolUtilization float64
QueueDepth int32
LatencyP50 time.Duration
LatencyP95 time.Duration
LatencyP99 time.Duration
// Timestamp
Timestamp time.Time
}
// PoolWithMetrics wraps a pool with metrics collection
type PoolWithMetrics struct {
*Pool
metrics *PoolMetrics
}
// NewPoolWithMetrics creates a pool with metrics tracking
func NewPoolWithMetrics(opts PoolOptions, logger *Logger) (*PoolWithMetrics, error) {
pool, err := NewPool(opts, logger)
if err != nil {
return nil, err
}
return &PoolWithMetrics{
Pool: pool,
metrics: NewPoolMetrics(),
}, nil
}
// Call wraps the pool Call method with metrics
func (p *PoolWithMetrics) Call(ctx context.Context, method string, input interface{}, output interface{}) error {
start := time.Now()
p.metrics.RequestsTotal.Add(1)
p.metrics.QueueDepth.Add(1)
defer p.metrics.QueueDepth.Add(-1)
err := p.Pool.Call(ctx, method, input, output)
latency := time.Since(start)
p.metrics.RecordLatency(latency)
if err != nil {
if ctx.Err() == context.DeadlineExceeded {
p.metrics.RequestsTimeout.Add(1)
} else {
p.metrics.RequestsFailed.Add(1)
}
} else {
p.metrics.RequestsSucceeded.Add(1)
}
// Update utilization
activeConns := p.metrics.ConnectionsActive.Load()
totalConns := activeConns + p.metrics.ConnectionsIdle.Load()
if totalConns > 0 {
utilization := uint64(activeConns * 100 / totalConns)
p.metrics.PoolUtilization.Store(utilization)
}
return err
}
// GetMetrics returns the current metrics snapshot
func (p *PoolWithMetrics) GetMetrics() MetricsSnapshot {
snapshot := p.metrics.GetMetricsSnapshot()
snapshot.Timestamp = time.Now()
return snapshot
}
// ResetMetrics resets all metrics counters
func (p *PoolWithMetrics) ResetMetrics() {
p.metrics = NewPoolMetrics()
}
package pyproc
import (
"context"
"errors"
"fmt"
"sync"
"sync/atomic"
"time"
"github.com/YuminosukeSato/pyproc/internal/protocol"
)
// TransportPool manages a pool of transports for load balancing
type TransportPool struct {
transports []Transport
nextIdx atomic.Uint64
logger *Logger
mu sync.RWMutex
}
// NewTransportPool creates a new transport pool
func NewTransportPool(configs []TransportConfig, logger *Logger) (*TransportPool, error) {
if len(configs) == 0 {
return nil, errors.New("at least one transport config is required")
}
pool := &TransportPool{
transports: make([]Transport, 0, len(configs)),
logger: logger,
}
for i, config := range configs {
transport, err := NewTransport(config, logger)
if err != nil {
// Clean up already created transports
for _, t := range pool.transports {
_ = t.Close()
}
return nil, fmt.Errorf("failed to create transport %d: %w", i, err)
}
pool.transports = append(pool.transports, transport)
}
return pool, nil
}
// Call selects a transport and makes a call
func (p *TransportPool) Call(ctx context.Context, req *protocol.Request) (*protocol.Response, error) {
p.mu.RLock()
defer p.mu.RUnlock()
if len(p.transports) == 0 {
return nil, errors.New("no transports available")
}
// Try round-robin with fallback
startIdx := p.nextIdx.Add(1) - 1
for i := 0; i < len(p.transports); i++ {
idx := (startIdx + uint64(i)) % uint64(len(p.transports))
transport := p.transports[idx]
if transport.IsHealthy() {
resp, err := transport.Call(ctx, req)
if err == nil {
return resp, nil
}
p.logger.Warn("transport call failed, trying next",
"index", idx,
"error", err)
}
}
return nil, errors.New("all transports failed")
}
// Close closes all transports in the pool
func (p *TransportPool) Close() error {
p.mu.Lock()
defer p.mu.Unlock()
var errs []error
for i, transport := range p.transports {
if err := transport.Close(); err != nil {
errs = append(errs, fmt.Errorf("transport %d: %w", i, err))
}
}
p.transports = nil
if len(errs) > 0 {
return fmt.Errorf("failed to close transports: %v", errs)
}
return nil
}
// Health returns the health status of the pool
func (p *TransportPool) Health() (healthy, total int) {
p.mu.RLock()
defer p.mu.RUnlock()
total = len(p.transports)
for _, transport := range p.transports {
if transport.IsHealthy() {
healthy++
}
}
return
}
// PoolWithTransport updates the Pool to use Transport interface
type PoolWithTransport struct {
opts PoolOptions
logger *Logger
transportPool *TransportPool
workers []*Worker // Still manage worker processes
shutdown atomic.Bool
wg sync.WaitGroup
// Backpressure control
semaphore chan struct{}
// Health monitoring
healthMu sync.RWMutex
healthStatus HealthStatus
healthCancel context.CancelFunc
}
// NewPoolWithTransport creates a new pool using the Transport interface
func NewPoolWithTransport(opts PoolOptions, logger *Logger) (*PoolWithTransport, error) {
if opts.Config.Workers <= 0 {
return nil, errors.New("workers must be > 0")
}
if opts.Config.MaxInFlight <= 0 {
opts.Config.MaxInFlight = 10
}
if opts.Config.MaxInFlightPerWorker <= 0 {
opts.Config.MaxInFlightPerWorker = 1
}
if opts.Config.HealthInterval <= 0 {
opts.Config.HealthInterval = 30 * time.Second
}
if logger == nil {
logger = NewLogger(LoggingConfig{Level: "info", Format: "json"})
}
pool := &PoolWithTransport{
opts: opts,
logger: logger,
workers: make([]*Worker, opts.Config.Workers),
semaphore: make(chan struct{}, opts.Config.MaxInFlight),
}
// Create workers (they still manage the Python processes)
for i := 0; i < opts.Config.Workers; i++ {
workerCfg := opts.WorkerConfig
workerCfg.ID = fmt.Sprintf("worker-%d", i)
workerCfg.SocketPath = fmt.Sprintf("%s-%d", opts.WorkerConfig.SocketPath, i)
if workerCfg.StartTimeout == 0 {
workerCfg.StartTimeout = 5 * time.Second
}
// Security configuration will be handled in WorkerConfig directly
worker := NewWorker(workerCfg, logger)
pool.workers[i] = worker
}
return pool, nil
}
// Start starts all workers and creates transports
func (p *PoolWithTransport) Start(ctx context.Context) error {
p.logger.Info("starting worker pool with transports", "workers", p.opts.Config.Workers)
// Start all workers
for i, worker := range p.workers {
if err := worker.Start(ctx); err != nil {
// Stop already started workers
for j := 0; j < i; j++ {
_ = p.workers[j].Stop()
}
return fmt.Errorf("failed to start worker %d: %w", i, err)
}
}
// Give workers time to stabilize
time.Sleep(100 * time.Millisecond)
// Create transport configurations for each worker
configs := make([]TransportConfig, len(p.workers))
for i, worker := range p.workers {
configs[i] = TransportConfig{
Type: "uds",
Address: worker.GetSocketPath(),
Options: map[string]interface{}{
"timeout": 5 * time.Second,
"idle_timeout": 30 * time.Second,
},
}
}
// Create transport pool
transportPool, err := NewTransportPool(configs, p.logger)
if err != nil {
// Stop all workers if transport creation fails
for _, worker := range p.workers {
_ = worker.Stop()
}
return fmt.Errorf("failed to create transport pool: %w", err)
}
p.transportPool = transportPool
// Start health monitoring
healthCtx, cancel := context.WithCancel(context.Background())
p.healthCancel = cancel
p.wg.Add(1)
go p.healthMonitor(healthCtx)
// Initial health check
p.updateHealthStatus()
p.logger.Info("worker pool with transports started successfully")
return nil
}
// Call invokes a method using the transport pool
func (p *PoolWithTransport) Call(ctx context.Context, method string, input interface{}, output interface{}) error {
if p.shutdown.Load() {
return errors.New("pool is shut down")
}
// Acquire semaphore for backpressure
select {
case p.semaphore <- struct{}{}:
defer func() { <-p.semaphore }()
case <-ctx.Done():
return ctx.Err()
}
// Create request
req, err := protocol.NewRequest(0, method, input)
if err != nil {
return fmt.Errorf("failed to create request: %w", err)
}
// Call through transport pool
resp, err := p.transportPool.Call(ctx, req)
if err != nil {
return fmt.Errorf("transport call failed: %w", err)
}
if !resp.OK {
return resp.Error()
}
return resp.UnmarshalBody(output)
}
// Shutdown gracefully shuts down the pool
func (p *PoolWithTransport) Shutdown(_ context.Context) error {
if !p.shutdown.CompareAndSwap(false, true) {
return nil // Already shutting down
}
p.logger.Info("shutting down worker pool with transports")
// Cancel health monitoring
if p.healthCancel != nil {
p.healthCancel()
}
// Close transport pool
if p.transportPool != nil {
if err := p.transportPool.Close(); err != nil {
p.logger.Error("failed to close transport pool", "error", err)
}
}
// Stop all workers
var errs []error
for i, worker := range p.workers {
if err := worker.Stop(); err != nil {
errs = append(errs, fmt.Errorf("worker %d: %w", i, err))
}
}
// Wait for goroutines
p.wg.Wait()
if len(errs) > 0 {
return fmt.Errorf("shutdown errors: %v", errs)
}
p.logger.Info("worker pool with transports shut down successfully")
return nil
}
// healthMonitor periodically checks worker health
func (p *PoolWithTransport) healthMonitor(ctx context.Context) {
defer p.wg.Done()
ticker := time.NewTicker(p.opts.Config.HealthInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
p.updateHealthStatus()
}
}
}
// updateHealthStatus updates the health status
func (p *PoolWithTransport) updateHealthStatus() {
healthy, total := p.transportPool.Health()
p.healthMu.Lock()
p.healthStatus = HealthStatus{
TotalWorkers: total,
HealthyWorkers: healthy,
LastCheck: time.Now(),
}
p.healthMu.Unlock()
if healthy < total {
p.logger.Warn("some transports are unhealthy",
"healthy", healthy, "total", total)
}
}
// Health returns the current health status
func (p *PoolWithTransport) Health() HealthStatus {
p.healthMu.RLock()
defer p.healthMu.RUnlock()
return p.healthStatus
}
package pyproc
import (
"fmt"
"os"
"path/filepath"
)
// SocketManager manages Unix domain socket files
type SocketManager struct {
dir string
prefix string
permissions os.FileMode
}
// NewSocketManager creates a new socket manager
func NewSocketManager(cfg SocketConfig) *SocketManager {
return &SocketManager{
dir: cfg.Dir,
prefix: cfg.Prefix,
permissions: os.FileMode(cfg.Permissions),
}
}
// GenerateSocketPath generates a unique socket path for a worker
func (sm *SocketManager) GenerateSocketPath(workerID string) string {
filename := fmt.Sprintf("%s-%s.sock", sm.prefix, workerID)
return filepath.Join(sm.dir, filename)
}
// CleanupSocket removes a socket file if it exists
func (sm *SocketManager) CleanupSocket(socketPath string) error {
// Check if the file exists
if _, err := os.Stat(socketPath); err != nil {
if os.IsNotExist(err) {
// File doesn't exist, nothing to clean up
return nil
}
return fmt.Errorf("failed to stat socket file: %w", err)
}
// Remove the socket file
if err := os.Remove(socketPath); err != nil {
return fmt.Errorf("failed to remove socket file: %w", err)
}
return nil
}
// CleanupAllSockets removes all socket files matching the prefix
func (sm *SocketManager) CleanupAllSockets() error {
pattern := filepath.Join(sm.dir, fmt.Sprintf("%s-*.sock", sm.prefix))
matches, err := filepath.Glob(pattern)
if err != nil {
return fmt.Errorf("failed to glob socket files: %w", err)
}
var lastErr error
for _, socketPath := range matches {
if err := os.Remove(socketPath); err != nil && !os.IsNotExist(err) {
lastErr = fmt.Errorf("failed to remove socket %s: %w", socketPath, err)
}
}
return lastErr
}
// EnsureSocketDir ensures the socket directory exists with proper permissions
func (sm *SocketManager) EnsureSocketDir() error {
// Create directory if it doesn't exist
if err := os.MkdirAll(sm.dir, 0755); err != nil {
return fmt.Errorf("failed to create socket directory: %w", err)
}
return nil
}
// SetSocketPermissions sets the proper permissions on a socket file
func (sm *SocketManager) SetSocketPermissions(socketPath string) error {
if err := os.Chmod(socketPath, sm.permissions); err != nil {
return fmt.Errorf("failed to set socket permissions: %w", err)
}
return nil
}
package pyproc
import (
"crypto/hmac"
"crypto/rand"
"crypto/sha256"
"encoding/hex"
"fmt"
"io"
"net"
"time"
)
// HMACAuth provides HMAC-based authentication for socket connections
type HMACAuth struct {
secret []byte
}
// NewHMACAuth creates a new HMAC authenticator with the given secret
func NewHMACAuth(secret []byte) *HMACAuth {
return &HMACAuth{
secret: secret,
}
}
// GenerateSecret generates a random secret key
func GenerateSecret() ([]byte, error) {
secret := make([]byte, 32)
if _, err := rand.Read(secret); err != nil {
return nil, fmt.Errorf("failed to generate secret: %w", err)
}
return secret, nil
}
// AuthenticateClient performs client-side authentication
func (h *HMACAuth) AuthenticateClient(conn net.Conn) error {
// Set timeout for auth handshake
if err := conn.SetDeadline(time.Now().Add(5 * time.Second)); err != nil {
return fmt.Errorf("failed to set deadline: %w", err)
}
defer func() { _ = conn.SetDeadline(time.Time{}) }() // Reset deadline
// Read challenge from server
challenge := make([]byte, 32)
if _, err := io.ReadFull(conn, challenge); err != nil {
return fmt.Errorf("failed to read challenge: %w", err)
}
// Compute HMAC response
mac := hmac.New(sha256.New, h.secret)
mac.Write(challenge)
response := mac.Sum(nil)
// Send response
if _, err := conn.Write(response); err != nil {
return fmt.Errorf("failed to send response: %w", err)
}
// Read authentication result
result := make([]byte, 1)
if _, err := io.ReadFull(conn, result); err != nil {
return fmt.Errorf("failed to read auth result: %w", err)
}
if result[0] != 1 {
return fmt.Errorf("authentication failed")
}
return nil
}
// AuthenticateServer performs server-side authentication
func (h *HMACAuth) AuthenticateServer(conn net.Conn) error {
// Set timeout for auth handshake
if err := conn.SetDeadline(time.Now().Add(5 * time.Second)); err != nil {
return fmt.Errorf("failed to set deadline: %w", err)
}
defer func() { _ = conn.SetDeadline(time.Time{}) }() // Reset deadline
// Generate random challenge
challenge := make([]byte, 32)
if _, err := rand.Read(challenge); err != nil {
return fmt.Errorf("failed to generate challenge: %w", err)
}
// Send challenge to client
if _, err := conn.Write(challenge); err != nil {
return fmt.Errorf("failed to send challenge: %w", err)
}
// Read response from client
response := make([]byte, 32)
if _, err := io.ReadFull(conn, response); err != nil {
return fmt.Errorf("failed to read response: %w", err)
}
// Verify HMAC
mac := hmac.New(sha256.New, h.secret)
mac.Write(challenge)
expected := mac.Sum(nil)
if !hmac.Equal(response, expected) {
// Authentication failed
_, _ = conn.Write([]byte{0})
return fmt.Errorf("HMAC verification failed")
}
// Authentication succeeded
if _, err := conn.Write([]byte{1}); err != nil {
return fmt.Errorf("failed to send auth success: %w", err)
}
return nil
}
// HMACListener wraps a listener with HMAC authentication
type HMACListener struct {
net.Listener
auth *HMACAuth
}
// NewHMACListener creates a new HMAC-authenticated listener
func NewHMACListener(listener net.Listener, secret []byte) *HMACListener {
return &HMACListener{
Listener: listener,
auth: NewHMACAuth(secret),
}
}
// Accept accepts a connection and performs HMAC authentication
func (l *HMACListener) Accept() (net.Conn, error) {
conn, err := l.Listener.Accept()
if err != nil {
return nil, err
}
// Perform authentication
if err := l.auth.AuthenticateServer(conn); err != nil {
_ = conn.Close()
return nil, fmt.Errorf("authentication failed: %w", err)
}
return conn, nil
}
// SecureConn wraps a connection with HMAC authentication
type SecureConn struct {
net.Conn
authenticated bool
}
// DialSecure dials a connection with HMAC authentication
func DialSecure(network, address string, secret []byte) (*SecureConn, error) {
conn, err := net.Dial(network, address)
if err != nil {
return nil, err
}
auth := NewHMACAuth(secret)
if err := auth.AuthenticateClient(conn); err != nil {
_ = conn.Close()
return nil, fmt.Errorf("authentication failed: %w", err)
}
return &SecureConn{
Conn: conn,
authenticated: true,
}, nil
}
// IsAuthenticated returns whether the connection is authenticated
func (c *SecureConn) IsAuthenticated() bool {
return c.authenticated
}
// SecretFromString creates a secret from a string
func SecretFromString(s string) []byte {
h := sha256.Sum256([]byte(s))
return h[:]
}
// SecretFromHex decodes a hex-encoded secret
func SecretFromHex(hexStr string) ([]byte, error) {
return hex.DecodeString(hexStr)
}
package pyproc
import (
"errors"
"fmt"
"net"
"os"
"path/filepath"
)
// SocketSecurityConfig defines security settings for Unix domain sockets
type SocketSecurityConfig struct {
// SocketDir is the directory where socket files will be created
// Default: /run/pyproc if running as root, /tmp/pyproc otherwise
SocketDir string
// SocketPerms defines the permissions for socket files
// Default: 0600 (read/write for owner only)
SocketPerms os.FileMode
// DirPerms defines the permissions for the socket directory
// Default: 0750 (rwxr-x--- for owner and group)
DirPerms os.FileMode
// AllowedUIDs is a list of UIDs that are allowed to connect
// If empty, any UID can connect (but still verified)
AllowedUIDs []uint32
// AllowedGIDs is a list of GIDs that are allowed to connect
// If empty, any GID can connect (but still verified)
AllowedGIDs []uint32
// RequireSameUser if true, only allows connections from the same UID as the server
RequireSameUser bool
}
// DefaultSocketSecurityConfig returns the default security configuration
func DefaultSocketSecurityConfig() SocketSecurityConfig {
cfg := SocketSecurityConfig{
SocketPerms: 0600,
DirPerms: 0750,
RequireSameUser: true,
}
// Use /run/pyproc if we have permissions, otherwise fallback to /tmp/pyproc
if os.Geteuid() == 0 {
cfg.SocketDir = "/run/pyproc"
} else {
cfg.SocketDir = filepath.Join(os.TempDir(), "pyproc")
}
return cfg
}
// SecureSocketPath creates a secure directory for socket files
func SecureSocketPath(config SocketSecurityConfig, socketName string) (string, error) {
// Create the socket directory with proper permissions
if err := os.MkdirAll(config.SocketDir, config.DirPerms); err != nil {
return "", fmt.Errorf("failed to create socket directory %s: %w", config.SocketDir, err)
}
// Set directory permissions explicitly (in case it already existed)
if err := os.Chmod(config.SocketDir, config.DirPerms); err != nil {
return "", fmt.Errorf("failed to set permissions on socket directory: %w", err)
}
socketPath := filepath.Join(config.SocketDir, socketName)
// Remove existing socket file if it exists
if err := os.RemoveAll(socketPath); err != nil && !os.IsNotExist(err) {
return "", fmt.Errorf("failed to remove existing socket file: %w", err)
}
return socketPath, nil
}
// SetSocketPermissions sets the appropriate permissions on a socket file
func SetSocketPermissions(socketPath string, perms os.FileMode) error {
return os.Chmod(socketPath, perms)
}
// VerifyPeerCredentials verifies the credentials of a peer connection using SO_PEERCRED
func VerifyPeerCredentials(conn net.Conn, config SocketSecurityConfig) error {
unixConn, ok := conn.(*net.UnixConn)
if !ok {
return errors.New("connection is not a Unix domain socket")
}
// Get the underlying file descriptor
rawConn, err := unixConn.SyscallConn()
if err != nil {
return fmt.Errorf("failed to get raw connection: %w", err)
}
var peerCreds *PeerCredentials
var credErr error
// Get peer credentials using SO_PEERCRED
err = rawConn.Control(func(fd uintptr) {
peerCreds, credErr = getPeerCredentials(int(fd))
})
if err != nil {
return fmt.Errorf("failed to control connection: %w", err)
}
if credErr != nil {
return fmt.Errorf("failed to get peer credentials: %w", credErr)
}
if peerCreds == nil {
return errors.New("peer credentials are nil")
}
// Verify credentials against configuration
if config.RequireSameUser {
currentUID := uint32(os.Geteuid())
if peerCreds.UID != currentUID {
return fmt.Errorf("peer UID %d does not match server UID %d", peerCreds.UID, currentUID)
}
}
// Check allowed UIDs if specified
if len(config.AllowedUIDs) > 0 {
allowed := false
for _, uid := range config.AllowedUIDs {
if peerCreds.UID == uid {
allowed = true
break
}
}
if !allowed {
return fmt.Errorf("peer UID %d is not in allowed list", peerCreds.UID)
}
}
// Check allowed GIDs if specified
if len(config.AllowedGIDs) > 0 {
allowed := false
for _, gid := range config.AllowedGIDs {
if peerCreds.GID == gid {
allowed = true
break
}
}
if !allowed {
return fmt.Errorf("peer GID %d is not in allowed list", peerCreds.GID)
}
}
return nil
}
// getPeerCredentials is implemented in platform-specific files:
// - socket_security_linux.go for Linux (using SO_PEERCRED)
// - socket_security_darwin.go for macOS (using LOCAL_PEERCRED)
// SecureListener creates a Unix domain socket listener with security features
type SecureListener struct {
net.Listener
config SocketSecurityConfig
}
// NewSecureListener creates a new secure Unix domain socket listener
func NewSecureListener(socketPath string, config SocketSecurityConfig) (*SecureListener, error) {
// Create secure socket path
path, err := SecureSocketPath(config, filepath.Base(socketPath))
if err != nil {
return nil, err
}
// Create the listener
listener, err := net.Listen("unix", path)
if err != nil {
return nil, fmt.Errorf("failed to create listener: %w", err)
}
// Set socket permissions
if err := SetSocketPermissions(path, config.SocketPerms); err != nil {
_ = listener.Close()
return nil, fmt.Errorf("failed to set socket permissions: %w", err)
}
return &SecureListener{
Listener: listener,
config: config,
}, nil
}
// Accept accepts a connection and verifies peer credentials
func (l *SecureListener) Accept() (net.Conn, error) {
conn, err := l.Listener.Accept()
if err != nil {
return nil, err
}
// Verify peer credentials
if err := VerifyPeerCredentials(conn, l.config); err != nil {
_ = conn.Close()
return nil, fmt.Errorf("peer verification failed: %w", err)
}
return conn, nil
}
//go:build linux
package pyproc
import (
"fmt"
"syscall"
"unsafe"
)
// getPeerCredentials retrieves the peer credentials using SO_PEERCRED (Linux-specific)
func getPeerCredentials(fd int) (*PeerCredentials, error) {
ucred := &syscall.Ucred{}
ucredLen := uint32(syscall.SizeofUcred)
// Get peer credentials using SO_PEERCRED
_, _, errno := syscall.Syscall6(
syscall.SYS_GETSOCKOPT,
uintptr(fd),
uintptr(syscall.SOL_SOCKET),
uintptr(syscall.SO_PEERCRED),
uintptr(unsafe.Pointer(ucred)),
uintptr(unsafe.Pointer(&ucredLen)),
0,
)
if errno != 0 {
return nil, fmt.Errorf("getsockopt SO_PEERCRED failed: %v", errno)
}
// Convert to platform-independent PeerCredentials
return &PeerCredentials{
UID: ucred.Uid,
GID: ucred.Gid,
PID: ucred.Pid,
}, nil
}
// Package telemetry provides OpenTelemetry tracing infrastructure for pyproc.
//
// This package implements distributed tracing support with the following key features:
// - Zero-overhead no-op mode when tracing is disabled
// - Backward compatibility (existing API unchanged)
// - Trace context propagation over Unix Domain Sockets
// - Integration with Pool.Call() for automatic span creation
//
// Usage:
//
// // Initialize telemetry provider
// provider, shutdown := telemetry.NewProvider(telemetry.Config{
// ServiceName: "my-service",
// Enabled: true,
// })
// defer shutdown(context.Background())
//
// // Create pool with telemetry
// pool, _ := pyproc.NewPool(poolOpts, logger)
// pool.WithTracer(provider.Tracer("my-service"))
//
// // Calls are automatically traced
// ctx := context.Background()
// var result map[string]interface{}
// pool.Call(ctx, "predict", input, &result)
package telemetry
import (
"context"
"fmt"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/exporters/stdout/stdouttrace"
"go.opentelemetry.io/otel/sdk/resource"
sdktrace "go.opentelemetry.io/otel/sdk/trace"
semconv "go.opentelemetry.io/otel/semconv/v1.24.0"
"go.opentelemetry.io/otel/trace"
"go.opentelemetry.io/otel/trace/noop"
)
// Config holds configuration for telemetry provider
type Config struct {
// ServiceName is the name of the service for tracing
ServiceName string
// Enabled determines whether tracing is active
// When false, a no-op tracer is used with zero overhead
Enabled bool
// SamplingRate controls the fraction of traces to record (0.0 to 1.0)
// Default is 1.0 (record all traces)
SamplingRate float64
// ExporterType determines which exporter to use
// Supported values: "stdout", "otlp" (future)
// Default is "stdout"
ExporterType string
}
// Provider wraps an OpenTelemetry TracerProvider
type Provider struct {
provider trace.TracerProvider
shutdown func(context.Context) error
}
// NewProvider creates a new telemetry provider based on the given configuration.
// Returns a Provider and a shutdown function that should be called on application exit.
//
// When Config.Enabled is false, returns a no-op provider with zero overhead.
func NewProvider(cfg Config) (*Provider, func(context.Context) error) {
if !cfg.Enabled {
return &Provider{
provider: noop.NewTracerProvider(),
shutdown: func(context.Context) error { return nil },
}, func(context.Context) error { return nil }
}
// Set defaults
if cfg.ServiceName == "" {
cfg.ServiceName = "pyproc"
}
if cfg.SamplingRate == 0 {
cfg.SamplingRate = 1.0
}
if cfg.ExporterType == "" {
cfg.ExporterType = "stdout"
}
// Create resource with service name
res, err := resource.New(
context.Background(),
resource.WithAttributes(
semconv.ServiceNameKey.String(cfg.ServiceName),
),
)
if err != nil {
// Fallback to default resource if creation fails
res = resource.Default()
}
// Create exporter based on configuration
var exporter sdktrace.SpanExporter
switch cfg.ExporterType {
case "stdout":
exporter, err = stdouttrace.New(
stdouttrace.WithPrettyPrint(),
)
if err != nil {
// Fallback to no-op on error
return &Provider{
provider: noop.NewTracerProvider(),
shutdown: func(context.Context) error { return nil },
}, func(context.Context) error { return nil }
}
default:
// Future: Add OTLP exporter support
return &Provider{
provider: noop.NewTracerProvider(),
shutdown: func(context.Context) error { return nil },
}, func(context.Context) error { return nil }
}
// Create sampler based on sampling rate
var sampler sdktrace.Sampler
if cfg.SamplingRate >= 1.0 {
sampler = sdktrace.AlwaysSample()
} else if cfg.SamplingRate <= 0.0 {
sampler = sdktrace.NeverSample()
} else {
sampler = sdktrace.TraceIDRatioBased(cfg.SamplingRate)
}
// Create TracerProvider
tp := sdktrace.NewTracerProvider(
sdktrace.WithBatcher(exporter),
sdktrace.WithResource(res),
sdktrace.WithSampler(sampler),
)
// Set as global provider
otel.SetTracerProvider(tp)
shutdown := func(ctx context.Context) error {
return tp.Shutdown(ctx)
}
return &Provider{
provider: tp,
shutdown: shutdown,
}, shutdown
}
// Tracer returns a tracer with the given name
func (p *Provider) Tracer(name string, opts ...trace.TracerOption) trace.Tracer {
return p.provider.Tracer(name, opts...)
}
// Shutdown gracefully shuts down the provider, flushing any remaining spans
func (p *Provider) Shutdown(ctx context.Context) error {
if p.shutdown != nil {
return p.shutdown(ctx)
}
return nil
}
// IsEnabled returns true if tracing is enabled (not a no-op provider)
func (p *Provider) IsEnabled() bool {
_, ok := p.provider.(noop.TracerProvider)
return !ok
}
// ExtractTraceContext extracts OpenTelemetry trace context from a map.
// This is used for propagating trace context across process boundaries (UDS).
//
// The trace context is stored in W3C Trace Context format:
// - "traceparent": "00-<trace-id>-<span-id>-<flags>"
// - "tracestate": "<vendor-specific-state>" (optional)
//
// Returns a new context with the extracted span context.
func ExtractTraceContext(ctx context.Context, headers map[string]string) context.Context {
if headers == nil {
return ctx
}
// Parse traceparent header (W3C Trace Context format)
// Format: version-trace-id-parent-id-flags
traceparent := headers["traceparent"]
if traceparent == "" {
return ctx
}
// Parse the traceparent string
var version, traceID, spanID, flags string
n, err := fmt.Sscanf(traceparent, "%2s-%32s-%16s-%2s", &version, &traceID, &spanID, &flags)
if err != nil || n != 4 {
return ctx
}
// Parse trace ID
tid, err := trace.TraceIDFromHex(traceID)
if err != nil {
return ctx
}
// Parse span ID
sid, err := trace.SpanIDFromHex(spanID)
if err != nil {
return ctx
}
// Parse flags
var flagsByte byte
_, err = fmt.Sscanf(flags, "%02x", &flagsByte)
if err != nil {
return ctx
}
// Create span context
spanCtx := trace.NewSpanContext(trace.SpanContextConfig{
TraceID: tid,
SpanID: sid,
TraceFlags: trace.TraceFlags(flagsByte),
Remote: true,
})
// Return context with span context
return trace.ContextWithRemoteSpanContext(ctx, spanCtx)
}
// InjectTraceContext injects OpenTelemetry trace context into a map.
// This is used for propagating trace context across process boundaries (UDS).
//
// The trace context is stored in W3C Trace Context format:
// - "traceparent": "00-<trace-id>-<span-id>-<flags>"
//
// If the context does not contain a span, returns nil without modifying headers.
func InjectTraceContext(ctx context.Context, headers map[string]string) {
spanCtx := trace.SpanContextFromContext(ctx)
if !spanCtx.IsValid() {
return
}
// Format traceparent header
// TraceFlags is a byte, format as 2-digit hex
traceparent := fmt.Sprintf("00-%s-%s-%02x",
spanCtx.TraceID().String(),
spanCtx.SpanID().String(),
byte(spanCtx.TraceFlags()),
)
headers["traceparent"] = traceparent
// Future: Add tracestate support if needed
}
package pyproc
import (
"context"
"fmt"
"time"
)
// TimeoutKind identifies the source of a timeout.
type TimeoutKind string
// Timeout kind constants identify the source of a timeout.
const (
TimeoutKindContext TimeoutKind = "Context"
TimeoutKindPerCall TimeoutKind = "PerCall"
TimeoutKindTransport TimeoutKind = "Transport"
)
// TimeoutError represents a classified timeout error.
// It unwraps to context.DeadlineExceeded for errors.Is compatibility.
type TimeoutError struct {
Kind TimeoutKind
Timeout time.Duration
Cause error
}
// Error returns a human-readable timeout error message.
func (e *TimeoutError) Error() string {
if e == nil {
return "timeout"
}
if e.Timeout > 0 {
return fmt.Sprintf("%s timeout after %s", e.Kind, e.Timeout)
}
return fmt.Sprintf("%s timeout", e.Kind)
}
// Unwrap returns the underlying cause.
func (e *TimeoutError) Unwrap() error {
if e == nil {
return nil
}
return e.Cause
}
func newTimeoutError(kind TimeoutKind, timeout time.Duration, cause error) *TimeoutError {
return &TimeoutError{
Kind: kind,
Timeout: timeout,
Cause: cause,
}
}
func timeoutDuration(start, deadline time.Time) time.Duration {
if deadline.IsZero() {
return 0
}
d := deadline.Sub(start)
if d < 0 {
return 0
}
return d
}
// effectiveDeadline returns the earliest deadline based on context, per-call, and transport defaults.
// The returned kind indicates which source won.
func effectiveDeadline(ctx context.Context, start time.Time, perCall *time.Duration, transportDefault time.Duration) (time.Time, TimeoutKind, bool) {
var (
deadline time.Time
kind TimeoutKind
ok bool
)
if ctxDeadline, hasCtx := ctx.Deadline(); hasCtx {
deadline = ctxDeadline
kind = TimeoutKindContext
ok = true
}
if perCall != nil {
perDeadline := start.Add(*perCall)
if !ok || perDeadline.Before(deadline) {
deadline = perDeadline
kind = TimeoutKindPerCall
ok = true
}
}
if transportDefault > 0 {
transportDeadline := start.Add(transportDefault)
if !ok || transportDeadline.Before(deadline) {
deadline = transportDeadline
kind = TimeoutKindTransport
ok = true
}
}
return deadline, kind, ok
}
func timeoutErrorForContext(ctx context.Context, start time.Time) error {
if ctx.Err() != context.DeadlineExceeded {
return ctx.Err()
}
deadline, ok := ctx.Deadline()
if !ok {
return newTimeoutError(TimeoutKindContext, 0, context.DeadlineExceeded)
}
return newTimeoutError(TimeoutKindContext, timeoutDuration(start, deadline), context.DeadlineExceeded)
}
package pyproc
import (
"context"
"fmt"
"github.com/YuminosukeSato/pyproc/internal/protocol"
)
// Transport defines the interface for communication with Python workers
// This abstraction allows for different transport mechanisms (UDS, gRPC, etc.)
type Transport interface {
// Call sends a request and receives a response
Call(ctx context.Context, req *protocol.Request) (*protocol.Response, error)
// Close closes the transport connection
Close() error
// IsHealthy checks if the transport is healthy
IsHealthy() bool
}
// TransportConfig defines configuration for transport layer
type TransportConfig struct {
Type string // "uds", "grpc-tcp", "grpc-uds"
Address string // Socket path or network address
Options map[string]interface{}
}
// NewTransport creates a new transport based on configuration
func NewTransport(config TransportConfig, logger *Logger) (Transport, error) {
switch config.Type {
case "uds", "":
// Default to UDS for backward compatibility
return NewUDSTransport(config, logger)
case "multiplexed":
// Multiplexed transport with request ID support
return NewMultiplexedTransport(config, logger)
case "grpc-tcp", "grpc-uds":
return NewGRPCTransport(config, logger)
default:
return nil, fmt.Errorf("unknown transport type: %s", config.Type)
}
}
package pyproc
import (
// "context"
"fmt"
// "sync"
// "time"
// "google.golang.org/grpc"
// "google.golang.org/grpc/credentials/insecure"
// "google.golang.org/grpc/keepalive"
// pyprocv1 "github.com/YuminosukeSato/pyproc/api/v1"
// "github.com/YuminosukeSato/pyproc/internal/protocol"
)
// GRPCTransport implements Transport using gRPC
// TODO: Uncomment when gRPC implementation is ready
/*
type GRPCTransport struct {
config TransportConfig
logger *Logger
conn *grpc.ClientConn
client pyprocv1.PyProcServiceClient
mu sync.RWMutex
closed bool
healthy bool
}
*/
// NewGRPCTransport creates a new gRPC transport
func NewGRPCTransport(_ TransportConfig, _ *Logger) (Transport, error) {
// gRPC transport is not fully implemented yet
return nil, fmt.Errorf("gRPC transport is not yet implemented")
// Original implementation commented out for future use:
/*
if config.Address == "" {
return nil, fmt.Errorf("address is required for gRPC transport")
}
transport := &GRPCTransport{
config: config,
logger: logger,
healthy: false,
}
// Connect to gRPC server
if err := transport.connect(); err != nil {
return nil, err
}
return transport, nil
*/
}
/*
// connect establishes the gRPC connection
func (t *GRPCTransport) connect() error {
t.mu.Lock()
defer t.mu.Unlock()
// Close existing connection if any
if t.conn != nil {
_ = t.conn.Close()
}
// Configure gRPC options
opts := []grpc.DialOption{
grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithKeepaliveParams(keepalive.ClientParameters{
Time: 10 * time.Second,
Timeout: 3 * time.Second,
PermitWithoutStream: true,
}),
}
// Determine target based on transport type
var target string
switch t.config.Type {
case "grpc-tcp":
target = t.config.Address
case "grpc-uds":
target = "unix://" + t.config.Address
default:
return fmt.Errorf("unsupported gRPC transport type: %s", t.config.Type)
}
// grpc.NewClient doesn't use context for initial connection
// The connection is established lazily on first RPC
conn, err := grpc.NewClient(target, opts...)
if err != nil {
return fmt.Errorf("failed to connect to %s: %w", target, err)
}
t.conn = conn
t.client = pyprocv1.NewPyProcServiceClient(conn)
t.healthy = true
t.logger.Debug("gRPC transport connected", "address", t.config.Address, "type", t.config.Type)
return nil
}
*/
/*
// Call sends a request and receives a response via gRPC
func (t *GRPCTransport) Call(ctx context.Context, req *protocol.Request) (*protocol.Response, error) {
t.mu.RLock()
client := t.client
closed := t.closed
t.mu.RUnlock()
if closed {
return nil, fmt.Errorf("transport is closed")
}
if client == nil {
return nil, fmt.Errorf("gRPC client not initialized")
}
// Create gRPC request using the already marshaled Body
grpcReq := &pyprocv1.CallRequest{
Id: req.ID,
Method: req.Method,
Input: req.Body,
}
// Make gRPC call
grpcResp, err := client.Call(ctx, grpcReq)
if err != nil {
t.mu.Lock()
t.healthy = false
t.mu.Unlock()
return nil, fmt.Errorf("gRPC call failed: %w", err)
}
// Convert gRPC response to protocol.Response
resp := &protocol.Response{
ID: grpcResp.Id,
OK: grpcResp.Ok,
Body: grpcResp.Body,
ErrorMsg: grpcResp.ErrorMessage,
}
t.mu.Lock()
t.healthy = true
t.mu.Unlock()
return resp, nil
}
// Close closes the gRPC connection
func (t *GRPCTransport) Close() error {
t.mu.Lock()
defer t.mu.Unlock()
if t.closed {
return nil
}
t.closed = true
t.healthy = false
if t.conn != nil {
err := t.conn.Close()
t.conn = nil
t.client = nil
return err
}
return nil
}
// IsHealthy checks if the transport is healthy
func (t *GRPCTransport) IsHealthy() bool {
t.mu.RLock()
defer t.mu.RUnlock()
if t.closed || t.conn == nil || t.client == nil {
return false
}
// Optionally perform a health check RPC
if healthCheckEnabled, ok := t.config.Options["health_check"].(bool); ok && healthCheckEnabled {
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
defer cancel()
_, err := t.client.HealthCheck(ctx, &pyprocv1.HealthCheckRequest{})
if err != nil {
return false
}
}
return t.healthy
}
*/
package pyproc
import (
"context"
"fmt"
"net"
"sync"
"sync/atomic"
"time"
"github.com/YuminosukeSato/pyproc/internal/framing"
"github.com/YuminosukeSato/pyproc/internal/protocol"
)
// MultiplexedTransport implements Transport with request multiplexing
type MultiplexedTransport struct {
config TransportConfig
logger *Logger
conn net.Conn
framer *framing.Framer
// Request tracking
requestID atomic.Uint64
pending map[uint64]*pendingRequest
mu sync.RWMutex
writeMu sync.Mutex
// Connection state
closed atomic.Bool
closeOnce sync.Once
closeCh chan struct{}
// Reader goroutine
readerWg sync.WaitGroup
}
// pendingRequest tracks an in-flight request
type pendingRequest struct {
id uint64
responseCh chan *protocol.Response
errCh chan error
cleanupOnce sync.Once
}
func (p *pendingRequest) cleanup(t *MultiplexedTransport) {
p.cleanupOnce.Do(func() {
t.mu.Lock()
delete(t.pending, p.id)
t.mu.Unlock()
})
}
var marshalRequest = func(req *protocol.Request) ([]byte, error) {
return req.Marshal()
}
// NewMultiplexedTransport creates a new multiplexed transport
func NewMultiplexedTransport(config TransportConfig, logger *Logger) (*MultiplexedTransport, error) {
if config.Address == "" {
return nil, fmt.Errorf("address is required for multiplexed transport")
}
transport := &MultiplexedTransport{
config: config,
logger: logger,
pending: make(map[uint64]*pendingRequest),
closeCh: make(chan struct{}),
}
// Connect to the socket
if err := transport.connect(); err != nil {
return nil, err
}
// Start the reader goroutine
transport.readerWg.Add(1)
go transport.readLoop()
return transport, nil
}
// connect establishes the connection
func (t *MultiplexedTransport) connect() error {
timeout := 5 * time.Second
if timeoutVal, ok := t.config.Options["timeout"].(time.Duration); ok {
timeout = timeoutVal
}
// Connect with timeout
conn, err := net.DialTimeout("unix", t.config.Address, timeout)
if err != nil {
return fmt.Errorf("failed to connect to %s: %w", t.config.Address, err)
}
t.conn = conn
t.framer = framing.NewFramer(conn)
t.logger.Debug("multiplexed transport connected", "address", t.config.Address)
return nil
}
// readLoop continuously reads responses from the connection
func (t *MultiplexedTransport) readLoop() {
defer t.readerWg.Done()
for {
select {
case <-t.closeCh:
return
default:
}
// Read a frame
frame, err := t.framer.ReadFrame()
if err != nil {
if t.closed.Load() {
return // Expected on shutdown
}
t.logger.Error("failed to read frame", "error", err)
t.handleReadError(err)
return
}
// Parse response
var resp protocol.Response
if err := resp.Unmarshal(frame.Payload); err != nil {
t.logger.Error("failed to unmarshal response", "error", err)
continue
}
// Prefer request ID from payload; fall back to frame header if present.
if resp.ID == 0 && frame.Header.RequestID != 0 {
resp.ID = frame.Header.RequestID
}
// Find pending request
t.mu.RLock()
pending, ok := t.pending[resp.ID]
t.mu.RUnlock()
if !ok {
t.logger.Warn("received response for unknown request", "id", resp.ID)
continue
}
// Deliver response
select {
case pending.responseCh <- &resp:
default:
// Caller may have already timed out or exited
}
}
}
// handleReadError handles errors from the read loop
func (t *MultiplexedTransport) handleReadError(err error) {
t.mu.RLock()
pendingList := make([]*pendingRequest, 0, len(t.pending))
for _, pending := range t.pending {
pendingList = append(pendingList, pending)
}
t.mu.RUnlock()
// Notify all pending requests of the error
for _, pending := range pendingList {
select {
case pending.errCh <- fmt.Errorf("connection error: %w", err):
default:
}
pending.cleanup(t)
}
// Close the transport
t.closed.Store(true)
select {
case <-t.closeCh:
default:
close(t.closeCh)
}
}
// Call sends a request and receives a response
func (t *MultiplexedTransport) Call(ctx context.Context, req *protocol.Request) (*protocol.Response, error) {
if t.closed.Load() {
return nil, fmt.Errorf("transport is closed")
}
start := time.Now()
// Generate request ID
requestID := t.requestID.Add(1)
req.ID = requestID
// Create pending request
pending := &pendingRequest{
id: requestID,
responseCh: make(chan *protocol.Response, 1),
errCh: make(chan error, 1),
}
// Register pending request
t.mu.Lock()
t.pending[requestID] = pending
t.mu.Unlock()
// Clean up on exit
defer pending.cleanup(t)
// Marshal request
reqData, err := marshalRequest(req)
if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err)
}
// Create and write frame
frame := framing.NewFrame(requestID, reqData)
select {
case <-ctx.Done():
return nil, timeoutErrorForContext(ctx, start)
default:
}
t.writeMu.Lock()
err = t.framer.WriteFrame(frame)
t.writeMu.Unlock()
if err != nil {
return nil, fmt.Errorf("failed to write frame: %w", err)
}
transportTimeout := 30 * time.Second
if timeoutVal, ok := t.config.Options["request_timeout"].(time.Duration); ok {
transportTimeout = timeoutVal
}
deadline, kind, hasDeadline := effectiveDeadline(ctx, start, nil, transportTimeout)
var timerCh <-chan time.Time
var timer *time.Timer
if hasDeadline && kind != TimeoutKindContext {
timeout := timeoutDuration(start, deadline)
timer = time.NewTimer(timeout)
timerCh = timer.C
defer timer.Stop()
}
// Wait for response
select {
case resp := <-pending.responseCh:
return resp, nil
case err := <-pending.errCh:
return nil, err
case <-ctx.Done():
if ctx.Err() == context.DeadlineExceeded {
return nil, newTimeoutError(TimeoutKindContext, timeoutDuration(start, deadline), context.DeadlineExceeded)
}
return nil, ctx.Err()
case <-timerCh:
return nil, newTimeoutError(kind, timeoutDuration(start, deadline), context.DeadlineExceeded)
}
}
// Close closes the transport
func (t *MultiplexedTransport) Close() error {
var closeErr error
t.closeOnce.Do(func() {
t.closed.Store(true)
select {
case <-t.closeCh:
default:
close(t.closeCh)
}
// Close connection
if t.conn != nil {
closeErr = t.conn.Close()
}
// Wait for reader to finish
t.readerWg.Wait()
// Clean up pending requests
t.mu.RLock()
pendingList := make([]*pendingRequest, 0, len(t.pending))
for _, pending := range t.pending {
pendingList = append(pendingList, pending)
}
t.mu.RUnlock()
for _, pending := range pendingList {
select {
case pending.errCh <- fmt.Errorf("transport closed"):
default:
}
pending.cleanup(t)
}
})
return closeErr
}
// IsHealthy checks if the transport is healthy
func (t *MultiplexedTransport) IsHealthy() bool {
return !t.closed.Load() && t.conn != nil
}
package pyproc
import (
"context"
"fmt"
"net"
"sync"
"time"
"github.com/YuminosukeSato/pyproc/internal/framing"
"github.com/YuminosukeSato/pyproc/internal/protocol"
)
// UDSTransport implements Transport using Unix Domain Sockets
type UDSTransport struct {
config TransportConfig
logger *Logger
conn net.Conn
framer *framing.Framer
codec Codec
mu sync.Mutex
closed bool
healthy bool
lastUsed time.Time
}
// NewUDSTransport creates a new UDS transport
func NewUDSTransport(config TransportConfig, logger *Logger) (*UDSTransport, error) {
if config.Address == "" {
return nil, fmt.Errorf("address is required for UDS transport")
}
// Create codec (default to JSON)
codecType := CodecJSON
if codecTypeStr, ok := config.Options["codec"].(string); ok {
codecType = CodecType(codecTypeStr)
}
codec, err := NewCodec(codecType)
if err != nil {
return nil, fmt.Errorf("failed to create codec: %w", err)
}
transport := &UDSTransport{
config: config,
logger: logger,
codec: codec,
healthy: false,
}
// Establish connection
if err := transport.connect(); err != nil {
return nil, err
}
return transport, nil
}
// connect establishes the UDS connection
func (t *UDSTransport) connect() error {
t.mu.Lock()
defer t.mu.Unlock()
if t.conn != nil {
_ = t.conn.Close()
}
// Connect with timeout
timeout := 5 * time.Second
if timeoutVal, ok := t.config.Options["timeout"].(time.Duration); ok {
timeout = timeoutVal
}
conn, err := net.DialTimeout("unix", t.config.Address, timeout)
if err != nil {
return fmt.Errorf("failed to connect to %s: %w", t.config.Address, err)
}
t.conn = conn
t.framer = framing.NewFramer(conn)
t.healthy = true
t.lastUsed = time.Now()
t.logger.Debug("UDS transport connected", "address", t.config.Address)
return nil
}
// Call sends a request and receives a response
func (t *UDSTransport) Call(ctx context.Context, req *protocol.Request) (*protocol.Response, error) {
t.mu.Lock()
defer t.mu.Unlock()
if t.closed {
return nil, fmt.Errorf("transport is closed")
}
// Check connection health
if !t.healthy || t.conn == nil {
if err := t.reconnect(); err != nil {
return nil, fmt.Errorf("failed to reconnect: %w", err)
}
}
// Set deadline from context
if deadline, ok := ctx.Deadline(); ok {
if err := t.conn.SetDeadline(deadline); err != nil {
return nil, fmt.Errorf("failed to set deadline: %w", err)
}
defer func() { _ = t.conn.SetDeadline(time.Time{}) }()
}
// Send request
reqData, err := req.Marshal()
if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err)
}
if err := t.framer.WriteMessage(reqData); err != nil {
t.healthy = false
return nil, fmt.Errorf("failed to write request: %w", err)
}
// Read response
respData, err := t.framer.ReadMessage()
if err != nil {
t.healthy = false
return nil, fmt.Errorf("failed to read response: %w", err)
}
// Unmarshal response
var resp protocol.Response
if err := resp.Unmarshal(respData); err != nil {
return nil, fmt.Errorf("failed to unmarshal response: %w", err)
}
t.lastUsed = time.Now()
return &resp, nil
}
// reconnect attempts to reconnect the transport
func (t *UDSTransport) reconnect() error {
if t.conn != nil {
_ = t.conn.Close()
t.conn = nil
}
// Reconnect with timeout
timeout := 5 * time.Second
if timeoutVal, ok := t.config.Options["timeout"].(time.Duration); ok {
timeout = timeoutVal
}
conn, err := net.DialTimeout("unix", t.config.Address, timeout)
if err != nil {
return fmt.Errorf("failed to reconnect to %s: %w", t.config.Address, err)
}
t.conn = conn
t.framer = framing.NewFramer(conn)
t.healthy = true
t.lastUsed = time.Now()
t.logger.Debug("UDS transport reconnected", "address", t.config.Address)
return nil
}
// Close closes the transport connection
func (t *UDSTransport) Close() error {
t.mu.Lock()
defer t.mu.Unlock()
if t.closed {
return nil
}
t.closed = true
t.healthy = false
if t.conn != nil {
err := t.conn.Close()
t.conn = nil
return err
}
return nil
}
// IsHealthy checks if the transport is healthy
func (t *UDSTransport) IsHealthy() bool {
t.mu.Lock()
defer t.mu.Unlock()
if t.closed || t.conn == nil {
return false
}
// Check if connection has been idle too long
idleTimeout := 30 * time.Second
if idleVal, ok := t.config.Options["idle_timeout"].(time.Duration); ok {
idleTimeout = idleVal
}
if time.Since(t.lastUsed) > idleTimeout {
// Try a simple ping to verify connection
if err := t.ping(); err != nil {
t.healthy = false
return false
}
}
return t.healthy
}
// ping sends a health check request
func (t *UDSTransport) ping() error {
req, err := protocol.NewRequest(0, "health", nil)
if err != nil {
return err
}
reqData, err := req.Marshal()
if err != nil {
return err
}
// Set a short timeout for ping
_ = t.conn.SetDeadline(time.Now().Add(1 * time.Second))
defer func() { _ = t.conn.SetDeadline(time.Time{}) }()
if err := t.framer.WriteMessage(reqData); err != nil {
return err
}
_, err = t.framer.ReadMessage()
return err
}
package pyproc
import (
"context"
"fmt"
"net"
"os"
"os/exec"
"sync"
"sync/atomic"
"time"
)
// WorkerState represents the state of a worker
type WorkerState int32
const (
// WorkerStateStopped indicates the worker is not running
WorkerStateStopped WorkerState = iota
// WorkerStateStarting indicates the worker is in the process of starting
WorkerStateStarting
// WorkerStateRunning indicates the worker is running and ready to accept requests
WorkerStateRunning
// WorkerStateStopping indicates the worker is in the process of stopping
WorkerStateStopping
)
// WorkerConfig defines configuration for a single worker
type WorkerConfig struct {
ID string
SocketPath string
PythonExec string
WorkerScript string
Env map[string]string
StartTimeout time.Duration
}
// Worker represents a single Python worker process
type Worker struct {
cfg WorkerConfig
logger *Logger
cmd *exec.Cmd
cmdMu sync.RWMutex
waitOnce sync.Once
waitErr error
state atomic.Int32
pid atomic.Int32
stopCh chan struct{}
doneCh chan struct{}
}
// NewWorker creates a new worker instance
func NewWorker(cfg WorkerConfig, logger *Logger) *Worker {
if logger == nil {
logger = NewLogger(LoggingConfig{Level: "info", Format: "text"})
}
return &Worker{
cfg: cfg,
logger: logger.WithWorker(cfg.ID),
stopCh: make(chan struct{}),
doneCh: make(chan struct{}),
}
}
// Start starts the worker process
func (w *Worker) Start(ctx context.Context) error {
if !w.state.CompareAndSwap(int32(WorkerStateStopped), int32(WorkerStateStarting)) {
return fmt.Errorf("worker already started or starting")
}
w.logger.InfoContext(ctx, "Starting worker",
"socket_path", w.cfg.SocketPath,
"script", w.cfg.WorkerScript)
// Reset wait-related fields for new process
w.cmdMu.Lock()
w.waitOnce = sync.Once{}
w.waitErr = nil
w.cmdMu.Unlock()
// Clean up any existing socket file
if err := os.Remove(w.cfg.SocketPath); err != nil && !os.IsNotExist(err) {
w.logger.WarnContext(ctx, "Failed to remove existing socket file",
"error", err)
}
// Create the command
cmd := exec.CommandContext(ctx, w.cfg.PythonExec, w.cfg.WorkerScript)
// Set environment variables
cmd.Env = os.Environ()
for k, v := range w.cfg.Env {
cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", k, v))
}
cmd.Env = append(cmd.Env, fmt.Sprintf("PYPROC_SOCKET_PATH=%s", w.cfg.SocketPath))
// Capture output for debugging
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
// Start the process
if err := cmd.Start(); err != nil {
w.state.Store(int32(WorkerStateStopped))
return fmt.Errorf("failed to start worker process: %w", err)
}
w.cmdMu.Lock()
w.cmd = cmd
w.cmdMu.Unlock()
w.pid.Store(int32(cmd.Process.Pid))
w.logger.InfoContext(ctx, "Worker process started", "pid", cmd.Process.Pid)
// Wait for the socket to be available
socketReady := make(chan error, 1)
go func() {
ticker := time.NewTicker(100 * time.Millisecond)
defer ticker.Stop()
timeout := time.After(w.cfg.StartTimeout)
for {
select {
case <-ticker.C:
// Try to connect to the socket
conn, err := net.Dial("unix", w.cfg.SocketPath)
if err == nil {
_ = conn.Close()
socketReady <- nil
return
}
case <-timeout:
socketReady <- fmt.Errorf("worker start timeout after %v", w.cfg.StartTimeout)
return
case <-ctx.Done():
socketReady <- ctx.Err()
return
}
}
}()
// Start monitoring goroutine
go w.monitor()
// Wait for socket to be ready
if err := <-socketReady; err != nil {
if err := w.Stop(); err != nil {
w.logger.Error("failed to stop worker after socket error", "error", err)
}
return err
}
w.state.Store(int32(WorkerStateRunning))
w.logger.InfoContext(ctx, "Worker ready")
return nil
}
// Stop stops the worker process
func (w *Worker) Stop() error {
if !w.state.CompareAndSwap(int32(WorkerStateRunning), int32(WorkerStateStopping)) {
// Also try from starting state
if !w.state.CompareAndSwap(int32(WorkerStateStarting), int32(WorkerStateStopping)) {
return nil // Already stopped or stopping
}
}
w.logger.Info("Stopping worker")
// Signal stop
close(w.stopCh)
// Get the command
w.cmdMu.RLock()
cmd := w.cmd
w.cmdMu.RUnlock()
if cmd != nil && cmd.Process != nil {
// Try graceful shutdown first
if err := cmd.Process.Signal(os.Interrupt); err != nil {
w.logger.Warn("Failed to send interrupt signal", "error", err)
}
// Wait for process to exit with timeout
done := make(chan error, 1)
go func() {
done <- w.wait()
}()
select {
case <-done:
// Process exited gracefully
case <-time.After(5 * time.Second):
// Force kill after timeout
w.logger.Warn("Worker did not exit gracefully, forcing kill")
if err := cmd.Process.Kill(); err != nil {
w.logger.Error("Failed to kill worker process", "error", err)
}
<-done // Wait for process to be reaped
}
}
// Clean up socket file
if err := os.Remove(w.cfg.SocketPath); err != nil && !os.IsNotExist(err) {
w.logger.Warn("Failed to remove socket file", "error", err)
}
// Wait for monitor to finish
<-w.doneCh
w.state.Store(int32(WorkerStateStopped))
w.pid.Store(0)
w.logger.Info("Worker stopped")
return nil
}
// wait wraps cmd.Wait() to ensure it's called only once
func (w *Worker) wait() error {
w.cmdMu.RLock()
cmd := w.cmd
w.cmdMu.RUnlock()
if cmd != nil {
w.waitOnce.Do(func() {
err := cmd.Wait()
w.cmdMu.Lock()
w.waitErr = err
w.cmdMu.Unlock()
})
}
w.cmdMu.RLock()
err := w.waitErr
w.cmdMu.RUnlock()
return err
}
// Restart restarts the worker process
func (w *Worker) Restart(ctx context.Context) error {
w.logger.InfoContext(ctx, "Restarting worker")
if err := w.Stop(); err != nil {
return fmt.Errorf("failed to stop worker: %w", err)
}
// Reset channels for new process
// Note: We don't reset waitOnce here as it can cause race conditions
// Each new process will get its own waitOnce via Start()
w.stopCh = make(chan struct{})
w.doneCh = make(chan struct{})
if err := w.Start(ctx); err != nil {
return fmt.Errorf("failed to start worker: %w", err)
}
return nil
}
// monitor monitors the worker process and handles unexpected exits
func (w *Worker) monitor() {
defer close(w.doneCh)
w.cmdMu.RLock()
cmd := w.cmd
w.cmdMu.RUnlock()
if cmd == nil {
return
}
// Wait for either stop signal or process exit
waitCh := make(chan error, 1)
go func() {
waitCh <- w.wait()
}()
select {
case <-w.stopCh:
// Normal stop requested
return
case err := <-waitCh:
// Process exited unexpectedly
if w.state.Load() == int32(WorkerStateRunning) {
if err != nil {
w.logger.Error("Worker process exited unexpectedly", "error", err)
} else {
w.logger.Warn("Worker process exited unexpectedly with status 0")
}
w.state.Store(int32(WorkerStateStopped))
w.pid.Store(0)
}
}
}
// IsRunning returns true if the worker is running
func (w *Worker) IsRunning() bool {
return w.state.Load() == int32(WorkerStateRunning)
}
// GetState returns the current worker state
func (w *Worker) GetState() WorkerState {
return WorkerState(w.state.Load())
}
// GetPID returns the process ID of the worker
func (w *Worker) GetPID() int {
return int(w.pid.Load())
}
// GetID returns the worker ID
func (w *Worker) GetID() string {
return w.cfg.ID
}
// GetSocketPath returns the socket path
func (w *Worker) GetSocketPath() string {
return w.cfg.SocketPath
}
package pyproc
import (
"context"
"fmt"
"net"
"sync/atomic"
"time"
)
// ExternalWorkerState represents the state of an external worker.
type ExternalWorkerState int32
const (
// ExternalWorkerStopped indicates the external worker is not connected.
ExternalWorkerStopped ExternalWorkerState = iota
// ExternalWorkerRunning indicates the external worker is connected and healthy.
ExternalWorkerRunning
)
const (
defaultConnectTimeout = 5 * time.Second
// defaultMaxRetries with defaultRetryInterval yields ~4 min total wait
// (500ms + 1s + 2s + 4s + 8s + 16s + 32s + 64s + 128s ≈ 255.5s).
defaultMaxRetries = 10
defaultRetryInterval = 500 * time.Millisecond
)
// ExternalWorkerOptions configures an ExternalWorker.
type ExternalWorkerOptions struct {
// SocketPath is the Unix Domain Socket path to connect to.
SocketPath string
// ConnectTimeout controls how long each dial attempt waits.
// If zero, defaults to 5s.
ConnectTimeout time.Duration
// MaxRetries is the maximum number of connection retry attempts in Start.
// If zero, defaults to 10.
MaxRetries int
// RetryInterval is the initial interval between retries. Each subsequent
// retry doubles the interval (exponential backoff).
// If zero, defaults to 500ms.
RetryInterval time.Duration
}
// ExternalWorker represents a pre-existing Python worker process managed
// outside of pyproc (e.g. a Kubernetes sidecar container). It connects to the
// worker via a well-known Unix Domain Socket path rather than spawning a child
// process.
type ExternalWorker struct {
socketPath string
connectTimeout time.Duration
maxRetries int
retryInterval time.Duration
state atomic.Int32
}
// NewExternalWorker creates a new ExternalWorker that connects to the given
// Unix Domain Socket path. The connectTimeout controls how long dial attempts
// wait; if zero, a default of 5 s is used. For retry support, use
// NewExternalWorkerWithOptions instead.
func NewExternalWorker(socketPath string, connectTimeout time.Duration) *ExternalWorker {
return NewExternalWorkerWithOptions(ExternalWorkerOptions{
SocketPath: socketPath,
ConnectTimeout: connectTimeout,
MaxRetries: 1, // no retry for backward compat
})
}
// NewExternalWorkerWithOptions creates a new ExternalWorker from options.
func NewExternalWorkerWithOptions(opts ExternalWorkerOptions) *ExternalWorker {
if opts.ConnectTimeout <= 0 {
opts.ConnectTimeout = defaultConnectTimeout
}
if opts.MaxRetries <= 0 {
opts.MaxRetries = defaultMaxRetries
}
if opts.RetryInterval <= 0 {
opts.RetryInterval = defaultRetryInterval
}
return &ExternalWorker{
socketPath: opts.SocketPath,
connectTimeout: opts.ConnectTimeout,
maxRetries: opts.MaxRetries,
retryInterval: opts.RetryInterval,
}
}
// Start verifies that the external worker's socket is reachable. It retries
// with exponential backoff according to the configured MaxRetries and
// RetryInterval. It does not spawn a process; the worker must already be
// running.
//
// In production, callers should pass a context with a deadline to bound the
// total wait time (e.g. context.WithTimeout). Without a deadline, Start may
// block for the full backoff duration (~4 min with defaults).
func (w *ExternalWorker) Start(ctx context.Context) error {
var lastErr error
interval := w.retryInterval
for i := 0; i < w.maxRetries; i++ {
if i > 0 {
select {
case <-ctx.Done():
return fmt.Errorf("external worker connection cancelled: %w", ctx.Err())
case <-time.After(interval):
}
interval *= 2
}
conn, err := net.DialTimeout("unix", w.socketPath, w.connectTimeout)
if err != nil {
lastErr = err
continue
}
_ = conn.Close()
w.state.Store(int32(ExternalWorkerRunning))
return nil
}
return fmt.Errorf("external worker socket unreachable at %s after %d attempts: %w",
w.socketPath, w.maxRetries, lastErr)
}
// Stop transitions the external worker to the stopped state. It does not
// terminate the remote process since pyproc does not own it.
func (w *ExternalWorker) Stop() error {
w.state.Store(int32(ExternalWorkerStopped))
return nil
}
// IsHealthy returns true if the external worker's socket is connectable.
func (w *ExternalWorker) IsHealthy(_ context.Context) bool {
conn, err := net.DialTimeout("unix", w.socketPath, w.connectTimeout)
if err != nil {
return false
}
_ = conn.Close()
return true
}
// GetSocketPath returns the Unix Domain Socket path for this worker.
func (w *ExternalWorker) GetSocketPath() string {
return w.socketPath
}