Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 14 additions & 17 deletions Sources/MySQLDriver/Connection.swift
Original file line number Diff line number Diff line change
Expand Up @@ -80,19 +80,20 @@ public extension MySQL.Connection {
var pos = 0
msh.proto_version = data[pos]
pos += 1
msh.server_version = data[pos..<data.count].string()
pos += (msh.server_version?.utf8.count)! + 1
msh.conn_id = data[pos...pos+4].uInt32()
let version_len = data[pos...].firstIndex(of: 0)! - pos
msh.server_version = data[pos...pos+version_len].string()
pos += version_len + 1// add null string end char
msh.conn_id = data[pos..<pos+4].uInt32()
pos += 4
msh.scramble = Array(data[pos..<pos+8])
pos += 8 + 1
msh.cap_flags = data[pos...pos+2].uInt16()
msh.cap_flags = data[pos..<pos+2].uInt16()
pos += 2
msh.server_lang = data[pos];
pos += 1;
msh.server_status = data[pos...pos+2].uInt16()
msh.server_status = data[pos..<pos+2].uInt16()
pos += 2
msh.ext_cap_flags = data[pos...pos+2].uInt16()
msh.ext_cap_flags = data[pos..<pos+2].uInt16()
pos += 2
let auth_len = Int(data[pos]);
pos += 1;
Expand Down Expand Up @@ -135,7 +136,7 @@ public extension MySQL.Connection {
}
if self.mysql_authSwitch?.auth_name == "caching_sha2_password" {

var token = MySQL.Utils.calculateToken(self.passwd!, scramble: self.mysql_Handshake!.scramble!);
let token = MySQL.Utils.calculateToken(self.passwd!, scramble: self.mysql_Handshake!.scramble!);

try socket?.writePacket(token)

Expand All @@ -147,10 +148,10 @@ public extension MySQL.Connection {
}

}

private func readAuthSwitch() throws -> MySQL.mysql_auth_switch {
var msh = MySQL.mysql_auth_switch()

if let data = try socket?.readPacket() {
var pos = 0;
msh.status = data[pos]
Expand All @@ -164,7 +165,6 @@ public extension MySQL.Connection {
}
return msh;
}


private func readAuthResponse() throws {
if let data = try socket?.readPacket() {
Expand All @@ -179,7 +179,7 @@ public extension MySQL.Connection {
}

private func requestServerKey() throws {
var arr:[UInt8] = [2];
let arr:[UInt8] = [2];
try socket?.writePacket(arr);

try readServerKey();
Expand All @@ -192,7 +192,7 @@ public extension MySQL.Connection {

let serverPublicKey = Array(data[1..<data.count]);

var epwd = MySQL.Utils.encPasswd(self.passwd!, scramble: self.mysql_Handshake!.scramble!, key:serverPublicKey);
let epwd = MySQL.Utils.encPasswd(self.passwd!, scramble: self.mysql_Handshake!.scramble!, key:serverPublicKey);

try socket?.writePacket(epwd)

Expand Down Expand Up @@ -261,23 +261,20 @@ public extension MySQL.Connection {
}
arr.append(0)

arr.append(contentsOf:"mysql_native_password".utf8)
arr.append(contentsOf: mysql_Handshake!.auth_plugin!.utf8)
arr.append(0)

try socket?.writePacket(arr)

if mysql_Handshake?.auth_plugin == "mysql_native_password" {
// don't do auth switch
}
else {
self.mysql_authSwitch = try self.readAuthSwitch()
try self.sendAuthSwitchResponse();
}


}


func readColumns(_ count:Int) throws ->[Field]? {

self.columns = [Field](repeating:Field(), count: count)
Expand Down
46 changes: 13 additions & 33 deletions Sources/MySQLDriver/Socket.swift
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ extension Socket {
open class Socket {

let s : Int32
var bytesToRead : UInt32
var packnr : Int
var socketInUse = false
var addr : sockaddr_in?
Expand All @@ -44,7 +43,6 @@ open class Socket {

init(host : String, port : Int) throws {
// create socket to MySQL Server
bytesToRead = 0
packnr = 0
#if os(Linux)
s = socket(AF_INET, Int32(SOCK_STREAM.rawValue), 0)
Expand Down Expand Up @@ -139,20 +137,25 @@ open class Socket {
return String(cString:UnsafePointer(strerror(errno))) //?? "Error: \(errno)"
}

func readNUInt8(_ n:UInt32) throws -> [UInt8] {
func readNUInt8(_ n: UInt32) throws -> [UInt8] {
var buffer = [UInt8](repeating: 0, count: Int(n))
var read = 0

while read < Int(n) {
read += recv(s, &buffer[read], Int(n) - read, 0)
let result = recv(s, &buffer[read], Int(n) - read, 0)

if read <= 0 {
if result < 0 {
throw SocketError.recvFailed(Socket.descriptionOfLastError())
} else if result == 0 {
// Connection closed by peer, handle gracefully
if read == 0 {
throw SocketError.recvFailed("Connection closed by peer with no data read")
} else {
break // Exit the loop if some data has been read
}
}
}

if bytesToRead >= UInt32(n) {
bytesToRead -= UInt32(n)

read += result
}

return buffer
Expand All @@ -161,23 +164,19 @@ open class Socket {

func readHeader() throws -> (UInt32, Int) {
let b = try readNUInt8(3).uInt24()

let pn = try readNUInt8(1)[0]
bytesToRead = b

return (b, Int(pn))
}

func readPacket() throws -> [UInt8] {
let (len, pknr) = try readHeader()
bytesToRead = len
self.packnr = pknr
return try readNUInt8(len)
}

func writePacket(_ data:[UInt8]) throws {
try writeHeader(UInt32(data.count), pn: UInt8(self.packnr + 1))
try writeBuffer(data)
try writeBuffer(data)
}

func writeBuffer(_ buffer:[UInt8]) throws {
Expand All @@ -192,7 +191,6 @@ open class Socket {
#endif
if s <= 0 {
throw SocketError.writeFailed(Socket.descriptionOfLastError())

}
sent += s
}
Expand All @@ -213,22 +211,4 @@ open class Socket {
return isLittleEndian ? _OSSwapInt16(port) : port
#endif
}

/*
func lockSocket() {
while socketInUse {

}
socketInUse = true
}

func unlockSocket() {
socketInUse = false
}

func socketLocked() -> Bool {
return socketInUse
}
*/

}
34 changes: 21 additions & 13 deletions Sources/MySQLDriver/Utils.swift
Original file line number Diff line number Diff line change
Expand Up @@ -358,15 +358,10 @@ extension MySQL {
[kSecAttrKeyType: kSecAttrKeyTypeRSA,
kSecAttrKeyClass: kSecAttrKeyClassPublic] as CFDictionary, nil)

let key_size = SecKeyGetBlockSize(publickeysi!)

var encrypt_bytes = [UInt8](repeating: 0, count: key_size)

var output_size : Int = key_size

SecKeyEncrypt(publickeysi!, SecPadding.OAEP, data, data.count, &encrypt_bytes, &output_size)

return encrypt_bytes;
var unmanagedError: Unmanaged<CFError>?
let encrypt_data = SecKeyCreateEncryptedData(publickeysi!, .rsaEncryptionOAEPSHA1AESGCM, Data(data) as CFData, &unmanagedError)! as Data

return Array(encrypt_data)
}
return [];
}
Expand Down Expand Up @@ -773,12 +768,25 @@ extension Sequence where Iterator.Element == UInt8 {
let arr = self.map { (elem) -> UInt8 in
return elem
}

guard (arr.count > 0) && (arr[arr.count-1] == 0) else {
return ""

// Verificar que el array no esté vacío
guard !arr.isEmpty else {
return nil
}

// Verificar que termine en null terminator
guard arr.last == 0 else {
// Si no termina en null, convertir directamente
return String(bytes: self, encoding: .utf8)
}

return String(cString: UnsafePointer<UInt8>(arr))
// Crear string de forma segura
return arr.withUnsafeBufferPointer { bufferPointer in
guard let baseAddress = bufferPointer.baseAddress else {
return nil
}
return String(cString: baseAddress)
}
}

static func UInt24Array(_ val: UInt32) -> [UInt8]{
Expand Down