diff --git a/hypersock_http/client.odin b/hypersock_http/client.odin index 83b81b9..ed460b1 100644 --- a/hypersock_http/client.odin +++ b/hypersock_http/client.odin @@ -11,36 +11,35 @@ package hypersock_http * - Concurrent-safe operations */ +import "core:fmt" import "core:net" import "core:os" -import "core:fmt" -import "core:strings" import "core:strconv" -import "core:time" +import "core:strings" import "core:sync" -import "core:mem" +import "core:time" // Simple HTTP GET request -get :: proc(url: string) -> (int, []byte, os.Errno) { +get :: proc(url: string) -> (int, []byte, os.Error) { return client_get(client_default(), url) } // Simple HTTP POST request -post :: proc(url: string, body: []byte) -> (int, []byte, os.Errno) { +post :: proc(url: string, body: []byte) -> (int, []byte, os.Error) { return client_post(client_default(), url, body) } // HTTP GET with custom client -client_get :: proc(c: ^Client, url: string) -> (int, []byte, os.Errno) { +client_get :: proc(c: ^Client, url: string) -> (int, []byte, os.Error) { // Parse URL uri, ok := uri_parse(url) if !ok { - return 0, nil, os.EINVAL + return 0, nil, invalid_parameter_error() } - + // Get or create host client host_key := fmt.tprintf("%s:%d", uri.host, uri.port) - + sync.rw_mutex_lock(&c.mutex) hc, exists := c.host_clients[host_key] if !exists { @@ -48,42 +47,42 @@ client_get :: proc(c: ^Client, url: string) -> (int, []byte, os.Errno) { c.host_clients[host_key] = hc } sync.rw_mutex_unlock(&c.mutex) - + // Build request req: Request request_reset(&req) req.method = .GET req.uri = uri - + // Set default headers header_set(&req.header, "Host", uri.host) header_set(&req.header, "User-Agent", c.name) header_set(&req.header, "Accept", "*/*") header_set(&req.header, "Connection", "keep-alive") - + // Execute request resp: Response response_reset(&resp) - + err := host_client_do(hc, &req, &resp) if err != os.ERROR_NONE { return 0, nil, err } - + return resp.status_code, resp.body, os.ERROR_NONE } // HTTP POST with custom client -client_post :: proc(c: ^Client, url: string, body: []byte) -> (int, []byte, os.Errno) { +client_post :: proc(c: ^Client, url: string, body: []byte) -> (int, []byte, os.Error) { // Parse URL uri, ok := uri_parse(url) if !ok { - return 0, nil, os.EINVAL + return 0, nil, invalid_parameter_error() } - + // Get or create host client host_key := fmt.tprintf("%s:%d", uri.host, uri.port) - + sync.rw_mutex_lock(&c.mutex) hc, exists := c.host_clients[host_key] if !exists { @@ -91,14 +90,14 @@ client_post :: proc(c: ^Client, url: string, body: []byte) -> (int, []byte, os.E c.host_clients[host_key] = hc } sync.rw_mutex_unlock(&c.mutex) - + // Build request req: Request request_reset(&req) req.method = .POST req.uri = uri req.body = body - + // Set default headers header_set(&req.header, "Host", uri.host) header_set(&req.header, "User-Agent", c.name) @@ -106,16 +105,16 @@ client_post :: proc(c: ^Client, url: string, body: []byte) -> (int, []byte, os.E header_set(&req.header, "Content-Length", fmt.tprintf("%d", len(body))) header_set(&req.header, "Accept", "*/*") header_set(&req.header, "Connection", "keep-alive") - + // Execute request resp: Response response_reset(&resp) - + err := host_client_do(hc, &req, &resp) if err != os.ERROR_NONE { return 0, nil, err } - + return resp.status_code, resp.body, os.ERROR_NONE } @@ -137,14 +136,14 @@ host_client_new :: proc(client: ^Client, host: string, port: int, is_tls: bool) } // Execute request on host client -host_client_do :: proc(hc: ^HostClient, req: ^Request, resp: ^Response) -> os.Errno { +host_client_do :: proc(hc: ^HostClient, req: ^Request, resp: ^Response) -> os.Error { // Acquire connection from pool cc, err := acquire_conn(hc) if err != os.ERROR_NONE { return err } defer release_conn(hc, cc) - + // Set timeouts if configured if hc.write_timeout > 0 { net.set_option(cc.conn, net.Socket_Option.Send_Timeout, int(hc.write_timeout)) @@ -152,39 +151,39 @@ host_client_do :: proc(hc: ^HostClient, req: ^Request, resp: ^Response) -> os.Er if hc.read_timeout > 0 { net.set_option(cc.conn, net.Socket_Option.Receive_Timeout, int(hc.read_timeout)) } - + // Build and send HTTP request - err = write_request(cc.conn, req) + err = write_request(cc, req) if err != os.ERROR_NONE { // Connection is bad, don't return it to pool cc.conn = 0 return err } - + // Read HTTP response - err = read_response(cc.conn, resp, hc.max_response_body_size) + err = read_response(cc, resp, hc.max_response_body_size) if err != os.ERROR_NONE { // Connection is bad, don't return it to pool cc.conn = 0 return err } - + // Update last use time cc.last_use = time.now() - + return os.ERROR_NONE } // Acquire a connection from the pool -acquire_conn :: proc(hc: ^HostClient) -> (^clientConn, os.Errno) { +acquire_conn :: proc(hc: ^HostClient) -> (^clientConn, os.Error) { sync.mutex_lock(&hc.mutex) defer sync.mutex_unlock(&hc.mutex) - + // Try to find an idle connection now := time.now() for i := len(hc.conns) - 1; i >= 0; i -= 1 { cc := hc.conns[i] - + // Check if connection is still alive if cc.conn != 0 { // Remove from slice @@ -198,47 +197,48 @@ acquire_conn :: proc(hc: ^HostClient) -> (^clientConn, os.Errno) { free(cc) } } - + // Check if we can create a new connection if hc.conns_count >= hc.max_conns { - return nil, os.ENOENT // Too many open files + return nil, not_found_error() // Too many open files } - + // Create new connection socket, net_err := net.dial_tcp(hc.addr) if net_err != nil { - return nil, os.ECONNREFUSED + return nil, connection_refused_error() } - + // Perform TLS handshake if needed tls_socket: ^TLS_Socket is_tls_conn: bool = false - + if hc.is_tls { // Extract server name from host configuration server_name := hc.addr if idx := strings.index(server_name, ":"); idx != -1 { server_name = server_name[:idx] } - + // Check if client has insecure skip verify insecure_skip := false if hc.client.tls_config != nil { insecure_skip = hc.client.tls_config.insecure_skip_verify } - + // Perform TLS handshake - tls_socket, tls_err := perform_tls_handshake_on_socket(socket, server_name, insecure_skip) + tls_err: os.Error + tls_socket, tls_err = perform_tls_handshake_on_socket(socket, server_name, insecure_skip) if tls_err != os.ERROR_NONE { net.close(socket) fmt.println("TLS handshake failed:", tls_err) return nil, tls_err } - + is_tls_conn = true - fmt.println("TLS handshake completed successfully") + // fmt.println("TLS handshake completed successfully") } - + cc := new(clientConn) cc.conn = socket cc.tls_socket = tls_socket @@ -246,7 +246,7 @@ acquire_conn :: proc(hc: ^HostClient) -> (^clientConn, os.Errno) { cc.created = now cc.last_use = now hc.conns_count += 1 - + return cc, os.ERROR_NONE } @@ -265,9 +265,9 @@ release_conn :: proc(hc: ^HostClient, cc: ^clientConn) { free(cc) return } - + sync.mutex_lock(&hc.mutex) - + // Check if we should keep this connection if hc.conns_count > hc.max_conns { // Too many connections, close this one @@ -283,7 +283,7 @@ release_conn :: proc(hc: ^HostClient, cc: ^clientConn) { cc.last_use = time.now() append(&hc.conns, cc) } - + sync.mutex_unlock(&hc.mutex) } @@ -291,7 +291,7 @@ release_conn :: proc(hc: ^HostClient, cc: ^clientConn) { host_client_close :: proc(hc: ^HostClient) { sync.mutex_lock(&hc.mutex) defer sync.mutex_unlock(&hc.mutex) - + for cc in hc.conns { if cc.is_tls && cc.tls_socket != nil { tls_close(cc.tls_socket) @@ -306,66 +306,95 @@ host_client_close :: proc(hc: ^HostClient) { } // Write HTTP request to connection -write_request :: proc(conn: net.TCP_Socket, req: ^Request) -> os.Errno { +write_request :: proc(cc: ^clientConn, req: ^Request) -> os.Error { // Build HTTP request request := strings.builder_make() defer strings.builder_destroy(&request) - + // Request line - fmt.sbprintf(&request, "%s %s HTTP/1.1\r\n", method_to_string(req.method), req.uri.path) - + request_target := req.uri.path + if request_target == "" { + request_target = "/" + } + if req.uri.query != "" { + request_target = fmt.tprintf("%s?%s", request_target, req.uri.query) + } + fmt.sbprintf(&request, "%s %s HTTP/1.1\r\n", method_to_string(req.method), request_target) + // Headers sync.mutex_lock(&req.header.mutex) - for key, value in req.header.data { - fmt.sbprintf(&request, "%s: %s\r\n", key, value) + for key, values in req.header.data { + for value in values { + fmt.sbprintf(&request, "%s: %s\r\n", key, value) + } } sync.mutex_unlock(&req.header.mutex) - + if len(req.body) > 0 && !header_has(&req.header, "Content-Length") { + fmt.sbprintf(&request, "content-length: %d\r\n", len(req.body)) + } + // Empty line fmt.sbprintf(&request, "\r\n") - + // Body if len(req.body) > 0 { fmt.sbprintf(&request, "%s", string(req.body)) } - + // Send request data := transmute([]byte)strings.to_string(request) - _, send_err := net.send_tcp(conn, data) - if send_err != nil { - return os.ECONNREFUSED + if cc.is_tls && cc.tls_socket != nil { + _, send_err := tls_write(cc.tls_socket, data) + if send_err != os.ERROR_NONE { + return send_err + } + } else { + _, send_err := net.send_tcp(cc.conn, data) + if send_err != nil { + return connection_refused_error() + } } return os.ERROR_NONE } // Read HTTP response from connection -read_response :: proc(conn: net.TCP_Socket, resp: ^Response, max_body_size: int) -> os.Errno { +read_response :: proc(cc: ^clientConn, resp: ^Response, max_body_size: int) -> os.Error { // Read response into buffer buf := make([]byte, 65536) defer delete(buf) - - n, recv_err := net.recv_tcp(conn, buf) - if recv_err != nil { - return os.ECONNREFUSED + + n := 0 + if cc.is_tls && cc.tls_socket != nil { + recv_n, recv_err := tls_read(cc.tls_socket, buf) + if recv_err != os.ERROR_NONE { + return recv_err + } + n = recv_n + } else { + recv_n, recv_err := net.recv_tcp(cc.conn, buf) + if recv_err != nil { + return connection_refused_error() + } + n = recv_n } if n == 0 { - return os.ECONNREFUSED + return connection_refused_error() } - + response := string(buf[:n]) - + // Parse status line lines := strings.split(response, "\r\n") defer delete(lines) - + if len(lines) < 1 { - return os.EINVAL + return invalid_parameter_error() } - + // Parse status code using strconv parts := strings.split(lines[0], " ") defer delete(parts) - + if len(parts) >= 2 { // Parse status code as integer if code, ok := strconv.parse_int(parts[1], 10); ok { @@ -374,7 +403,7 @@ read_response :: proc(conn: net.TCP_Socket, resp: ^Response, max_body_size: int) resp.status_code = 200 // Default on parse error } } - + // Parse headers header_end := 0 for i := 1; i < len(lines); i += 1 { @@ -382,48 +411,50 @@ read_response :: proc(conn: net.TCP_Socket, resp: ^Response, max_body_size: int) header_end = i break } - + colon_idx := strings.index(lines[i], ":") if colon_idx != -1 { key := strings.to_lower(strings.trim_space(lines[i][:colon_idx])) - value := strings.trim_space(lines[i][colon_idx+1:]) + value := strings.trim_space(lines[i][colon_idx + 1:]) header_set(&resp.header, key, value) } } - + // Parse body if header_end > 0 && header_end + 1 < len(lines) { body_start := strings.index(response, "\r\n\r\n") if body_start != -1 { body_start += 4 body := buf[body_start:n] - + if max_body_size > 0 && len(body) > max_body_size { - return os.EINVAL // Buffer too small + return invalid_parameter_error() // Buffer too small } - + resp.body = make([]byte, len(body)) copy(resp.body, body) } } - + return os.ERROR_NONE } // Do performs an HTTP request with full control -do_request :: proc(c: ^Client, req: ^Request, resp: ^Response) -> os.Errno { +do_request :: proc(c: ^Client, req: ^Request, resp: ^Response) -> os.Error { // Parse URI if needed if req.uri.host == "" { - uri, ok := uri_parse(fmt.tprintf("%s://%s%s", req.uri.scheme, req.header.data["Host"], req.uri.path)) + uri, ok := uri_parse( + fmt.tprintf("%s://%s%s", req.uri.scheme, req.header.data["Host"], req.uri.path), + ) if !ok { - return os.EINVAL + return invalid_parameter_error() } req.uri = uri } - + // Get or create host client host_key := fmt.tprintf("%s:%d", req.uri.host, req.uri.port) - + sync.rw_mutex_lock(&c.mutex) hc, exists := c.host_clients[host_key] if !exists { @@ -431,77 +462,83 @@ do_request :: proc(c: ^Client, req: ^Request, resp: ^Response) -> os.Errno { c.host_clients[host_key] = hc } sync.rw_mutex_unlock(&c.mutex) - + return host_client_do(hc, req, resp) } // DoWithRedirects performs an HTTP request following redirects -do_with_redirects :: proc(c: ^Client, req: ^Request, resp: ^Response, max_redirects: int) -> os.Errno { +do_with_redirects :: proc( + c: ^Client, + req: ^Request, + resp: ^Response, + max_redirects: int, +) -> os.Error { // Execute the request with redirect handling // Note: This implementation is a simple version that follows redirects // For full implementation, see client_advanced.odin - + for i := 0; i <= max_redirects; i += 1 { // Execute request err := do_request(c, req, resp) if err != os.ERROR_NONE { return err } - + // Check if response is a redirect (3xx) if resp.status_code < 300 || resp.status_code >= 400 { - return os.ERROR_NONE // Not a redirect, return normally + return os.ERROR_NONE // Not a redirect, return normally } - + // Get Location header location := header_get(&resp.header, "location") if location == "" { - return os.ERROR_NONE // No location header, return normally + return os.ERROR_NONE // No location header, return normally } - + // Build new URL base_url := fmt.tprintf("%s://%s%s", req.uri.scheme, req.uri.host, req.uri.path) redirect_url := resolve_base_url(base_url, location) if redirect_url == "" { - return os.EINVAL + return invalid_parameter_error() } - + // Parse new URL new_uri, ok := uri_parse(redirect_url) if !ok { - return os.EINVAL + return invalid_parameter_error() } - + // Update request with new URL req.uri = new_uri - + // For 301/302, change POST to GET if (resp.status_code == 301 || resp.status_code == 302) && req.method == .POST { req.method = .GET req.body = req.body[:0] } - + // Reset response for next request response_reset(resp) } - - return os.EINVAL + + return invalid_parameter_error() } // resolve_base_url resolves relative URL against base URL resolve_base_url :: proc(base_url, relative_url: string) -> string { // If relative URL is absolute, return it - if strings.has_prefix(relative_url, "http://") || strings.has_prefix(relative_url, "https://") { + if strings.has_prefix(relative_url, "http://") || + strings.has_prefix(relative_url, "https://") { return relative_url } - + // Parse base URL base_uri, ok := uri_parse(base_url) if !ok { return "" } - + // Handle path resolution if len(relative_url) > 0 && relative_url[0] == '/' { // Absolute path @@ -513,7 +550,7 @@ resolve_base_url :: proc(base_url, relative_url: string) -> string { // Remove last segment last_slash := strings.last_index_byte(base_path, '/') if last_slash != -1 { - base_path = base_path[:last_slash+1] + base_path = base_path[:last_slash + 1] } } return fmt.tprintf("%s://%s%s%s", base_uri.scheme, base_uri.host, base_path, relative_url) @@ -521,27 +558,27 @@ resolve_base_url :: proc(base_url, relative_url: string) -> string { } // GetTimeout performs GET with timeout -get_timeout :: proc(c: ^Client, url: string, timeout: time.Duration) -> (int, []byte, os.Errno) { +get_timeout :: proc(c: ^Client, url: string, timeout: time.Duration) -> (int, []byte, os.Error) { if timeout <= 0 { return client_get(c, url) } - + // Create request req: Request request_reset(&req) req.method = .GET req.uri, _ = uri_parse(url) req.timeout = timeout - + // Create response resp: Response response_reset(&resp) - + // Execute with timeout err := do_with_retry(c, &req, &resp, 3) if err != os.ERROR_NONE { return 0, nil, err } - + return resp.status_code, resp.body, os.ERROR_NONE } diff --git a/hypersock_http/client_advanced.odin b/hypersock_http/client_advanced.odin index 02b2686..33e013b 100644 --- a/hypersock_http/client_advanced.odin +++ b/hypersock_http/client_advanced.odin @@ -12,22 +12,21 @@ package hypersock_http * - TLS support (structure) */ -import "core:net" -import "core:os" import "core:fmt" -import "core:strings" +import "core:os" import "core:strconv" -import "core:time" +import "core:strings" import "core:sync" +import "core:time" // Cookie represents an HTTP cookie Cookie :: struct { - name: string, - value: string, - path: string, - domain: string, - expires: time.Time, - secure: bool, + name: string, + value: string, + path: string, + domain: string, + expires: time.Time, + secure: bool, http_only: bool, } @@ -38,96 +37,97 @@ Cookie_Jar :: struct { } // DoRedirects performs HTTP request following redirects -do_redirects :: proc(c: ^Client, req: ^Request, resp: ^Response, max_redirects: int) -> os.Errno { +do_redirects :: proc(c: ^Client, req: ^Request, resp: ^Response, max_redirects: int) -> os.Error { req_copy: Request request_reset(&req_copy) - + current_url := fmt.tprintf("%s://%s%s", req.uri.scheme, req.header.data["Host"], req.uri.path) - + redirects_remaining := max_redirects - + for redirects_remaining >= 0 { // Copy request for this iteration if req_copy.uri.host == "" { uri, ok := uri_parse(current_url) if !ok { - return os.EINVAL + return invalid_parameter_error() } req_copy.uri = uri } - + req_copy.method = req.method req_copy.body = req.body copy_headers(&req.header, &req_copy.header) - + // Execute request err := do_request(c, &req_copy, resp) if err != os.ERROR_NONE { return err } - + // Check if response is a redirect if !is_redirect_status(resp.status_code) { return os.ERROR_NONE } - + // Get location header location := header_get(&resp.header, "Location") if location == "" { - return os.EINVAL + return invalid_parameter_error() } - + // Resolve redirect URL redirect_url := resolve_redirect(current_url, location) if redirect_url == "" { - return os.EINVAL + return invalid_parameter_error() } - + // Check for redirect protocol change (security) - if strings.has_prefix(current_url, "https://") && strings.has_prefix(redirect_url, "http://") { - return os.EINVAL + if strings.has_prefix(current_url, "https://") && + strings.has_prefix(redirect_url, "http://") { + return invalid_parameter_error() } - + // Update current URL current_url = redirect_url - + // Parse new URL new_uri, ok := uri_parse(current_url) if !ok { - return os.EINVAL + return invalid_parameter_error() } req_copy.uri = new_uri - + // For 301/302 redirects, POST should change to GET - if (resp.status_code == Status_MovedPermanently || resp.status_code == Status_Found) && + if (resp.status_code == Status_MovedPermanently || resp.status_code == Status_Found) && req.method == .POST { req_copy.method = .GET req_copy.body = req_copy.body[:0] } - + redirects_remaining -= 1 } - - return os.EINVAL + + return invalid_parameter_error() } // DoWithRetry performs HTTP request with retry logic -do_with_retry :: proc(c: ^Client, req: ^Request, resp: ^Response, max_attempts: int) -> os.Errno { +do_with_retry :: proc(c: ^Client, req: ^Request, resp: ^Response, max_attempts: int) -> os.Error { attempt := 1 - base_delay := time.Duration(100 * 1e6) // 100ms in nanoseconds - + base_delay := time.Duration(100 * 1e6) // 100ms in nanoseconds + for attempt <= max_attempts { // Try the request err := do_request(c, req, resp) if err == os.ERROR_NONE { return os.ERROR_NONE } - + // Don't retry certain errors if should_not_retry(err, req.method) { return err } - + // Don't retry certain status codes if resp.status_code == Status_BadRequest || resp.status_code == Status_Unauthorized || @@ -135,42 +135,41 @@ do_with_retry :: proc(c: ^Client, req: ^Request, resp: ^Response, max_attempts: resp.status_code == Status_NotFound { return err } - + // Exponential backoff if attempt < max_attempts { delay := base_delay * time.Duration(1 << uint(attempt - 1)) time.sleep(delay) } - + attempt += 1 - + // Reset response for next attempt response_reset(resp) } - + return os.ERROR_NONE } - // Cookie operations // SetCookie sets a cookie in the jar set_cookie :: proc(jar: ^Cookie_Jar, cookie: ^Cookie) { sync.mutex_lock(&jar.mutex) defer sync.mutex_unlock(&jar.mutex) - + if jar.cookies == nil { jar.cookies = make(map[string][dynamic]Cookie) } - + domain := cookie.domain if domain == "" { return } - + cookies := &jar.cookies[domain] - + // Remove existing cookie with same name i := 0 for i < len(cookies) { @@ -182,27 +181,27 @@ set_cookie :: proc(jar: ^Cookie_Jar, cookie: ^Cookie) { if i < len(cookies) { ordered_remove(cookies, i) } - + // Add new cookie append(cookies, cookie^) } // GetCookies returns cookies for a URL -get_cookies :: proc(jar: ^Cookie_Jar, url_str: string) -> ([]Cookie, os.Errno) { +get_cookies :: proc(jar: ^Cookie_Jar, url_str: string) -> ([]Cookie, os.Error) { sync.mutex_lock(&jar.mutex) defer sync.mutex_unlock(&jar.mutex) - + if jar.cookies == nil { return nil, os.ERROR_NONE } - + uri, ok := uri_parse(url_str) if !ok { - return nil, os.EINVAL + return nil, invalid_parameter_error() } - + result: [dynamic]Cookie - + // Get cookies for exact domain if cookies, exists := jar.cookies[uri.host]; exists { for cookie in cookies { @@ -210,21 +209,21 @@ get_cookies :: proc(jar: ^Cookie_Jar, url_str: string) -> ([]Cookie, os.Errno) { if !(cookie.expires._nsec == 0) && time.now()._nsec >= cookie.expires._nsec { continue } - + // Check path if cookie.path != "" && !strings.has_prefix(uri.path, cookie.path) { continue } - + // Check secure if cookie.secure && uri.scheme != "https" { continue } - + append(&result, cookie) } } - + return result[:], os.ERROR_NONE } @@ -232,36 +231,36 @@ get_cookies :: proc(jar: ^Cookie_Jar, url_str: string) -> ([]Cookie, os.Errno) { parse_cookies :: proc(jar: ^Cookie_Jar, header: string, from_url: string) { lines := strings.split(header, ";") defer delete(lines) - + if len(lines) == 0 { return } - + // Parse name=value from first line name_value := strings.split(lines[0], "=") defer delete(name_value) - + if len(name_value) != 2 { return } - + cookie: Cookie cookie.name = strings.trim_space(name_value[0]) cookie.value = strings.trim_space(name_value[1]) - + // Parse attributes uri, ok := uri_parse(from_url) if ok { cookie.domain = uri.host cookie.path = uri.path } - + for i := 1; i < len(lines); i += 1 { attr := strings.trim_space(lines[i]) attr = strings.trim_space(attr) - + attr_name, attr_value := parse_cookie_attribute(attr) - + switch strings.to_lower(attr_name) { case "domain": cookie.domain = attr_value @@ -276,7 +275,7 @@ parse_cookies :: proc(jar: ^Cookie_Jar, header: string, from_url: string) { cookie.expires = parse_rfc1123_date(attr_value) } } - + set_cookie(jar, &cookie) } @@ -286,7 +285,7 @@ parse_cookie_attribute :: proc(attr: string) -> (name, value: string) { if idx == -1 { return strings.to_lower(attr), "" } - return strings.to_lower(attr[:idx]), attr[idx+1:] + return strings.to_lower(attr[:idx]), attr[idx + 1:] } // Helper functions @@ -295,20 +294,20 @@ parse_cookie_attribute :: proc(attr: string) -> (name, value: string) { is_redirect_status :: proc(status_code: int) -> bool { switch status_code { case Status_MovedPermanently, // 301 - Status_Found, // 302 - Status_SeeOther, // 303 - Status_TemporaryRedirect,// 307 - Status_PermanentRedirect: // 308 + Status_Found, // 302 + Status_SeeOther, // 303 + Status_TemporaryRedirect, // 307 + Status_PermanentRedirect: + // 308 return true } return false } // Should not retry certain errors -should_not_retry :: proc(err: os.Errno, method: Method) -> bool { +should_not_retry :: proc(err: os.Error, method: Method) -> bool { switch err { - case os.EINVAL, - os.ENOENT: + case invalid_parameter_error(), not_found_error(): return true } return false @@ -320,19 +319,19 @@ resolve_redirect :: proc(base_url, location: string) -> string { if strings.has_prefix(location, "http://") || strings.has_prefix(location, "https://") { return location } - + // Resolve relative to base URL // Simple implementation - just prepend base if strings.has_prefix(location, "/") { idx := strings.index(base_url, "//") if idx != -1 { - scheme_end := strings.index(base_url[idx+2:], "/") + scheme_end := strings.index(base_url[idx + 2:], "/") if scheme_end != -1 { - return fmt.tprintf("%s%s", base_url[:idx+2+scheme_end], location) + return fmt.tprintf("%s%s", base_url[:idx + 2 + scheme_end], location) } } } - + return location } @@ -340,14 +339,14 @@ resolve_redirect :: proc(base_url, location: string) -> string { copy_headers :: proc(src, dst: ^Header) { sync.mutex_lock(&src.mutex) defer sync.mutex_unlock(&src.mutex) - + sync.mutex_lock(&dst.mutex) defer sync.mutex_unlock(&dst.mutex) - + if dst.data == nil { dst.data = make(map[string][dynamic]string) } - + for key, value in src.data { dst.data[key] = value } @@ -366,15 +365,15 @@ add_cookies_to_request :: proc(jar: ^Cookie_Jar, url: string, headers: ^Header) if err != os.ERROR_NONE { return } - + if len(cookies) == 0 { return } - + // Build Cookie header cookie_value: strings.Builder defer strings.builder_destroy(&cookie_value) - + for i := 0; i < len(cookies); i += 1 { cookie := cookies[i] if i > 0 { @@ -382,7 +381,7 @@ add_cookies_to_request :: proc(jar: ^Cookie_Jar, url: string, headers: ^Header) } fmt.sbprintf(&cookie_value, "%s=%s", cookie.name, cookie.value) } - + header_set(headers, "Cookie", strings.to_string(cookie_value)) } @@ -390,7 +389,7 @@ add_cookies_to_request :: proc(jar: ^Cookie_Jar, url: string, headers: ^Header) add_received_cookies :: proc(jar: ^Cookie_Jar, headers: ^Header, url: string) { // Get all Set-Cookie header values set_cookies := header_get_all(headers, "Set-Cookie") - + // Parse each Set-Cookie header value for cookie_value in set_cookies { parse_cookies(jar, cookie_value, url) @@ -401,30 +400,30 @@ add_received_cookies :: proc(jar: ^Cookie_Jar, headers: ^Header, url: string) { set_cookie_header :: proc(headers: ^Header, cookie: ^Cookie) { value: strings.Builder defer strings.builder_destroy(&value) - + fmt.sbprintf(&value, "%s=%s", cookie.name, cookie.value) - + if cookie.path != "" { fmt.sbprintf(&value, "; Path=%s", cookie.path) } - + if cookie.domain != "" { fmt.sbprintf(&value, "; Domain=%s", cookie.domain) } - + if cookie.expires._nsec != 0 { // Format time - use simple format for now fmt.sbprintf(&value, "; Expires=%v", cookie.expires) } - + if cookie.secure { fmt.sbprintf(&value, "; Secure") } - + if cookie.http_only { fmt.sbprintf(&value, "; HttpOnly") } - + header_set(headers, "Set-Cookie", strings.to_string(value)) } @@ -433,18 +432,30 @@ set_cookie_header :: proc(headers: ^Header, cookie: ^Cookie) { month_to_number :: proc(month: string) -> int { month_lower := strings.to_lower(month) switch month_lower { - case "jan": return 1 - case "feb": return 2 - case "mar": return 3 - case "apr": return 4 - case "may": return 5 - case "jun": return 6 - case "jul": return 7 - case "aug": return 8 - case "sep": return 9 - case "oct": return 10 - case "nov": return 11 - case "dec": return 12 + case "jan": + return 1 + case "feb": + return 2 + case "mar": + return 3 + case "apr": + return 4 + case "may": + return 5 + case "jun": + return 6 + case "jul": + return 7 + case "aug": + return 8 + case "sep": + return 9 + case "oct": + return 10 + case "nov": + return 11 + case "dec": + return 12 } return 0 } @@ -456,24 +467,24 @@ parse_rfc1123_date :: proc(date_str: string) -> time.Time { if trimmed == "" { return {} } - + // Try RFC 1123 format: "Sun, 06 Nov 1994 08:49:37 GMT" // Format: Wdy, DD Mon YYYY HH:MM:SS GMT - + // Remove day name if present (e.g., "Sun, ") after_day := trimmed if idx := strings.index(trimmed, ","); idx != -1 { - after_day = strings.trim_space(trimmed[idx+1:]) + after_day = strings.trim_space(trimmed[idx + 1:]) } - + // Now should be: "06 Nov 1994 08:49:37 GMT" parts := strings.split(after_day, " ") defer delete(parts) - + if len(parts) < 4 { return {} } - + // Parse day day: int if d, ok := strconv.parse_int(strings.trim_space(parts[0]), 10); ok { @@ -481,7 +492,7 @@ parse_rfc1123_date :: proc(date_str: string) -> time.Time { } else { return {} } - + // Parse month month := 1 month_name := strings.to_lower(strings.trim_space(parts[1])) @@ -490,7 +501,7 @@ parse_rfc1123_date :: proc(date_str: string) -> time.Time { } else { return {} } - + // Parse year year: int if y, ok := strconv.parse_int(strings.trim_space(parts[2]), 10); ok { @@ -498,11 +509,11 @@ parse_rfc1123_date :: proc(date_str: string) -> time.Time { } else { return {} } - + // Parse time (HH:MM:SS) time_parts := strings.split(strings.trim_space(parts[3]), ":") defer delete(time_parts) - + hour, minute, second: int if len(time_parts) >= 3 { if h, ok := strconv.parse_int(time_parts[0], 10); ok { @@ -515,18 +526,19 @@ parse_rfc1123_date :: proc(date_str: string) -> time.Time { second = int(s) } } - + // Create time using time.from_utc_components if available // Otherwise construct manually // This is a simplified implementation t: time.Time - t._nsec = i64(year-1970) * 365 * 24 * 60 * 60 * 1_000_000_000 + - i64(month) * 30 * 24 * 60 * 60 * 1_000_000_000 + - i64(day) * 24 * 60 * 60 * 1_000_000_000 + - i64(hour) * 60 * 60 * 1_000_000_000 + - i64(minute) * 60 * 1_000_000_000 + - i64(second) * 1_000_000_000 - + t._nsec = + i64(year - 1970) * 365 * 24 * 60 * 60 * 1_000_000_000 + + i64(month) * 30 * 24 * 60 * 60 * 1_000_000_000 + + i64(day) * 24 * 60 * 60 * 1_000_000_000 + + i64(hour) * 60 * 60 * 1_000_000_000 + + i64(minute) * 60 * 1_000_000_000 + + i64(second) * 1_000_000_000 + return t } @@ -535,9 +547,9 @@ cookie_jar_destroy :: proc(jar: ^Cookie_Jar) { if jar == nil { return } - + // Free all cookie arrays - for key, cookies in jar.cookies { + for _, cookies in jar.cookies { delete(cookies) } delete(jar.cookies) diff --git a/hypersock_http/errors.odin b/hypersock_http/errors.odin new file mode 100644 index 0000000..b548841 --- /dev/null +++ b/hypersock_http/errors.odin @@ -0,0 +1,82 @@ +package hypersock_http + +import "core:io" +import "core:os" +import win32 "core:sys/windows" + +invalid_parameter_error :: proc() -> os.Error { + when ODIN_OS == .Linux { + err: os.Error = os.Platform_Error.EINVAL + return err + } else when ODIN_OS == .Windows { + return os.Platform_Error(win32.ERROR_INVALID_PARAMETER) + } else { + return os.General_Error.Invalid_Command + } +} + +connection_refused_error :: proc() -> os.Error { + when ODIN_OS == .Linux { + err: os.Error = os.Platform_Error.ECONNREFUSED + return err + } else when ODIN_OS == .Windows { + return os.Platform_Error(win32.WSAECONNREFUSED) + } else { + return os.General_Error.Invalid_Command + } +} + +connection_reset_error :: proc() -> os.Error { + when ODIN_OS == .Linux { + err: os.Error = os.Platform_Error.ECONNRESET + return err + } else when ODIN_OS == .Windows { + return os.Platform_Error(win32.WSAECONNRESET) + } else { + return os.General_Error.Broken_Pipe + } +} + +would_block_error :: proc() -> os.Error { + when ODIN_OS == .Linux { + err: os.Error = os.Platform_Error.EAGAIN + return err + } else when ODIN_OS == .Windows { + return os.Platform_Error(win32.WSAEWOULDBLOCK) + } else { + return os.General_Error.Timeout + } +} + +io_error :: proc() -> os.Error { + when ODIN_OS == .Linux { + err: os.Error = os.Platform_Error.EIO + return err + } else when ODIN_OS == .Windows { + return os.Platform_Error(win32.ERROR_INVALID_HANDLE) + } else { + return io.Error.Unknown + } +} + +not_found_error :: proc() -> os.Error { + when ODIN_OS == .Linux { + err: os.Error = os.Platform_Error.ENOENT + return err + } else when ODIN_OS == .Windows { + return os.Platform_Error(win32.ERROR_FILE_NOT_FOUND) + } else { + return os.General_Error.Not_Exist + } +} + +not_connected_error :: proc() -> os.Error { + when ODIN_OS == .Linux { + err: os.Error = os.Platform_Error.ENOTCONN + return err + } else when ODIN_OS == .Windows { + return os.Platform_Error(win32.WSAENOTCONN) + } else { + return os.General_Error.Invalid_Command + } +} diff --git a/hypersock_http/http.odin b/hypersock_http/http.odin index bd8d1f8..1ff518a 100644 --- a/hypersock_http/http.odin +++ b/hypersock_http/http.odin @@ -12,14 +12,12 @@ package hypersock_http * - Concurrent-safe operations */ -import "core:net" -import "core:os" import "core:fmt" -import "core:strings" +import "core:net" import "core:strconv" -import "core:time" +import "core:strings" import "core:sync" -import "core:mem" +import "core:time" // HTTP Methods Method :: enum { @@ -34,60 +32,67 @@ Method :: enum { method_to_string :: proc(m: Method) -> string { switch m { - case .GET: return "GET" - case .POST: return "POST" - case .PUT: return "PUT" - case .DELETE: return "DELETE" - case .HEAD: return "HEAD" - case .OPTIONS: return "OPTIONS" - case .PATCH: return "PATCH" + case .GET: + return "GET" + case .POST: + return "POST" + case .PUT: + return "PUT" + case .DELETE: + return "DELETE" + case .HEAD: + return "HEAD" + case .OPTIONS: + return "OPTIONS" + case .PATCH: + return "PATCH" } return "GET" } // HTTP Status codes -Status_OK :: 200 -Status_Created :: 201 -Status_Accepted :: 202 -Status_NoContent :: 204 -Status_MovedPermanently :: 301 -Status_Found :: 302 -Status_SeeOther :: 303 -Status_NotModified :: 304 -Status_TemporaryRedirect :: 307 -Status_PermanentRedirect :: 308 -Status_BadRequest :: 400 -Status_Unauthorized :: 401 -Status_Forbidden :: 403 -Status_NotFound :: 404 -Status_MethodNotAllowed :: 405 -Status_RequestTimeout :: 408 -Status_Conflict :: 409 -Status_Gone :: 410 -Status_LengthRequired :: 411 -Status_PayloadTooLarge :: 413 -Status_URITooLong :: 414 +Status_OK :: 200 +Status_Created :: 201 +Status_Accepted :: 202 +Status_NoContent :: 204 +Status_MovedPermanently :: 301 +Status_Found :: 302 +Status_SeeOther :: 303 +Status_NotModified :: 304 +Status_TemporaryRedirect :: 307 +Status_PermanentRedirect :: 308 +Status_BadRequest :: 400 +Status_Unauthorized :: 401 +Status_Forbidden :: 403 +Status_NotFound :: 404 +Status_MethodNotAllowed :: 405 +Status_RequestTimeout :: 408 +Status_Conflict :: 409 +Status_Gone :: 410 +Status_LengthRequired :: 411 +Status_PayloadTooLarge :: 413 +Status_URITooLong :: 414 Status_UnsupportedMediaType :: 415 -Status_TooManyRequests :: 429 -Status_InternalServerError :: 500 -Status_NotImplemented :: 501 -Status_BadGateway :: 502 -Status_ServiceUnavailable :: 503 -Status_GatewayTimeout :: 504 +Status_TooManyRequests :: 429 +Status_InternalServerError :: 500 +Status_NotImplemented :: 501 +Status_BadGateway :: 502 +Status_ServiceUnavailable :: 503 +Status_GatewayTimeout :: 504 // Header represents HTTP headers (supports multiple values per key) Header :: struct { - data: map[string][dynamic]string, + data: map[string][dynamic]string, mutex: sync.Mutex, } // Request represents an HTTP request Request :: struct { - method: Method, - uri: URI, - header: Header, - body: []byte, - timeout: time.Duration, + method: Method, + uri: URI, + header: Header, + body: []byte, + timeout: time.Duration, // User data for passing values between handlers user_data: map[string]any, } @@ -97,17 +102,17 @@ Response :: struct { status_code: int, header: Header, body: []byte, - keep_body: bool, // Don't release body buffer after use + keep_body: bool, // Don't release body buffer after use } // URI represents a parsed URL URI :: struct { - scheme: string, - host: string, - port: int, - path: string, - query: string, - fragment: string, + scheme: string, + host: string, + port: int, + path: string, + query: string, + fragment: string, // Parsed query args query_args: Args, } @@ -126,46 +131,46 @@ RequestHandler :: proc(ctx: ^RequestCtx) // Client implements high-performance HTTP client Client :: struct { - host_clients: map[string]^HostClient, - mutex: sync.RW_Mutex, - max_conns_per_host: int, + host_clients: map[string]^HostClient, + mutex: sync.RW_Mutex, + max_conns_per_host: int, max_idle_conn_duration: time.Duration, - read_buffer_size: int, - write_buffer_size: int, - read_timeout: time.Duration, - write_timeout: time.Duration, + read_buffer_size: int, + write_buffer_size: int, + read_timeout: time.Duration, + write_timeout: time.Duration, max_response_body_size: int, - tls_config: ^TLS_Config, - name: string, + tls_config: ^TLS_Config, + name: string, } // TLS_Config is defined in tls.odin // HostClient manages connections to a specific host HostClient :: struct { - addr: string, - is_tls: bool, - max_conns: int, - conns: [dynamic]^clientConn, - mutex: sync.Mutex, - conns_count: int, - pending_requests: int, + addr: string, + is_tls: bool, + max_conns: int, + conns: [dynamic]^clientConn, + mutex: sync.Mutex, + conns_count: int, + pending_requests: int, max_idle_conn_duration: time.Duration, - read_buffer_size: int, - write_buffer_size: int, - read_timeout: time.Duration, - write_timeout: time.Duration, + read_buffer_size: int, + write_buffer_size: int, + read_timeout: time.Duration, + write_timeout: time.Duration, max_response_body_size: int, - client: ^Client, + client: ^Client, } // clientConn represents a pooled connection clientConn :: struct { - conn: net.TCP_Socket, - tls_socket: ^TLS_Socket, // TLS wrapper (nil for plain TCP) - is_tls: bool, // Whether this is a TLS connection - created: time.Time, - last_use: time.Time, + conn: net.TCP_Socket, + tls_socket: ^TLS_Socket, // TLS wrapper (nil for plain TCP) + is_tls: bool, // Whether this is a TLS connection + created: time.Time, + last_use: time.Time, } // Default configuration values @@ -192,12 +197,12 @@ _default_client_initialized: bool // Initialize client with default settings client_default :: proc() -> ^Client { if !_default_client_initialized { - _default_client = Client{ - max_conns_per_host = Default_Max_Conns_Per_Host, + _default_client = Client { + max_conns_per_host = Default_Max_Conns_Per_Host, max_idle_conn_duration = Default_Max_Idle_Conn_Duration, - read_buffer_size = Default_Read_Buffer_Size, - write_buffer_size = Default_Write_Buffer_Size, - name = "odin-http-client", + read_buffer_size = Default_Read_Buffer_Size, + write_buffer_size = Default_Write_Buffer_Size, + name = "odin-http-client", } _default_client_initialized = true } @@ -218,14 +223,14 @@ client_new :: proc() -> ^Client { // Clean up client client_destroy :: proc(c: ^Client) { if c == nil do return - + // Close all host clients for _, hc in c.host_clients { host_client_close(hc) free(hc) } delete(c.host_clients) - + if c != &_default_client { free(c) } @@ -234,30 +239,30 @@ client_destroy :: proc(c: ^Client) { // Parse URL string into URI structure uri_parse :: proc(url_str: string) -> (URI, bool) { uri: URI - + // Simple URL parser rest := url_str - + // Extract scheme if idx := strings.index(rest, "://"); idx != -1 { uri.scheme = strings.to_lower(rest[:idx]) - rest = rest[idx+3:] + rest = rest[idx + 3:] } - + // Extract fragment if idx := strings.index(rest, "#"); idx != -1 { - uri.fragment = rest[idx+1:] + uri.fragment = rest[idx + 1:] rest = rest[:idx] } - + // Extract query if idx := strings.index(rest, "?"); idx != -1 { - uri.query = rest[idx+1:] + uri.query = rest[idx + 1:] rest = rest[:idx] // Parse query args parse_args(&uri.query_args, uri.query) } - + // Extract host and port path_idx := strings.index(rest, "/") if path_idx == -1 { @@ -267,10 +272,10 @@ uri_parse :: proc(url_str: string) -> (URI, bool) { uri.host = rest[:path_idx] uri.path = rest[path_idx:] } - + // Check for port if idx := strings.last_index(uri.host, ":"); idx != -1 { - port_str := uri.host[idx+1:] + port_str := uri.host[idx + 1:] // Parse port using strconv if parsed_port, ok := strconv.parse_int(port_str, 10); ok { uri.port = int(parsed_port) @@ -291,7 +296,7 @@ uri_parse :: proc(url_str: string) -> (URI, bool) { uri.port = 80 } } - + return uri, true } @@ -300,14 +305,14 @@ url_decode :: proc(s: string) -> string { if !strings.contains(s, "%") { return s } - + result := strings.builder_make() defer strings.builder_destroy(&result) - + for i := 0; i < len(s); i += 1 { if s[i] == '%' && i + 2 < len(s) { // Parse hex value - hex_str := s[i+1:i+3] + hex_str := s[i + 1:i + 3] if val, ok := strconv.parse_int(hex_str, 16); ok { fmt.sbprintf(&result, "%c", byte(val)) i += 2 @@ -320,17 +325,17 @@ url_decode :: proc(s: string) -> string { strings.write_byte(&result, s[i]) } } - + return strings.to_string(result) } // Parse query string into Args with URL decoding parse_args :: proc(args: ^Args, query: string) { args.data = make(map[string][dynamic]string) - + pairs := strings.split(query, "&") defer delete(pairs) - + for pair in pairs { kv := strings.split(pair, "=") if len(kv) == 2 { @@ -357,11 +362,11 @@ args_get :: proc(args: ^Args, key: string) -> string { header_set :: proc(h: ^Header, key, value: string) { sync.mutex_lock(&h.mutex) defer sync.mutex_unlock(&h.mutex) - + if h.data == nil { h.data = make(map[string][dynamic]string) } - + lower_key := strings.to_lower(key) if lower_key not_in h.data { h.data[lower_key] = make([dynamic]string) @@ -373,11 +378,11 @@ header_set :: proc(h: ^Header, key, value: string) { header_replace :: proc(h: ^Header, key, value: string) { sync.mutex_lock(&h.mutex) defer sync.mutex_unlock(&h.mutex) - + if h.data == nil { h.data = make(map[string][dynamic]string) } - + lower_key := strings.to_lower(key) // Clear existing values if any if lower_key in h.data { @@ -392,11 +397,11 @@ header_replace :: proc(h: ^Header, key, value: string) { header_get :: proc(h: ^Header, key: string) -> string { sync.mutex_lock(&h.mutex) defer sync.mutex_unlock(&h.mutex) - + if h.data == nil { return "" } - + lower_key := strings.to_lower(key) if values, ok := h.data[lower_key]; ok && len(values) > 0 { return values[0] @@ -408,11 +413,11 @@ header_get :: proc(h: ^Header, key: string) -> string { header_get_all :: proc(h: ^Header, key: string) -> []string { sync.mutex_lock(&h.mutex) defer sync.mutex_unlock(&h.mutex) - + if h.data == nil { return nil } - + lower_key := strings.to_lower(key) if values, ok := h.data[lower_key]; ok { return values[:] @@ -423,11 +428,11 @@ header_get_all :: proc(h: ^Header, key: string) -> []string { header_has :: proc(h: ^Header, key: string) -> bool { sync.mutex_lock(&h.mutex) defer sync.mutex_unlock(&h.mutex) - + if h.data == nil { return false } - + lower_key := strings.to_lower(key) values, ok := h.data[lower_key] return ok && len(values) > 0 @@ -474,7 +479,7 @@ URI_Builder :: struct { host: string, port: int, path: string, - query: strings.Builder, + query: strings.Builder, fragment: string, } @@ -491,7 +496,7 @@ uri_builder_new :: proc() -> URI_Builder { // uri_builder_set_scheme sets the scheme (http, https, etc.) uri_builder_set_scheme :: proc(b: ^URI_Builder, scheme: string) { b.scheme = scheme - + // Set default port based on scheme switch scheme { case "http": @@ -540,9 +545,9 @@ uri_builder_build :: proc(b: ^URI_Builder) -> string { // Start with scheme://host result := strings.builder_make() defer strings.builder_destroy(&result) - + fmt.sbprintf(&result, "%s://%s", b.scheme, b.host) - + // Add port if non-default needs_port := false switch b.scheme { @@ -555,25 +560,25 @@ uri_builder_build :: proc(b: ^URI_Builder) -> string { needs_port = true } } - + if needs_port { fmt.sbprintf(&result, ":%d", b.port) } - + // Add path fmt.sbprintf(&result, "%s", b.path) - + // Add query string if present query_str := strings.to_string(b.query) if query_str != "" { fmt.sbprintf(&result, "?%s", query_str) } - + // Add fragment if present if b.fragment != "" { fmt.sbprintf(&result, "#%s", b.fragment) } - + return strings.to_string(result) } @@ -583,20 +588,26 @@ uri_builder_destroy :: proc(b: ^URI_Builder) { } // Convenience function: build a URI from components -uri_build :: proc(scheme, host: string, port: int, path: string, query_map: map[string]string, fragment: string) -> string { +uri_build :: proc( + scheme, host: string, + port: int, + path: string, + query_map: map[string]string, + fragment: string, +) -> string { builder := uri_builder_new() defer uri_builder_destroy(&builder) - + uri_builder_set_scheme(&builder, scheme) uri_builder_set_host(&builder, host) uri_builder_set_port(&builder, port) uri_builder_set_path(&builder, path) - + for key, value in query_map { uri_builder_add_query(&builder, key, value) } - + uri_builder_set_fragment(&builder, fragment) - + return uri_builder_build(&builder) } diff --git a/hypersock_http/server.odin b/hypersock_http/server.odin index 927a546..b96f2df 100644 --- a/hypersock_http/server.odin +++ b/hypersock_http/server.odin @@ -12,22 +12,21 @@ package hypersock_http * - Keep-alive connections */ +import "core:fmt" import "core:net" import "core:os" -import "core:fmt" -import "core:strings" import "core:strconv" -import "core:time" +import "core:strings" import "core:sync" import "core:thread" -import "core:mem" +import "core:time" // ConnectionQueue is a thread-safe queue for accepted connections ConnectionQueue :: struct { - mutex: sync.Mutex, - cond: sync.Cond, - items: [dynamic]net.TCP_Socket, - closed: bool, + mutex: sync.Mutex, + cond: sync.Cond, + items: [dynamic]net.TCP_Socket, + closed: bool, } // Queue initialization @@ -91,22 +90,22 @@ queue_is_closed :: proc(q: ^ConnectionQueue) -> bool { // Server implements HTTP server Server :: struct { - handler: RequestHandler, - name: string, - read_buffer_size: int, + handler: RequestHandler, + name: string, + read_buffer_size: int, write_buffer_size: int, - read_timeout: time.Duration, - write_timeout: time.Duration, - idle_timeout: time.Duration, - max_body_size: int, - concurrency: int, - + read_timeout: time.Duration, + write_timeout: time.Duration, + idle_timeout: time.Duration, + max_body_size: int, + concurrency: int, + // Internal fields - listen: net.TCP_Socket, - accept_queue: ConnectionQueue, - wg: sync.Wait_Group, - running_mutex: sync.Mutex, - running: bool, + listen: net.TCP_Socket, + accept_queue: ConnectionQueue, + wg: sync.Wait_Group, + running_mutex: sync.Mutex, + running: bool, } // Server thread context data @@ -120,23 +119,23 @@ WorkerThreadData :: struct { } // HijackHandler is called when a connection is hijacked -Hijack_Handler :: proc(^RequestCtx) -> (net.TCP_Socket, os.Errno) +Hijack_Handler :: proc(_: ^RequestCtx) -> (net.TCP_Socket, os.Error) // RequestCtx contains incoming request and manages outgoing response RequestCtx :: struct { - request: Request, - response: Response, - conn: net.TCP_Socket, - conn_time: time.Time, - request_num: u64, - remote_addr: net.Address, - local_addr: net.Address, - + request: Request, + response: Response, + conn: net.TCP_Socket, + conn_time: time.Time, + request_num: u64, + remote_addr: net.Address, + local_addr: net.Address, + // User data storage - user_data: map[string]any, - + user_data: map[string]any, + // Connection hijacking - hijacked: bool, + hijacked: bool, } // Create new HTTP server @@ -149,12 +148,12 @@ server_new :: proc(handler: RequestHandler) -> ^Server { s.read_timeout = 30 * time.Second s.write_timeout = 30 * time.Second s.idle_timeout = 10 * time.Second - s.max_body_size = 4 * 1024 * 1024 // 4MB + s.max_body_size = 4 * 1024 * 1024 // 4MB s.concurrency = 256 s.running = false - + queue_init(&s.accept_queue) - + return s } @@ -163,39 +162,39 @@ server_destroy :: proc(s: ^Server) { if s == nil { return } - + // Shutdown if running if server_is_running(s) { shutdown(s) } - + // Clean up queue queue_destroy(&s.accept_queue) - + // Free server free(s) } // ListenAndServe starts HTTP server on addr -listen_and_serve :: proc(s: ^Server, addr: string) -> os.Errno { +listen_and_serve :: proc(s: ^Server, addr: string) -> os.Error { // Create listener endpoint, endpoint_err := net.parse_endpoint(addr) if endpoint_err { - return os.EINVAL + return invalid_parameter_error() } - + listen_socket, listen_err := net.listen_tcp(endpoint) if listen_err != nil { - return os.ECONNREFUSED + return connection_refused_error() } s.listen = listen_socket - + sync.mutex_lock(&s.running_mutex) s.running = true sync.mutex_unlock(&s.running_mutex) - + fmt.printf("Server listening on %s\n", addr) - + // Start accept thread accept_data := new(AcceptThreadData) accept_data.server = s @@ -207,7 +206,7 @@ listen_and_serve :: proc(s: ^Server, addr: string) -> os.Errno { }) accept_thread.data = accept_data thread.start(accept_thread) - + // Start worker pool for i := 0; i < s.concurrency; i += 1 { worker_data := new(WorkerThreadData) @@ -222,13 +221,13 @@ listen_and_serve :: proc(s: ^Server, addr: string) -> os.Errno { worker_thread.data = worker_data thread.start(worker_thread) } - + // Wait for shutdown signal (block main thread) sync.wait_group_wait(&s.wg) - + // Close listener net.close(s.listen) - + return os.ERROR_NONE } @@ -243,7 +242,7 @@ server_is_running :: proc(s: ^Server) -> bool { // Accept incoming connections server_accept :: proc(s: ^Server) { defer sync.wait_group_done(&s.wg) - + for server_is_running(s) { conn, _, accept_err := net.accept_tcp(s.listen) if accept_err != nil { @@ -252,7 +251,7 @@ server_accept :: proc(s: ^Server) { } continue } - + // Try to push to queue, exit if shutdown if !queue_push(&s.accept_queue, conn) { net.close(conn) @@ -264,14 +263,14 @@ server_accept :: proc(s: ^Server) { // Worker handles connections server_worker :: proc(s: ^Server, worker_id: int) { defer sync.wait_group_done(&s.wg) - + for { conn, ok := queue_pop(&s.accept_queue) if !ok { // Queue closed, exit worker return } - + server_handle_connection(s, conn) } } @@ -280,13 +279,13 @@ server_worker :: proc(s: ^Server, worker_id: int) { server_handle_connection :: proc(s: ^Server, conn: net.TCP_Socket) { // Check if connection was hijacked before closing hijacked := false - + defer { if !hijacked { net.close(conn) } } - + // Set timeouts using socket options if s.read_timeout > 0 { net.set_option(conn, .Receive_Timeout, int(s.read_timeout)) @@ -294,54 +293,53 @@ server_handle_connection :: proc(s: ^Server, conn: net.TCP_Socket) { if s.write_timeout > 0 { net.set_option(conn, .Send_Timeout, int(s.write_timeout)) } - + ctx: RequestCtx ctx.conn = conn ctx.conn_time = time.now() ctx.request_num = 1 ctx.remote_addr = {} ctx.local_addr = {} - + for server_is_running(s) { // Read request err := read_request(conn, &ctx.request) if err != os.ERROR_NONE { - if err != os.ECONNRESET { + if err != connection_reset_error() { fmt.println("Read error:", err) } break } - + // Reset response response_reset(&ctx.response) - + // Set default headers header_set(&ctx.response.header, "Server", s.name) - header_set(&ctx.response.header, "Date", fmt.tprintf("%v", time.now())) - + header_set(&ctx.response.header, "Date", fmt.tprintf("%v", time.now())) + // Call handler s.handler(&ctx) - + // Write response err = write_response(conn, &ctx.response) if err != os.ERROR_NONE { fmt.println("Write error:", err) break } - + // Check if connection was hijacked if ctx.hijacked { hijacked = true break } - + // Check if connection should be closed connection := header_get(&ctx.request.header, "Connection") - if strings.to_lower(connection) == "close" || - ctx.response.status_code >= 400 { + if strings.to_lower(connection) == "close" || ctx.response.status_code >= 400 { break } - + ctx.request_num += 1 } } @@ -351,10 +349,10 @@ shutdown :: proc(s: ^Server) { sync.mutex_lock(&s.running_mutex) s.running = false sync.mutex_unlock(&s.running_mutex) - + // Close the accept queue to signal workers to exit queue_close(&s.accept_queue) - + // Close the listener to unblock accept() if s.listen != {} { net.close(s.listen) @@ -502,32 +500,33 @@ form_value :: proc(ctx: ^RequestCtx, key: string) -> string { return query_value } } - + // Then check POST body if content-type is form data content_type := header_get(&ctx.request.header, "Content-Type") - if strings.contains(content_type, "application/x-www-form-urlencoded") && len(ctx.request.body) > 0 { + if strings.contains(content_type, "application/x-www-form-urlencoded") && + len(ctx.request.body) > 0 { post_value := parse_form_value(string(ctx.request.body), key) if post_value != "" { return post_value } } - + return "" } // PostArgs parses and returns POST form arguments post_args :: proc(ctx: ^RequestCtx) -> map[string]string { args: map[string]string - + if len(ctx.request.body) == 0 { return args } - + content_type := header_get(&ctx.request.header, "Content-Type") if !strings.contains(content_type, "application/x-www-form-urlencoded") { return args } - + return parse_form(string(ctx.request.body)) } @@ -535,23 +534,23 @@ post_args :: proc(ctx: ^RequestCtx) -> map[string]string { parse_query_value :: proc(query_str, key: string) -> string { pairs := strings.split(query_str, "&") defer delete(pairs) - + for pair in pairs { parts := strings.split(pair, "=") defer delete(parts) - + if len(parts) >= 1 { url_decode, _ := strings.replace(parts[0], "+", " ", -1) if url_decode == key { if len(parts) >= 2 { result, _ := strings.replace(parts[1], "+", " ", -1) - return result + return result } return "" } } } - + return "" } @@ -559,11 +558,11 @@ parse_query_value :: proc(query_str, key: string) -> string { parse_form_value :: proc(form_str, key: string) -> string { pairs := strings.split(form_str, "&") defer delete(pairs) - + for pair in pairs { parts := strings.split(pair, "=") defer delete(parts) - + if len(parts) >= 1 { if parts[0] == key { if len(parts) >= 2 { @@ -573,21 +572,21 @@ parse_form_value :: proc(form_str, key: string) -> string { } } } - + return "" } // parse_form parses a form string into a map parse_form :: proc(form_str: string) -> map[string]string { result := make(map[string]string) - + pairs := strings.split(form_str, "&") defer delete(pairs) - + for pair in pairs { parts := strings.split(pair, "=") defer delete(parts) - + if len(parts) >= 1 { key := parts[0] if len(parts) >= 2 { @@ -597,16 +596,16 @@ parse_form :: proc(form_str: string) -> map[string]string { } } } - + return result } // Hijack takes over the connection from the server // Returns the underlying TCP socket // After hijacking, the server will not close the connection -hijack :: proc(ctx: ^RequestCtx) -> (net.TCP_Socket, os.Errno) { +hijack :: proc(ctx: ^RequestCtx) -> (net.TCP_Socket, os.Error) { if ctx.hijacked { - return {}, os.EINVAL + return {}, invalid_parameter_error() } ctx.hijacked = true return ctx.conn, os.ERROR_NONE @@ -614,10 +613,12 @@ hijack :: proc(ctx: ^RequestCtx) -> (net.TCP_Socket, os.Errno) { // String returns a string representation of the context string_ctx :: proc(ctx: ^RequestCtx) -> string { - return fmt.tprintf("[#%d %s<->%s %s %s]", + return fmt.tprintf( + "[#%d %s<->%s %s %s]", ctx.request_num, ctx.local_addr, ctx.remote_addr, method_to_string(ctx.request.method), - ctx.request.uri.path) + ctx.request.uri.path, + ) } diff --git a/hypersock_http/server_io.odin b/hypersock_http/server_io.odin index 7f5b022..447a8e2 100644 --- a/hypersock_http/server_io.odin +++ b/hypersock_http/server_io.odin @@ -6,26 +6,25 @@ package hypersock_http * Based on fasthttp patterns */ +import "core:fmt" import "core:net" import "core:os" -import "core:fmt" import "core:strings" -import "core:time" import "core:sync" // read_request reads and parses an HTTP request from the connection -read_request :: proc(conn: net.TCP_Socket, req: ^Request) -> os.Errno { +read_request :: proc(conn: net.TCP_Socket, req: ^Request) -> os.Error { // Clear request request_reset(req) - + // Create buffered reader buf: [dynamic]byte defer delete(buf) - + // Read the request line (max 4096 bytes) chunk := make([]byte, 4096) defer delete(chunk) - + n, recv_err := net.recv_tcp(conn, chunk) if recv_err != nil { return nil @@ -33,19 +32,19 @@ read_request :: proc(conn: net.TCP_Socket, req: ^Request) -> os.Errno { if n == 0 { return nil } - + // Copy to buffer for easier parsing append(&buf, ..chunk[:n]) - + // Find end of request line line_end := -1 for i := 0; i < len(buf); i += 1 { - if buf[i] == '\r' && i + 1 < len(buf) && buf[i+1] == '\n' { + if buf[i] == '\r' && i + 1 < len(buf) && buf[i + 1] == '\n' { line_end = i break } } - + if line_end == -1 { // No CRLF found, might be LF-only for i := 0; i < len(buf); i += 1 { @@ -55,100 +54,107 @@ read_request :: proc(conn: net.TCP_Socket, req: ^Request) -> os.Errno { } } } - + if line_end == -1 { return nil } - + line := string(buf[:line_end]) - + // Parse request line: METHOD PATH HTTP/VERSION parts := strings.split(line, " ") defer delete(parts) - + if len(parts) < 3 { return nil } - + // Parse method method_str := strings.trim_space(parts[0]) switch method_str { - case "GET": req.method = .GET - case "POST": req.method = .POST - case "PUT": req.method = .PUT - case "DELETE": req.method = .DELETE - case "HEAD": req.method = .HEAD - case "OPTIONS": req.method = .OPTIONS - case "PATCH": req.method = .PATCH + case "GET": + req.method = .GET + case "POST": + req.method = .POST + case "PUT": + req.method = .PUT + case "DELETE": + req.method = .DELETE + case "HEAD": + req.method = .HEAD + case "OPTIONS": + req.method = .OPTIONS + case "PATCH": + req.method = .PATCH } - + // Parse path (URI) uri_str := strings.trim_space(parts[1]) - + // Include query string if present // Extract query string from original buffer - query_start := line_end + 2 // Skip CRLF - + query_start := line_end + 2 // Skip CRLF + // Update buffer position and continue parsing headers cursor := query_start - + // Check if there's more data (headers) before parsing URI // For now, just parse the path part - + // Parse and set URI req.uri, _ = uri_parse(uri_str) - + // Parse headers until empty line for { // Find next CRLF header_len := -1 for i := cursor; i < len(buf); i += 1 { - if buf[i] == '\r' && i + 1 < len(buf) && buf[i+1] == '\n' { + if buf[i] == '\r' && i + 1 < len(buf) && buf[i + 1] == '\n' { header_len = i break } } - + if header_len == -1 { // Need to read more data // For simplicity, assume we have all data break } - + // Skip CRLF cursor = header_len + 2 - + if cursor == line_end + 2 { // Empty line - headers done break } - + // Extract header next_header_end := -1 for i := cursor; i < len(buf); i += 1 { - if buf[i] == '\r' && i + 1 < len(buf) && buf[i+1] == '\n' { + if buf[i] == '\r' && i + 1 < len(buf) && buf[i + 1] == '\n' { next_header_end = i break } } - + if next_header_end == -1 { break } - + header_line := string(buf[cursor:next_header_end]) - + // Parse header: Name: Value colon_idx := strings.index(header_line, ":") if colon_idx != -1 { name := strings.to_lower(strings.trim_space(header_line[:colon_idx])) - value := strings.trim_space(header_line[colon_idx+1:]) + value := strings.trim_space(header_line[colon_idx + 1:]) header_set(&req.header, name, value) } - + cursor = next_header_end + 2 } - + // Check for Content-Length and set body read size if header_has(&req.header, "content-length") { len_str := header_get(&req.header, "content-length") @@ -164,55 +170,55 @@ read_request :: proc(conn: net.TCP_Socket, req: ^Request) -> os.Errno { } } } - + return nil } // write_response writes HTTP response to the connection -write_response :: proc(conn: net.TCP_Socket, resp: ^Response) -> os.Errno { +write_response :: proc(conn: net.TCP_Socket, resp: ^Response) -> os.Error { // Build response response_data: strings.Builder defer strings.builder_destroy(&response_data) - + // Status line with complete status text mappings status_text := get_status_text(resp.status_code) - + // Determine protocol version protocol := "HTTP/1.1" // In real implementation, check request for HTTP/1.0 or HTTP/1.1 - + fmt.sbprintf(&response_data, "%s %d %s\r\n", protocol, resp.status_code, status_text) - + // Write headers sync.mutex_lock(&resp.header.mutex) defer sync.mutex_unlock(&resp.header.mutex) - + if resp.header.data != nil { for key, value in resp.header.data { fmt.sbprintf(&response_data, "%s: %s\r\n", key, value) } } - + // Content-Length header (if not set) if !header_has(&resp.header, "content-length") && len(resp.body) > 0 { fmt.sbprintf(&response_data, "Content-Length: %d\r\n", len(resp.body)) } - + // Empty line fmt.sbprintf(&response_data, "\r\n") - + // Body if len(resp.body) > 0 { fmt.sbprintf(&response_data, "%s", string(resp.body)) } - + // Send response data := transmute([]byte)strings.to_string(response_data) _, send_err := net.send_tcp(conn, data) if send_err != nil { return nil } - + return nil } @@ -221,77 +227,139 @@ write_response :: proc(conn: net.TCP_Socket, resp: ^Response) -> os.Errno { get_status_text :: proc(code: int) -> string { switch code { // 1xx Informational - case 100: return "Continue" - case 101: return "Switching Protocols" - case 102: return "Processing" - case 103: return "Early Hints" - + case 100: + return "Continue" + case 101: + return "Switching Protocols" + case 102: + return "Processing" + case 103: + return "Early Hints" + // 2xx Success - case 200: return "OK" - case 201: return "Created" - case 202: return "Accepted" - case 203: return "Non-Authoritative Information" - case 204: return "No Content" - case 205: return "Reset Content" - case 206: return "Partial Content" - case 207: return "Multi-Status" - case 208: return "Already Reported" - case 226: return "IM Used" - + case 200: + return "OK" + case 201: + return "Created" + case 202: + return "Accepted" + case 203: + return "Non-Authoritative Information" + case 204: + return "No Content" + case 205: + return "Reset Content" + case 206: + return "Partial Content" + case 207: + return "Multi-Status" + case 208: + return "Already Reported" + case 226: + return "IM Used" + // 3xx Redirection - case 300: return "Multiple Choices" - case 301: return "Moved Permanently" - case 302: return "Found" - case 303: return "See Other" - case 304: return "Not Modified" - case 305: return "Use Proxy" - case 307: return "Temporary Redirect" - case 308: return "Permanent Redirect" - + case 300: + return "Multiple Choices" + case 301: + return "Moved Permanently" + case 302: + return "Found" + case 303: + return "See Other" + case 304: + return "Not Modified" + case 305: + return "Use Proxy" + case 307: + return "Temporary Redirect" + case 308: + return "Permanent Redirect" + // 4xx Client Error - case 400: return "Bad Request" - case 401: return "Unauthorized" - case 402: return "Payment Required" - case 403: return "Forbidden" - case 404: return "Not Found" - case 405: return "Method Not Allowed" - case 406: return "Not Acceptable" - case 407: return "Proxy Authentication Required" - case 408: return "Request Timeout" - case 409: return "Conflict" - case 410: return "Gone" - case 411: return "Length Required" - case 412: return "Precondition Failed" - case 413: return "Payload Too Large" - case 414: return "URI Too Long" - case 415: return "Unsupported Media Type" - case 416: return "Range Not Satisfiable" - case 417: return "Expectation Failed" - case 418: return "I'm a teapot" - case 421: return "Misdirected Request" - case 422: return "Unprocessable Entity" - case 423: return "Locked" - case 424: return "Failed Dependency" - case 425: return "Too Early" - case 426: return "Upgrade Required" - case 428: return "Precondition Required" - case 429: return "Too Many Requests" - case 431: return "Request Header Fields Too Large" - case 451: return "Unavailable For Legal Reasons" - + case 400: + return "Bad Request" + case 401: + return "Unauthorized" + case 402: + return "Payment Required" + case 403: + return "Forbidden" + case 404: + return "Not Found" + case 405: + return "Method Not Allowed" + case 406: + return "Not Acceptable" + case 407: + return "Proxy Authentication Required" + case 408: + return "Request Timeout" + case 409: + return "Conflict" + case 410: + return "Gone" + case 411: + return "Length Required" + case 412: + return "Precondition Failed" + case 413: + return "Payload Too Large" + case 414: + return "URI Too Long" + case 415: + return "Unsupported Media Type" + case 416: + return "Range Not Satisfiable" + case 417: + return "Expectation Failed" + case 418: + return "I'm a teapot" + case 421: + return "Misdirected Request" + case 422: + return "Unprocessable Entity" + case 423: + return "Locked" + case 424: + return "Failed Dependency" + case 425: + return "Too Early" + case 426: + return "Upgrade Required" + case 428: + return "Precondition Required" + case 429: + return "Too Many Requests" + case 431: + return "Request Header Fields Too Large" + case 451: + return "Unavailable For Legal Reasons" + // 5xx Server Error - case 500: return "Internal Server Error" - case 501: return "Not Implemented" - case 502: return "Bad Gateway" - case 503: return "Service Unavailable" - case 504: return "Gateway Timeout" - case 505: return "HTTP Version Not Supported" - case 506: return "Variant Also Negotiates" - case 507: return "Insufficient Storage" - case 508: return "Loop Detected" - case 510: return "Not Extended" - case 511: return "Network Authentication Required" - + case 500: + return "Internal Server Error" + case 501: + return "Not Implemented" + case 502: + return "Bad Gateway" + case 503: + return "Service Unavailable" + case 504: + return "Gateway Timeout" + case 505: + return "HTTP Version Not Supported" + case 506: + return "Variant Also Negotiates" + case 507: + return "Insufficient Storage" + case 508: + return "Loop Detected" + case 510: + return "Not Extended" + case 511: + return "Network Authentication Required" + case: if code >= 100 && code < 200 { return "Informational" diff --git a/hypersock_http/tls.odin b/hypersock_http/tls.odin index 6896d5e..c0a7a69 100644 --- a/hypersock_http/tls.odin +++ b/hypersock_http/tls.odin @@ -14,20 +14,20 @@ package hypersock_http * or external libraries for secure connections. */ -import "core:net" import "core:c" +import "core:fmt" +import "core:net" import "core:os" import "core:strings" import "core:time" -import "core:fmt" // TLS Protocol versions TLS_Version :: enum { None, - Version_1_0, // TLS 1.0 (deprecated, not recommended) - Version_1_1, // TLS 1.1 (deprecated) - Version_1_2, // TLS 1.2 (recommended for production) - Version_1_3, // TLS 1.3 (latest) + Version_1_0, // TLS 1.0 (deprecated, not recommended) + Version_1_1, // TLS 1.1 (deprecated) + Version_1_2, // TLS 1.2 (recommended for production) + Version_1_3, // TLS 1.3 (latest) } // TLS Connection state @@ -41,104 +41,104 @@ TLS_State :: enum { // Certificate information Certificate :: struct { // Certificate subject - subject_common_name: string, + subject_common_name: string, subject_organization: string, - subject_country: string, - + subject_country: string, + // Certificate issuer - issuer_common_name: string, - issuer_organization: string, - + issuer_common_name: string, + issuer_organization: string, + // Validity period - not_before: time.Time, - not_after: time.Time, - + not_before: time.Time, + not_after: time.Time, + // Certificate details - serial_number: string, - version: u32, - public_key_alg: string, - signature_alg: string, - + serial_number: string, + version: u32, + public_key_alg: string, + signature_alg: string, + // Extended validation - dns_names: []string, - email_addresses: []string, - ip_addresses: []string, - + dns_names: []string, + email_addresses: []string, + ip_addresses: []string, + // Validation status - is_valid: bool, - is_trusted: bool, - error_message: string, + is_valid: bool, + is_trusted: bool, + error_message: string, } // TLS Configuration TLS_Config :: struct { // Server name for SNI (Server Name Indication) - server_name: string, - + server_name: string, + // Protocol version - min_version: TLS_Version, - max_version: TLS_Version, - + min_version: TLS_Version, + max_version: TLS_Version, + // Certificate verification - insecure_skip_verify: bool, - ca_certificates: []byte, // PEM encoded CA certs - client_certificates: []byte, // PEM encoded client certs - client_private_key: []byte, // PEM encoded private key - + insecure_skip_verify: bool, + ca_certificates: []byte, // PEM encoded CA certs + client_certificates: []byte, // PEM encoded client certs + client_private_key: []byte, // PEM encoded private key + // Cipher suites - cipher_suites: []string, - + cipher_suites: []string, + // Session resumption session_tickets_enabled: bool, client_session_cache: bool, - + // ALPN (Application-Layer Protocol Negotiation) - next_protos: []string, - + next_protos: []string, + // Handshake timeout - handshake_timeout: time.Duration, - + handshake_timeout: time.Duration, + // Verification callback - verify_callback: proc(cert: ^Certificate) -> bool, - + verify_callback: proc(cert: ^Certificate) -> bool, + // Connection state (internal) - state: TLS_State, - last_error: string, - peer_certificates: []Certificate, - selected_protocol: string, + state: TLS_State, + last_error: string, + peer_certificates: []Certificate, + selected_protocol: string, } // TLS Socket wrapper TLS_Socket :: struct { // Underlying TCP socket - tcp_conn: net.TCP_Socket, - + tcp_conn: net.TCP_Socket, + // TLS configuration - config: ^TLS_Config, - + config: ^TLS_Config, + // Connection state - state: TLS_State, - is_server: bool, - + state: TLS_State, + is_server: bool, + // Handshake data - local_random: [32]byte, - remote_random: [32]byte, - master_secret: [48]byte, - session_id: [32]byte, - + local_random: [32]byte, + remote_random: [32]byte, + master_secret: [48]byte, + session_id: [32]byte, + // Protocol info - version: TLS_Version, - cipher_suite: string, - + version: TLS_Version, + cipher_suite: string, + // Connection statistics - bytes_sent: u64, - bytes_received: u64, + bytes_sent: u64, + bytes_received: u64, handshake_start: time.Time, handshake_end: time.Time, - + // OpenSSL handles - openssl_ctx: ^SSL_CTX, - openssl_sock: ^OpenSSL_Socket, + openssl_ctx: ^SSL_CTX, + openssl_sock: ^OpenSSL_Socket, } // Convert OpenSSL version string to TLS_Version @@ -152,7 +152,7 @@ tls_version_from_string :: proc(version_str: string) -> TLS_Version { } else if strings.contains(version_str, "TLSv1.0") { return .Version_1_0 } - return .Version_1_2 // Default + return .Version_1_2 // Default } // Create default TLS configuration @@ -166,7 +166,7 @@ tls_config_default :: proc() -> TLS_Config { config.client_session_cache = false config.handshake_timeout = 10 * time.Second config.state = .Idle - + return config } @@ -174,7 +174,13 @@ tls_config_default :: proc() -> TLS_Config { tls_socket_new :: proc(tcp_conn: net.TCP_Socket, config: ^TLS_Config) -> ^TLS_Socket { tls := new(TLS_Socket) tls.tcp_conn = tcp_conn - tls.config = config + tls_config := new(TLS_Config) + if config != nil { + tls_config^ = config^ + } else { + tls_config^ = tls_config_default() + } + tls.config = tls_config tls.state = .Idle tls.is_server = false return tls @@ -182,18 +188,18 @@ tls_socket_new :: proc(tcp_conn: net.TCP_Socket, config: ^TLS_Config) -> ^TLS_So // TLS Handshake - performs TLS handshake // Uses OpenSSL foreign bindings for actual encryption -tls_handshake :: proc(tls: ^TLS_Socket) -> os.Errno { +tls_handshake :: proc(tls: ^TLS_Socket) -> os.Error { if tls.state == .Connected { return os.ERROR_NONE } - + if tls.state != .Idle && tls.state != .Handshaking { - return os.EINVAL + return invalid_parameter_error() } - + tls.state = .Handshaking tls.handshake_start = time.now() - + // Create OpenSSL context if not exists if tls.openssl_ctx == nil { ctx, err := openssl_client_context_new() @@ -203,7 +209,7 @@ tls_handshake :: proc(tls: ^TLS_Socket) -> os.Errno { tls.config.last_error = "Failed to create OpenSSL context" return err } - + // Configure verification if tls.config.insecure_skip_verify { openssl_set_verify_mode(ctx, SSL_VERIFY_NONE) @@ -216,13 +222,18 @@ tls_handshake :: proc(tls: ^TLS_Socket) -> os.Errno { _ = openssl_load_ca_file(ctx, "/etc/ssl/certs/ca-certificates.crt") } } - + tls.openssl_ctx = ctx } - + // Create OpenSSL socket wrapper if tls.openssl_sock == nil { - sock, err := openssl_socket_new(tls.openssl_ctx, cast(os.Socket)tls.tcp_conn, tls.is_server) + sock, err := openssl_socket_new( + tls.openssl_ctx, + tls.tcp_conn, + tls.is_server, + tls.config.server_name, + ) if err != os.ERROR_NONE { tls.state = .Failed tls.config.state = .Failed @@ -231,44 +242,44 @@ tls_handshake :: proc(tls: ^TLS_Socket) -> os.Errno { } tls.openssl_sock = sock } - + // Perform handshake - err: os.Errno + err: os.Error if tls.is_server { err = openssl_accept(tls.openssl_sock) } else { err = openssl_connect(tls.openssl_sock) } - + if err == os.ERROR_NONE { tls.state = .Connected tls.config.state = .Connected tls.handshake_end = time.now() - + // Get TLS version and cipher tls.version = tls_version_from_string(openssl_get_version(tls.openssl_sock)) tls.cipher_suite = openssl_get_cipher(tls.openssl_sock) - - fmt.println("TLS Handshake completed:", tls.version, tls.cipher_suite) - } else if err == os.EAGAIN { + + // fmt.println("TLS Handshake completed:", tls.version, tls.cipher_suite) + } else if err == would_block_error() { // Handshake in progress - return os.EAGAIN + return would_block_error() } else { tls.state = .Failed tls.config.state = .Failed tls.config.last_error = openssl_get_error_string() return err } - + return os.ERROR_NONE } // Read from TLS connection -tls_read :: proc(tls: ^TLS_Socket, p: []byte) -> (n: int, err: os.Errno) { +tls_read :: proc(tls: ^TLS_Socket, p: []byte) -> (n: int, err: os.Error) { if tls.openssl_sock == nil || !tls.openssl_sock.connected { - return 0, os.ENOTCONN + return 0, not_connected_error() } - + n, err = openssl_read(tls.openssl_sock, p) if n > 0 { tls.bytes_received += u64(n) @@ -280,11 +291,11 @@ tls_read :: proc(tls: ^TLS_Socket, p: []byte) -> (n: int, err: os.Errno) { } // Write to TLS connection -tls_write :: proc(tls: ^TLS_Socket, p: []byte) -> (n: int, err: os.Errno) { +tls_write :: proc(tls: ^TLS_Socket, p: []byte) -> (n: int, err: os.Error) { if tls.openssl_sock == nil || !tls.openssl_sock.connected { - return 0, os.ENOTCONN + return 0, not_connected_error() } - + n, err = openssl_write(tls.openssl_sock, p) if n > 0 { tls.bytes_sent += u64(n) @@ -296,27 +307,32 @@ tls_write :: proc(tls: ^TLS_Socket, p: []byte) -> (n: int, err: os.Errno) { } // Close TLS connection -tls_close :: proc(tls: ^TLS_Socket) -> os.Errno { +tls_close :: proc(tls: ^TLS_Socket) -> os.Error { if tls.state == .Idle { return os.ERROR_NONE } - + // Close OpenSSL connection if tls.openssl_sock != nil { openssl_close(tls.openssl_sock) tls.openssl_sock = nil } - + // Free OpenSSL context if tls.openssl_ctx != nil { SSL_CTX_free(tls.openssl_ctx) tls.openssl_ctx = nil } - + + if tls.config != nil { + free(tls.config) + tls.config = nil + } + // Close underlying TCP net.close(tls.tcp_conn) tls.state = .Idle - + return os.ERROR_NONE } @@ -331,7 +347,12 @@ tls_get_peer_certificates :: proc(tls: ^TLS_Socket) -> []Certificate { } // Get connection statistics -tls_get_stats :: proc(tls: ^TLS_Socket) -> (bytes_sent, bytes_received: u64, handshake_duration: time.Duration) { +tls_get_stats :: proc( + tls: ^TLS_Socket, +) -> ( + bytes_sent, bytes_received: u64, + handshake_duration: time.Duration, +) { // Framework placeholder - actual timing requires real TLS return tls.bytes_sent, tls.bytes_received, 0 } @@ -347,7 +368,7 @@ verify_certificate :: proc(config: ^TLS_Config, cert: ^Certificate) -> bool { config.last_error = "Certificate not trusted" return false } - + // Call verify callback if provided if config.verify_callback != nil { if !config.verify_callback(cert) { @@ -356,7 +377,7 @@ verify_certificate :: proc(config: ^TLS_Config, cert: ^Certificate) -> bool { return false } } - + cert.is_valid = true return true } @@ -367,7 +388,7 @@ verify_hostname :: proc(cert: ^Certificate, hostname: string) -> bool { if cert.subject_common_name == hostname { return true } - + // Check Subject Alternative Names (SAN) for dns in cert.dns_names { if dns == hostname { @@ -386,14 +407,14 @@ verify_hostname :: proc(cert: ^Certificate, hostname: string) -> bool { } } } - + // Check IP addresses for ip in cert.ip_addresses { if ip == hostname { return true } } - + return false } @@ -401,12 +422,12 @@ verify_hostname :: proc(cert: ^Certificate, hostname: string) -> bool { // Uses OpenSSL for actual parsing parse_x509_certificate :: proc(pem_data: []byte) -> (Certificate, bool) { cert: Certificate - + if len(pem_data) == 0 { cert.error_message = "Empty certificate data" return cert, false } - + // Create BIO from memory bio := BIO_new_mem_buf(rawptr(raw_data(pem_data)), c.int(len(pem_data))) if bio == nil { @@ -414,7 +435,7 @@ parse_x509_certificate :: proc(pem_data: []byte) -> (Certificate, bool) { return cert, false } defer BIO_free(bio) - + // Read X509 certificate from BIO x509_cert := PEM_read_bio_X509(bio, nil, nil, nil) if x509_cert == nil { @@ -422,7 +443,7 @@ parse_x509_certificate :: proc(pem_data: []byte) -> (Certificate, bool) { return cert, false } defer X509_free(x509_cert) - + // Extract subject name subject_name := X509_get_subject_name(x509_cert) if subject_name != nil { @@ -432,7 +453,7 @@ parse_x509_certificate :: proc(pem_data: []byte) -> (Certificate, bool) { cert.subject_common_name = string(subject_str) } } - + // Extract issuer name issuer_name := X509_get_issuer_name(x509_cert) if issuer_name != nil { @@ -442,7 +463,7 @@ parse_x509_certificate :: proc(pem_data: []byte) -> (Certificate, bool) { cert.issuer_common_name = string(issuer_str) } } - + // Extract validity periods not_before := X509_get_notBefore(x509_cert) not_after := X509_get_notAfter(x509_cert) @@ -450,13 +471,13 @@ parse_x509_certificate :: proc(pem_data: []byte) -> (Certificate, bool) { cert.not_before = time.now() // Simplified - would need ASN1_TIME parsing } if not_after != nil { - cert.not_after = time.now() // Simplified - would need ASN1_TIME parsing + cert.not_after = time.now() // Simplified - would need ASN1_TIME parsing } - + // Mark as valid if we got this far cert.is_valid = true - cert.is_trusted = true // Would need verification against CA store - + cert.is_trusted = true // Would need verification against CA store + return cert, true } @@ -465,9 +486,9 @@ get_peer_certificate :: proc(tls: ^TLS_Socket) -> (Certificate, bool) { if tls == nil || tls.openssl_sock == nil || tls.openssl_sock.ssl == nil { return {}, false } - + cert: Certificate - + // Get peer certificate x509 := SSL_get_peer_certificate(tls.openssl_sock.ssl) if x509 == nil { @@ -475,7 +496,7 @@ get_peer_certificate :: proc(tls: ^TLS_Socket) -> (Certificate, bool) { return cert, false } defer X509_free(x509) - + // Extract subject name subject_name := X509_get_subject_name(x509) if subject_name != nil { @@ -485,7 +506,7 @@ get_peer_certificate :: proc(tls: ^TLS_Socket) -> (Certificate, bool) { cert.subject_common_name = string(subject_str) } } - + // Extract issuer name issuer_name := X509_get_issuer_name(x509) if issuer_name != nil { @@ -495,7 +516,7 @@ get_peer_certificate :: proc(tls: ^TLS_Socket) -> (Certificate, bool) { cert.issuer_common_name = string(issuer_str) } } - + // Extract validity periods (simplified) not_before := X509_get_notBefore(x509) not_after := X509_get_notAfter(x509) @@ -503,10 +524,10 @@ get_peer_certificate :: proc(tls: ^TLS_Socket) -> (Certificate, bool) { cert.not_before = time.now() cert.not_after = time.now() } - + cert.is_valid = true cert.is_trusted = !tls.config.insecure_skip_verify - + return cert, true } @@ -525,15 +546,20 @@ tls_dump_config :: proc(config: ^TLS_Config) -> string { builder: strings.Builder builder = strings.builder_make() defer strings.builder_destroy(&builder) - + fmt.sbprintf(&builder, "\\n=== TLS Configuration ===\\n") fmt.sbprintf(&builder, "Server Name: %s\\n", config.server_name) - fmt.sbprintf(&builder, "Protocol: v%d.%d to v%d.%d\\n", - 1, int(config.min_version) - 1, - 1, int(config.max_version) - 1) + fmt.sbprintf( + &builder, + "Protocol: v%d.%d to v%d.%d\\n", + 1, + int(config.min_version) - 1, + 1, + int(config.max_version) - 1, + ) fmt.sbprintf(&builder, "Skip Verify: %v\\n", config.insecure_skip_verify) fmt.sbprintf(&builder, "Handshake Timeout: %v\\n", config.handshake_timeout) - + if len(config.ca_certificates) > 0 { fmt.sbprintf(&builder, "CA Certificates: %d bytes\\n", len(config.ca_certificates)) } @@ -543,10 +569,10 @@ tls_dump_config :: proc(config: ^TLS_Config) -> string { if len(config.next_protos) > 0 { fmt.sbprintf(&builder, "Next Protocols: %v\\n", config.next_protos) } - + fmt.sbprintf(&builder, "State: %v\\n", config.state) fmt.sbprintf(&builder, "===========================\\n") - + return strings.to_string(builder) } @@ -554,29 +580,34 @@ tls_connection_info :: proc(tls: ^TLS_Socket) -> string { builder: strings.Builder builder = strings.builder_make() defer strings.builder_destroy(&builder) - + fmt.sbprintf(&builder, "\\n=== TLS Connection Info ===\\n") fmt.sbprintf(&builder, "State: %v\\n", tls.state) fmt.sbprintf(&builder, "Version: %v\\n", tls.version) fmt.sbprintf(&builder, "Cipher: %s\\n", tls.cipher_suite) fmt.sbprintf(&builder, "Selected Protocol: %s\\n", tls.config.selected_protocol) - + sent, recv, dur := tls_get_stats(tls) fmt.sbprintf(&builder, "Bytes Sent: %d\\n", sent) fmt.sbprintf(&builder, "Bytes Received: %d\\n", recv) fmt.sbprintf(&builder, "Handshake Duration: %v\\n", dur) - + certs := tls_get_peer_certificates(tls) if len(certs) > 0 { fmt.sbprintf(&builder, "Peer Certificates: %d\\n", len(certs)) - for cert in certs { - fmt.sbprintf(&builder, " [%d] CN: %s, Org: %s, Valid: %v\\n", - cert.subject_common_name, cert.subject_organization, cert.is_valid) - } + for cert in certs { + fmt.sbprintf( + &builder, + " [%d] CN: %s, Org: %s, Valid: %v\\n", + cert.subject_common_name, + cert.subject_organization, + cert.is_valid, + ) + } } - + fmt.sbprintf(&builder, "=============================\\n") - + return strings.to_string(builder) } @@ -584,28 +615,34 @@ tls_connection_info :: proc(tls: ^TLS_Socket) -> string { // perform_tls_handshake_on_socket performs TLS handshake on an existing TCP socket // This is the main entry point for HTTP client TLS support -perform_tls_handshake_on_socket :: proc(tcp_conn: net.TCP_Socket, server_name: string, - insecure_skip: bool) -> (^TLS_Socket, os.Errno) { +perform_tls_handshake_on_socket :: proc( + tcp_conn: net.TCP_Socket, + server_name: string, + insecure_skip: bool, +) -> ( + ^TLS_Socket, + os.Error, +) { // Create TLS config config := tls_config_default() config.server_name = server_name config.insecure_skip_verify = insecure_skip - + if insecure_skip { fmt.println("TLS WARNING: Insecure - skipping certificate verification") } - + // Wrap TCP socket in TLS socket tls_socket := tls_socket_new(tcp_conn, &config) - + // Perform handshake err := tls_handshake(tls_socket) if err != 0 { return nil, err } - + // Log connection info - fmt.println(tls_connection_info(tls_socket)) - + // fmt.println(tls_connection_info(tls_socket)) + return tls_socket, os.ERROR_NONE } diff --git a/hypersock_http/tls_openssl.odin b/hypersock_http/tls_openssl.odin index 7300725..b6b9274 100644 --- a/hypersock_http/tls_openssl.odin +++ b/hypersock_http/tls_openssl.odin @@ -14,7 +14,17 @@ package hypersock_http */ import "core:c" +import "core:net" import "core:os" +import "core:strings" + +when ODIN_OS == .Linux { + foreign import libssl "system:ssl" + foreign import libcrypto "system:crypto" +} else when ODIN_OS == .Windows { + foreign import libssl "windows/libssl.lib" + foreign import libcrypto "windows/libcrypto.lib" +} // OpenSSL Constants SSL_FILETYPE_PEM :: 1 @@ -36,16 +46,12 @@ SSL_ERROR_SYSCALL :: 5 SSL_ERROR_ZERO_RETURN :: 6 SSL_ERROR_WANT_CONNECT :: 7 SSL_ERROR_WANT_ACCEPT :: 8 +SSL_CTRL_SET_TLSEXT_HOSTNAME :: 55 +TLSEXT_NAMETYPE_host_name :: 0 // Foreign OpenSSL bindings @(default_calling_convention = "c") -foreign { - // SSL Library init - SSL_library_init :: proc() -> c.int --- - SSL_load_error_strings :: proc() --- - ERR_load_crypto_strings :: proc() --- - ERR_load_SSL_strings :: proc() --- - +foreign libssl { // SSL Context SSL_CTX_new :: proc(method: ^SSL_METHOD) -> ^SSL_CTX --- SSL_CTX_free :: proc(ctx: ^SSL_CTX) --- @@ -56,11 +62,12 @@ foreign { SSL_CTX_use_PrivateKey_file :: proc(ctx: ^SSL_CTX, file: cstring, type: c.int) -> c.int --- SSL_CTX_check_private_key :: proc(ctx: ^SSL_CTX) -> c.int --- SSL_CTX_set_default_verify_paths :: proc(ctx: ^SSL_CTX) -> c.int --- - + // SSL Connection SSL_new :: proc(ctx: ^SSL_CTX) -> ^SSL --- SSL_free :: proc(ssl: ^SSL) --- SSL_set_fd :: proc(ssl: ^SSL, fd: c.int) -> c.int --- + SSL_ctrl :: proc(ssl: ^SSL, cmd: c.int, larg: c.long, parg: rawptr) -> c.long --- SSL_connect :: proc(ssl: ^SSL) -> c.int --- SSL_accept :: proc(ssl: ^SSL) -> c.int --- SSL_shutdown :: proc(ssl: ^SSL) -> c.int --- @@ -71,19 +78,23 @@ foreign { SSL_set_accept_state :: proc(ssl: ^SSL) --- SSL_do_handshake :: proc(ssl: ^SSL) -> c.int --- SSL_get_version :: proc(ssl: ^SSL) -> cstring --- - SSL_get_cipher :: proc(ssl: ^SSL) -> cstring --- + SSL_get_current_cipher :: proc(ssl: ^SSL) -> ^SSL_CIPHER --- + SSL_CIPHER_get_name :: proc(cipher: ^SSL_CIPHER) -> cstring --- SSL_get_peer_certificate :: proc(ssl: ^SSL) -> ^X509 --- - + // SSL Methods TLS_client_method :: proc() -> ^SSL_METHOD --- TLS_server_method :: proc() -> ^SSL_METHOD --- - - // Error handling +} + +// Error handling +@(default_calling_convention = "c") +foreign libcrypto { ERR_get_error :: proc() -> c.ulong --- ERR_error_string :: proc(e: c.ulong, buf: [^]c.char) -> cstring --- ERR_error_string_n :: proc(e: c.ulong, buf: [^]c.char, len: c.int) --- ERR_clear_error :: proc() --- - + // X509 Certificate X509_free :: proc(cert: ^X509) --- X509_get_subject_name :: proc(cert: ^X509) -> ^X509_NAME --- @@ -91,7 +102,7 @@ foreign { X509_get_notBefore :: proc(cert: ^X509) -> ^ASN1_TIME --- X509_get_notAfter :: proc(cert: ^X509) -> ^ASN1_TIME --- X509_NAME_oneline :: proc(name: ^X509_NAME, buf: [^]c.char, size: c.int) -> cstring --- - + // BIO (Basic I/O) for memory buffers BIO_new_mem_buf :: proc(buf: rawptr, len: c.int) -> ^BIO --- BIO_free :: proc(bio: ^BIO) --- @@ -102,6 +113,7 @@ foreign { SSL_METHOD :: struct {} SSL_CTX :: struct {} SSL :: struct {} +SSL_CIPHER :: struct {} X509 :: struct {} X509_NAME :: struct {} ASN1_TIME :: struct {} @@ -110,7 +122,12 @@ BIO :: struct {} // Callback types SSL_Verify_Callback :: proc "c" (preverify_ok: c.int, ctx: ^X509_STORE_CTX) -> c.int X509_STORE_CTX :: struct {} -PEM_Password_Callback :: proc "c" (buf: [^]c.char, size: c.int, rwflag: c.int, userdata: rawptr) -> c.int +PEM_Password_Callback :: proc "c" ( + buf: [^]c.char, + size: c.int, + rwflag: c.int, + userdata: rawptr, +) -> c.int // OpenSSL initialization flag openssl_initialized := false @@ -120,12 +137,7 @@ openssl_init :: proc() { if openssl_initialized { return } - - SSL_library_init() - SSL_load_error_strings() - ERR_load_crypto_strings() - ERR_load_SSL_strings() - + openssl_initialized = true } @@ -135,7 +147,7 @@ openssl_get_error_string :: proc() -> string { if err_code == 0 { return "Unknown error" } - + buf: [256]c.char ERR_error_string_n(err_code, &buf[0], 256) return string(cstring(&buf[0])) @@ -150,211 +162,239 @@ openssl_clear_errors :: proc() { OpenSSL_Socket :: struct { ssl: ^SSL, ctx: ^SSL_CTX, - socket_fd: os.Socket, + socket_fd: net.TCP_Socket, is_server: bool, connected: bool, } // Create OpenSSL client context -openssl_client_context_new :: proc() -> (^SSL_CTX, os.Errno) { +openssl_client_context_new :: proc() -> (^SSL_CTX, os.Error) { openssl_init() - + method := TLS_client_method() if method == nil { - return nil, os.EINVAL + return nil, invalid_parameter_error() } - + ctx := SSL_CTX_new(method) if ctx == nil { - return nil, os.EINVAL + return nil, invalid_parameter_error() } - + // Set default verify mode SSL_CTX_set_verify(ctx, SSL_VERIFY_PEER, nil) SSL_CTX_set_verify_depth(ctx, 4) - + // Try to load default CA certificates SSL_CTX_set_default_verify_paths(ctx) - + return ctx, os.ERROR_NONE } // Create OpenSSL server context -openssl_server_context_new :: proc(cert_file, key_file: string) -> (^SSL_CTX, os.Errno) { +openssl_server_context_new :: proc(cert_file, key_file: string) -> (^SSL_CTX, os.Error) { openssl_init() - + method := TLS_server_method() if method == nil { - return nil, os.EINVAL + return nil, invalid_parameter_error() } - + ctx := SSL_CTX_new(method) if ctx == nil { - return nil, os.EINVAL + return nil, invalid_parameter_error() } - + // Load certificate cert_cstr := cstring(raw_data(cert_file)) if SSL_CTX_use_certificate_file(ctx, cert_cstr, SSL_FILETYPE_PEM) != 1 { SSL_CTX_free(ctx) - return nil, os.EINVAL + return nil, invalid_parameter_error() } - + // Load private key key_cstr := cstring(raw_data(key_file)) if SSL_CTX_use_PrivateKey_file(ctx, key_cstr, SSL_FILETYPE_PEM) != 1 { SSL_CTX_free(ctx) - return nil, os.EINVAL + return nil, invalid_parameter_error() } - + // Verify private key if SSL_CTX_check_private_key(ctx) != 1 { SSL_CTX_free(ctx) - return nil, os.EINVAL + return nil, invalid_parameter_error() } - + return ctx, os.ERROR_NONE } // Create OpenSSL socket wrapper -openssl_socket_new :: proc(ctx: ^SSL_CTX, socket_fd: os.Socket, is_server: bool) -> (^OpenSSL_Socket, os.Errno) { +openssl_socket_new :: proc( + ctx: ^SSL_CTX, + socket_fd: net.TCP_Socket, + is_server: bool, + server_name: string = "", +) -> ( + ^OpenSSL_Socket, + os.Error, +) { ssl := SSL_new(ctx) if ssl == nil { - return nil, os.EINVAL + return nil, invalid_parameter_error() } - + + if !is_server && server_name != "" { + server_name_cstr, alloc_error := strings.clone_to_cstring(server_name) + if alloc_error != nil { + SSL_free(ssl) + return nil, invalid_parameter_error() + } + defer delete(server_name_cstr) + + if SSL_ctrl( + ssl, + SSL_CTRL_SET_TLSEXT_HOSTNAME, + TLSEXT_NAMETYPE_host_name, + rawptr(server_name_cstr), + ) != + 1 { + SSL_free(ssl) + return nil, invalid_parameter_error() + } + } + // Set the socket file descriptor if SSL_set_fd(ssl, c.int(socket_fd)) != 1 { SSL_free(ssl) - return nil, os.EINVAL + return nil, invalid_parameter_error() } - + sock := new(OpenSSL_Socket) sock.ssl = ssl - sock.ctx = ctx // Note: we don't own the context, don't free it + sock.ctx = ctx // Note: we don't own the context, don't free it sock.socket_fd = socket_fd sock.is_server = is_server sock.connected = false - + if is_server { SSL_set_accept_state(ssl) } else { SSL_set_connect_state(ssl) } - + return sock, os.ERROR_NONE } // Perform TLS handshake -openssl_handshake :: proc(sock: ^OpenSSL_Socket) -> os.Errno { +openssl_handshake :: proc(sock: ^OpenSSL_Socket) -> os.Error { if sock == nil || sock.ssl == nil { - return os.EINVAL + return invalid_parameter_error() } - + result := SSL_do_handshake(sock.ssl) if result == 1 { sock.connected = true return os.ERROR_NONE } - + err := SSL_get_error(sock.ssl, result) if err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE { // Would block, need to retry - return os.EAGAIN + return would_block_error() } - - return os.EIO + + return io_error() } // Connect (client handshake) -openssl_connect :: proc(sock: ^OpenSSL_Socket) -> os.Errno { +openssl_connect :: proc(sock: ^OpenSSL_Socket) -> os.Error { if sock == nil || sock.ssl == nil { - return os.EINVAL + return invalid_parameter_error() } - + result := SSL_connect(sock.ssl) if result == 1 { sock.connected = true return os.ERROR_NONE } - + err := SSL_get_error(sock.ssl, result) if err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE { - return os.EAGAIN + return would_block_error() } - - return os.EIO + + return io_error() } // Accept (server handshake) -openssl_accept :: proc(sock: ^OpenSSL_Socket) -> os.Errno { +openssl_accept :: proc(sock: ^OpenSSL_Socket) -> os.Error { if sock == nil || sock.ssl == nil { - return os.EINVAL + return invalid_parameter_error() } - + result := SSL_accept(sock.ssl) if result == 1 { sock.connected = true return os.ERROR_NONE } - + err := SSL_get_error(sock.ssl, result) if err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE { - return os.EAGAIN + return would_block_error() } - - return os.EIO + + return io_error() } // Read from TLS connection -openssl_read :: proc(sock: ^OpenSSL_Socket, buf: []byte) -> (int, os.Errno) { +openssl_read :: proc(sock: ^OpenSSL_Socket, buf: []byte) -> (int, os.Error) { if sock == nil || sock.ssl == nil || len(buf) == 0 { - return 0, os.EINVAL + return 0, invalid_parameter_error() } - + result := SSL_read(sock.ssl, rawptr(raw_data(buf)), c.size_t(len(buf))) if result > 0 { return int(result), os.ERROR_NONE } - + err := SSL_get_error(sock.ssl, result) switch err { case SSL_ERROR_ZERO_RETURN: // Connection closed return 0, os.ERROR_NONE case SSL_ERROR_WANT_READ, SSL_ERROR_WANT_WRITE: - return 0, os.EAGAIN + return 0, would_block_error() case SSL_ERROR_SYSCALL: - return 0, os.EIO + return 0, io_error() case SSL_ERROR_SSL: - return 0, os.EIO + return 0, io_error() } - - return 0, os.EIO + + return 0, io_error() } // Write to TLS connection -openssl_write :: proc(sock: ^OpenSSL_Socket, buf: []byte) -> (int, os.Errno) { +openssl_write :: proc(sock: ^OpenSSL_Socket, buf: []byte) -> (int, os.Error) { if sock == nil || sock.ssl == nil || len(buf) == 0 { - return 0, os.EINVAL + return 0, invalid_parameter_error() } - + result := SSL_write(sock.ssl, rawptr(raw_data(buf)), c.size_t(len(buf))) if result > 0 { return int(result), os.ERROR_NONE } - + err := SSL_get_error(sock.ssl, result) switch err { case SSL_ERROR_WANT_READ, SSL_ERROR_WANT_WRITE: - return 0, os.EAGAIN + return 0, would_block_error() case SSL_ERROR_SYSCALL: - return 0, os.EIO + return 0, io_error() case SSL_ERROR_SSL: - return 0, os.EIO + return 0, io_error() } - - return 0, os.EIO + + return 0, io_error() } // Close TLS connection @@ -362,13 +402,13 @@ openssl_close :: proc(sock: ^OpenSSL_Socket) { if sock == nil { return } - + if sock.ssl != nil { SSL_shutdown(sock.ssl) SSL_free(sock.ssl) sock.ssl = nil } - + free(sock) } @@ -377,12 +417,12 @@ openssl_get_version :: proc(sock: ^OpenSSL_Socket) -> string { if sock == nil || sock.ssl == nil { return "" } - + version := SSL_get_version(sock.ssl) if version == nil { return "" } - + return string(version) } @@ -391,13 +431,18 @@ openssl_get_cipher :: proc(sock: ^OpenSSL_Socket) -> string { if sock == nil || sock.ssl == nil { return "" } - - cipher := SSL_get_cipher(sock.ssl) + + cipher := SSL_get_current_cipher(sock.ssl) if cipher == nil { return "" } - - return string(cipher) + + cipher_name := SSL_CIPHER_get_name(cipher) + if cipher_name == nil { + return "" + } + + return string(cipher_name) } // Load CA certificates from file @@ -405,7 +450,7 @@ openssl_load_ca_file :: proc(ctx: ^SSL_CTX, ca_file: string) -> bool { if ctx == nil || len(ca_file) == 0 { return false } - + ca_cstr := cstring(raw_data(ca_file)) result := SSL_CTX_load_verify_locations(ctx, ca_cstr, nil) return result == 1 diff --git a/hypersock_http/windows/.gitkeep b/hypersock_http/windows/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/hypersock_websocket/client.odin b/hypersock_websocket/client.odin index ec7fb89..9d4fa05 100644 --- a/hypersock_websocket/client.odin +++ b/hypersock_websocket/client.odin @@ -5,131 +5,155 @@ package hypersock_websocket * Based on gorilla/websocket client.go patterns */ +import http "../hypersock_http" +import "core:fmt" import "core:net" import "core:os" -import "core:fmt" -import "core:strings" import "core:strconv" +import "core:strings" import "core:time" -import http "../hypersock_http" // Dial connects to a WebSocket server // Returns the connection and any error -dial :: proc(url: string, dialer_arg: ^Dialer) -> (^Conn, os.Errno) { +dial :: proc(url: string, dialer_arg: ^Dialer) -> (^Conn, os.Error) { dialer := dialer_arg if dialer == nil { d := dialer_default() dialer = &d } - + // Parse URL scheme, host, path, port, ok := parse_ws_url(url) if !ok { - return nil, os.EINVAL + return nil, invalid_parameter_error() } - + // Determine if TLS is_tls := scheme == "wss" - + // Build address addr := fmt.tprintf("%s:%d", host, port) - + // Dial TCP connection socket: net.TCP_Socket - err: os.Errno - + err: os.Error + if dialer.net_dial != nil { socket, err = dialer.net_dial("tcp", addr) } else { - socket, net_err := net.dial_tcp(addr) + net_err: net.Network_Error + socket, net_err = net.dial_tcp(addr) if net_err != nil { - return nil, os.EINVAL + return nil, invalid_parameter_error() } } - + if err != os.ERROR_NONE { return nil, err } - + // Perform TLS handshake if needed _tls_socket: ^http.TLS_Socket if is_tls { // Perform TLS handshake - _tls_socket, tls_err := http.perform_tls_handshake_on_socket(socket, host, false) + tls_err: os.Error + _tls_socket, tls_err = http.perform_tls_handshake_on_socket(socket, host, false) if tls_err != os.ERROR_NONE { net.close(socket) fmt.println("WebSocket TLS handshake failed:", tls_err) return nil, tls_err } - - fmt.println("WebSocket TLS handshake completed successfully") - - // The TLS socket now wraps the TCP socket - // For the WebSocket handshake, we still use the underlying TCP socket - // because the TLS handshake is already completed - // After the WebSocket handshake, we would use tls_write/tls_read - - // Free the TLS socket for now (framework - would need actual TLS library for encryption) - // delete(_tls_socket) // Can't delete custom types like this + + // fmt.println("WebSocket TLS handshake completed successfully") + } - + // Generate challenge key challenge_key := generate_challenge_key() - + // Build HTTP upgrade request - request := build_handshake_request(host, path, challenge_key, dialer.subprotocols, dialer.enable_compression) + request := build_handshake_request( + host, + path, + challenge_key, + dialer.subprotocols, + dialer.enable_compression, + ) defer delete(request) - + // Set handshake deadline if dialer.handshake_timeout > 0 { - deadline := time.time_add(time.now(), dialer.handshake_timeout) - timeout_ms := int(time.duration_milliseconds(time.since(deadline))) - if timeout_ms < 0 { timeout_ms = 0 } - net.set_option(socket, net.Socket_Option.Send_Timeout, timeout_ms) - net.set_option(socket, net.Socket_Option.Receive_Timeout, timeout_ms) + net.set_option(socket, net.Socket_Option.Send_Timeout, dialer.handshake_timeout) + net.set_option(socket, net.Socket_Option.Receive_Timeout, dialer.handshake_timeout) } - + // Send request - _, send_err := net.send_tcp(socket, transmute([]byte)request) - if send_err != nil { - net.close(socket) - return nil, os.EINVAL + if is_tls && _tls_socket != nil { + _, send_err := http.tls_write(_tls_socket, transmute([]byte)request) + if send_err != os.ERROR_NONE { + http.tls_close(_tls_socket) + free(_tls_socket) + return nil, send_err + } + } else { + _, send_err := net.send_tcp(socket, transmute([]byte)request) + if send_err != nil { + net.close(socket) + return nil, invalid_parameter_error() + } } - + // Read response response_buf := make([]byte, 4096) defer delete(response_buf) - - n, recv_err := net.recv_tcp(socket, response_buf) - if recv_err != nil { - net.close(socket) - return nil, os.EINVAL + + n := 0 + if is_tls && _tls_socket != nil { + recv_n, recv_err := http.tls_read(_tls_socket, response_buf) + if recv_err != os.ERROR_NONE { + http.tls_close(_tls_socket) + free(_tls_socket) + return nil, invalid_parameter_error() + } + n = recv_n + } else { + recv_n, recv_err := net.recv_tcp(socket, response_buf) + if recv_err != nil { + net.close(socket) + return nil, invalid_parameter_error() + } + n = recv_n } - + response := string(response_buf[:n]) - + // Parse and validate response status_code, accept_key, subprotocol, ok2 := parse_handshake_response(response) if !ok2 || status_code != 101 { net.close(socket) - return nil, os.ECONNREFUSED + return nil, connection_refused_error() } - + // Validate accept key expected_accept := compute_accept_key(challenge_key) if accept_key != expected_accept { - net.close(socket) - return nil, os.EINVAL + if _tls_socket != nil { + http.tls_close(_tls_socket) + free(_tls_socket) + } else { + net.close(socket) + } + return nil, invalid_parameter_error() } - + // Create WebSocket connection - conn := new_conn(socket, false, dialer.read_buffer_size, dialer.write_buffer_size) + conn := new_conn(socket, _tls_socket, false, dialer.read_buffer_size, dialer.write_buffer_size) conn.subprotocol = subprotocol - + // Clear deadlines - net.set_option(socket, net.Socket_Option.Send_Timeout, 0) - net.set_option(socket, net.Socket_Option.Receive_Timeout, 0) - + net.set_option(socket, net.Socket_Option.Send_Timeout, time.Duration(0)) + net.set_option(socket, net.Socket_Option.Receive_Timeout, time.Duration(0)) + return conn, os.ERROR_NONE } @@ -138,9 +162,9 @@ dial :: proc(url: string, dialer_arg: ^Dialer) -> (^Conn, os.Errno) { parse_ws_url :: proc(url_str: string) -> (scheme, host, path: string, port: int, ok: bool) { // Simple URL parser for ws:// and wss:// schemes ok = false - + url := url_str - + if strings.has_prefix(url, "ws://") { scheme = "ws" url = url[5:] @@ -152,7 +176,7 @@ parse_ws_url :: proc(url_str: string) -> (scheme, host, path: string, port: int, } else { return } - + // Find path path_idx := strings.index(url, "/") if path_idx == -1 { @@ -162,94 +186,107 @@ parse_ws_url :: proc(url_str: string) -> (scheme, host, path: string, port: int, host = url[:path_idx] path = url[path_idx:] } - + // Check for port in host port_idx := strings.last_index(host, ":") if port_idx != -1 { - port_str := host[port_idx+1:] + port_str := host[port_idx + 1:] // Parse port number using strconv - if parsed_port, ok := strconv.parse_int(port_str, 10); ok { + if parsed_port, parse_ok := strconv.parse_int(port_str, 10); parse_ok { port = int(parsed_port) } host = host[:port_idx] } - + ok = true return } // build_handshake_request builds the HTTP upgrade request -build_handshake_request :: proc(host, path, challenge_key: string, subprotocols: []string, enable_compression: bool) -> string { +build_handshake_request :: proc( + host, path, challenge_key: string, + subprotocols: []string, + enable_compression: bool, +) -> string { request: strings.Builder // strings.Builder is zero-initialized, no init needed - + fmt.sbprintf(&request, "GET %s HTTP/1.1\r\n", path) fmt.sbprintf(&request, "Host: %s\r\n", host) fmt.sbprintf(&request, "Upgrade: websocket\r\n") fmt.sbprintf(&request, "Connection: Upgrade\r\n") fmt.sbprintf(&request, "Sec-WebSocket-Key: %s\r\n", challenge_key) fmt.sbprintf(&request, "Sec-WebSocket-Version: 13\r\n") - + // Add subprotocols if specified if len(subprotocols) > 0 { protocols := strings.join(subprotocols, ", ") defer delete(protocols) fmt.sbprintf(&request, "Sec-WebSocket-Protocol: %s\r\n", protocols) } - + // Add compression extension if enabled if enable_compression { - fmt.sbprintf(&request, "Sec-WebSocket-Extensions: permessage-deflate; client_no_context_takeover; server_no_context_takeover\r\n") + fmt.sbprintf( + &request, + "Sec-WebSocket-Extensions: permessage-deflate; client_no_context_takeover; server_no_context_takeover\r\n", + ) } - + fmt.sbprintf(&request, "\r\n") - + return strings.to_string(request) } // parse_handshake_response parses the HTTP upgrade response // Returns: status_code, accept_key, subprotocol, ok -parse_handshake_response :: proc(response: string) -> (status_code: int, accept_key, subprotocol: string, ok: bool) { +parse_handshake_response :: proc( + response: string, +) -> ( + status_code: int, + accept_key, subprotocol: string, + ok: bool, +) { ok = false - + // Split response into lines lines := strings.split(response, "\r\n") defer delete(lines) - + if len(lines) < 1 { return } - + // Parse status line status_parts := strings.split(lines[0], " ") defer delete(status_parts) - + if len(status_parts) < 2 { return } - + // Parse status code using strconv - if code, ok := strconv.parse_int(status_parts[1], 10); ok { + if code, parse_ok := strconv.parse_int(status_parts[1], 10); parse_ok { status_code = int(code) } else { return } - + // Parse headers for i := 1; i < len(lines); i += 1 { line := lines[i] if line == "" { break } - + colon_idx := strings.index(line, ":") if colon_idx == -1 { continue } - + key := strings.to_lower(strings.trim_space(line[:colon_idx])) - value := strings.trim_space(line[colon_idx+1:]) - + value := strings.trim_space(line[colon_idx + 1:]) + switch key { case "sec-websocket-accept": accept_key = value @@ -257,11 +294,11 @@ parse_handshake_response :: proc(response: string) -> (status_code: int, accept_ subprotocol = value } } - + if accept_key == "" { return } - + ok = true return } diff --git a/hypersock_websocket/conn.odin b/hypersock_websocket/conn.odin index 5210248..4e7231e 100644 --- a/hypersock_websocket/conn.odin +++ b/hypersock_websocket/conn.odin @@ -5,14 +5,14 @@ package hypersock_websocket * Based on gorilla/websocket conn.go patterns */ +import http "../hypersock_http" +import "core:encoding/endian" +import "core:fmt" import "core:net" import "core:os" -import "core:fmt" -import "core:mem" -import "core:time" -import "core:sync" import "core:strings" -import "core:encoding/endian" +import "core:sync" +import "core:time" import "core:unicode/utf8" // Frame size limits @@ -29,12 +29,18 @@ Text_Message :: 1 Binary_Message :: 2 // new_conn creates a new WebSocket connection -new_conn :: proc(socket: net.TCP_Socket, is_server: bool, read_buf_size, write_buf_size: int) -> ^Conn { +new_conn :: proc( + socket: net.TCP_Socket, + tls_socket: ^http.TLS_Socket, + is_server: bool, + read_buf_size, write_buf_size: int, +) -> ^Conn { c := new(Conn) c.conn = socket + c.tls_socket = tls_socket c.is_server = is_server c.state = .Open - + // Initialize read buffer actual_read_buf_size := read_buf_size if actual_read_buf_size == 0 { @@ -43,7 +49,7 @@ new_conn :: proc(socket: net.TCP_Socket, is_server: bool, read_buf_size, write_b c.read_buf = make([dynamic]byte, 0, actual_read_buf_size) c.read_limit = 0 c.read_final = true - + // Initialize write buffer actual_write_buf_size := write_buf_size if actual_write_buf_size == 0 { @@ -52,68 +58,147 @@ new_conn :: proc(socket: net.TCP_Socket, is_server: bool, read_buf_size, write_b c.write_buf = make([dynamic]byte, 0, actual_write_buf_size + MAX_FRAME_HEADER_SIZE) c.write_deadline = time.Time{} c.read_deadline = time.Time{} + c.close_code = 0 + c.close_text = make([dynamic]byte) c.is_writing = false - + // Set default handlers - c.handle_ping = proc(data: string) -> os.Errno { + c.handle_ping = proc(data: string) -> os.Error { return os.ERROR_NONE } - c.handle_pong = proc(data: string) -> os.Errno { + c.handle_pong = proc(data: string) -> os.Error { return os.ERROR_NONE } - c.handle_close = proc(code: u16, text: string) -> os.Errno { + c.handle_close = proc(code: u16, text: string) -> os.Error { return os.ERROR_NONE } - + return c } +conn_read :: proc(c: ^Conn, p: []byte) -> (n: int, err: os.Error) { + if len(p) == 0 { + return 0, os.ERROR_NONE + } + + total := 0 + for total < len(p) { + if c.tls_socket != nil { + read_n, read_err := http.tls_read(c.tls_socket, p[total:]) + if read_err == would_block_error() { + continue + } + if read_err != os.ERROR_NONE { + return total, read_err + } + if read_n <= 0 { + return total, connection_reset_error() + } + total += read_n + continue + } + + recv_n, recv_err := net.recv_tcp(c.conn, p[total:]) + if recv_err != nil { + return total, invalid_parameter_error() + } + if recv_n <= 0 { + return total, connection_reset_error() + } + total += int(recv_n) + } + + return total, os.ERROR_NONE +} + +conn_write :: proc(c: ^Conn, p: []byte) -> os.Error { + if len(p) == 0 { + return os.ERROR_NONE + } + + total := 0 + for total < len(p) { + if c.tls_socket != nil { + send_n, send_err := http.tls_write(c.tls_socket, p[total:]) + if send_err == would_block_error() { + continue + } + if send_err != os.ERROR_NONE { + return send_err + } + if send_n <= 0 { + return connection_reset_error() + } + total += send_n + continue + } + + send_n, send_err := net.send_tcp(c.conn, p[total:]) + if send_err != nil { + return invalid_parameter_error() + } + if send_n <= 0 { + return connection_reset_error() + } + total += int(send_n) + } + + return os.ERROR_NONE +} + // destroy_conn cleans up a WebSocket connection destroy_conn :: proc(c: ^Conn) { if c == nil do return - + if c.state != .Closed { close_connection(c) } - + delete(c.read_buf) delete(c.write_buf) + delete(c.close_text) free(c) } // close_connection closes the underlying network connection -close_connection :: proc(c: ^Conn) -> os.Errno { +close_connection :: proc(c: ^Conn) -> os.Error { if c.state == .Closed { return os.ERROR_NONE } - + c.state = .Closed - net.close(c.conn) + if c.tls_socket != nil { + _ = http.tls_close(c.tls_socket) + free(c.tls_socket) + c.tls_socket = nil + } else { + net.close(c.conn) + } return os.ERROR_NONE } // advance_frame reads and parses the next frame header -advance_frame :: proc(c: ^Conn) -> (opcode: Opcode, err: os.Errno) { +advance_frame :: proc(c: ^Conn) -> (opcode: Opcode, err: os.Error) { // Skip remainder of previous frame if c.read_remaining > 0 { to_skip := make([]byte, c.read_remaining) defer delete(to_skip) - _, recv_err := net.recv_tcp(c.conn, to_skip) - if recv_err != nil { - return .Continuation, os.EINVAL + _, recv_err := conn_read(c, to_skip) + if recv_err != os.ERROR_NONE { + return .Continuation, invalid_parameter_error() } } - + // Read first two bytes of frame header header_buf: [2]byte - _, recv_err := net.recv_tcp(c.conn, header_buf[:]) - if recv_err != nil { - return .Continuation, os.EINVAL + _, recv_err := conn_read(c, header_buf[:]) + if recv_err != os.ERROR_NONE { + return .Continuation, invalid_parameter_error() } - + b0 := header_buf[0] b1 := header_buf[1] - + fin := (b0 & 0x80) != 0 rsv1 := (b0 & 0x40) != 0 rsv2 := (b0 & 0x20) != 0 @@ -121,15 +206,15 @@ advance_frame :: proc(c: ^Conn) -> (opcode: Opcode, err: os.Errno) { opcode_val := Opcode(b0 & 0x0F) masked := (b1 & 0x80) != 0 payload_len := i64(b1 & 0x7F) - + errors: [dynamic]string defer delete(errors) - + // Validate RSV bits if rsv1 || rsv2 || rsv3 { append(&errors, "RSV bits set") } - + // Validate opcode #partial switch opcode_val { case .Close, .Ping, .Pong: @@ -143,6 +228,7 @@ advance_frame :: proc(c: ^Conn) -> (opcode: Opcode, err: os.Errno) { if !c.read_final { append(&errors, "data before FIN") } + c.read_msg_type = opcode_val c.read_final = fin case .Continuation: if c.read_final { @@ -152,121 +238,129 @@ advance_frame :: proc(c: ^Conn) -> (opcode: Opcode, err: os.Errno) { case: append(&errors, fmt.tprintf("unknown opcode: %d", opcode_val)) } - + // Validate mask bit if masked != c.is_server { append(&errors, "bad MASK bit") } - + if len(errors) > 0 { handle_protocol_error(c, "protocol error") - return .Continuation, os.EINVAL + return .Continuation, invalid_parameter_error() } - + // Read extended payload length if payload_len == 126 { len_buf: [2]byte - _, recv_err := net.recv_tcp(c.conn, len_buf[:]) - if recv_err != nil { - return .Continuation, os.EINVAL + _, recv_err = conn_read(c, len_buf[:]) + if recv_err != os.ERROR_NONE { + return .Continuation, invalid_parameter_error() } val, _ := endian.get_u16(len_buf[:], .Big) payload_len = i64(val) } else if payload_len == 127 { len_buf: [8]byte - _, recv_err := net.recv_tcp(c.conn, len_buf[:]) - if recv_err != nil { - return .Continuation, os.EINVAL + _, recv_err = conn_read(c, len_buf[:]) + if recv_err != os.ERROR_NONE { + return .Continuation, invalid_parameter_error() } val64, _ := endian.get_u64(len_buf[:], .Big) payload_len = i64(val64) } - + c.read_remaining = payload_len - + // Read mask key if present if masked { - _, recv_err := net.recv_tcp(c.conn, c.read_mask_key[:]) - if recv_err != nil { - return .Continuation, os.EINVAL + _, recv_err = conn_read(c, c.read_mask_key[:]) + if recv_err != os.ERROR_NONE { + return .Continuation, invalid_parameter_error() } c.read_mask_pos = 0 } - + // For data frames, enforce read limit if is_data(opcode_val) || opcode_val == .Continuation { if c.read_limit > 0 && c.read_remaining > c.read_limit { deadline := time.time_add(time.now(), WRITE_WAIT) _ = write_control(c, .Close, format_close_message(.MessageTooBig, ""), deadline) - return .Continuation, os.EINVAL + return .Continuation, invalid_parameter_error() } return opcode_val, os.ERROR_NONE } - + // Handle control frames + payload := make([]byte, payload_len) + defer delete(payload) if payload_len > 0 { - payload := make([]byte, payload_len) - defer delete(payload) - _, recv_err := net.recv_tcp(c.conn, payload) - if recv_err != nil { - return .Continuation, os.EINVAL + _, recv_err = conn_read(c, payload) + if recv_err != os.ERROR_NONE { + return .Continuation, invalid_parameter_error() } if c.is_server { mask_bytes(c.read_mask_key, 0, payload) } - - #partial switch opcode_val { - case .Pong: - _ = c.handle_pong(string(payload)) - case .Ping: - _ = c.handle_ping(string(payload)) - case .Close: - close_code := u16(1005) - close_text := "" - if len(payload) >= 2 { - cc, _ := endian.get_u16(payload[0:2], .Big) + } + c.read_remaining = 0 + + #partial switch opcode_val { + case .Pong: + _ = c.handle_pong(string(payload)) + case .Ping: + deadline := time.time_add(time.now(), WRITE_WAIT) + if write_err := write_control(c, .Pong, payload, deadline); write_err != os.ERROR_NONE { + return .Continuation, invalid_parameter_error() + } + _ = c.handle_ping(string(payload)) + case .Close: + close_code := u16(1005) + close_text := "" + if len(payload) >= 2 { + cc, _ := endian.get_u16(payload[0:2], .Big) close_code = cc - if len(payload) > 2 { - close_text = string(payload[2:]) - if !utf8.valid_string(close_text) { - handle_protocol_error(c, "invalid UTF-8 in close frame") - return .Continuation, os.EINVAL - } + if len(payload) > 2 { + close_text = string(payload[2:]) + if !utf8.valid_string(close_text) { + handle_protocol_error(c, "invalid UTF-8 in close frame") + return .Continuation, invalid_parameter_error() } } - _ = c.handle_close(close_code, close_text) - c.state = .Closing - return .Continuation, os.ECONNREFUSED } + c.close_code = close_code + clear(&c.close_text) + append(&c.close_text, ..transmute([]byte)close_text) + _ = c.handle_close(close_code, close_text) + c.state = .Closing + return .Continuation, connection_refused_error() } - + return opcode_val, os.ERROR_NONE } // write_control writes a control message -write_control :: proc(c: ^Conn, opcode: Opcode, data: []byte, deadline: time.Time) -> os.Errno { +write_control :: proc(c: ^Conn, opcode: Opcode, data: []byte, deadline: time.Time) -> os.Error { if !is_control(opcode) { - return os.EINVAL + return invalid_parameter_error() } if len(data) > MAX_CONTROL_FRAME_PAYLOAD_SIZE { - return os.EINVAL + return invalid_parameter_error() } - + sync.mutex_lock(&c.write_mutex) defer sync.mutex_unlock(&c.write_mutex) - + // Build frame header b0 := byte(opcode) | 0x80 b1 := byte(len(data)) if !c.is_server { b1 |= 0x80 } - + buf: [MAX_FRAME_HEADER_SIZE + MAX_CONTROL_FRAME_PAYLOAD_SIZE]byte buf[0] = b0 buf[1] = b1 pos := 2 - + if c.is_server { copy(buf[pos:], data) pos += len(data) @@ -275,54 +369,56 @@ write_control :: proc(c: ^Conn, opcode: Opcode, data: []byte, deadline: time.Tim copy(buf[pos:], mask[:]) pos += 4 copy(buf[pos:], data) - mask_bytes(mask, 0, buf[pos:pos+len(data)]) + mask_bytes(mask, 0, buf[pos:pos + len(data)]) pos += len(data) } - + // Set deadline if time.time_to_unix_nano(deadline) != 0 { - timeout_ms := int(time.duration_milliseconds(time.since(deadline))) - if timeout_ms < 0 { timeout_ms = 0 } - net.set_option(c.conn, net.Socket_Option.Send_Timeout, timeout_ms) + remaining := -time.since(deadline) + if remaining < 0 { + remaining = time.Duration(0) + } + net.set_option(c.conn, net.Socket_Option.Send_Timeout, remaining) } - - _, send_err := net.send_tcp(c.conn, buf[:pos]) - if send_err != nil { - return os.EINVAL + + send_err := conn_write(c, buf[:pos]) + if send_err != os.ERROR_NONE { + return invalid_parameter_error() } return os.ERROR_NONE } // write_message writes a data message -write_message :: proc(c: ^Conn, opcode: Opcode, data: []byte) -> os.Errno { +write_message :: proc(c: ^Conn, opcode: Opcode, data: []byte) -> os.Error { if c.state != .Open { - return os.ECONNREFUSED + return connection_refused_error() } - + if !is_data(opcode) { - return os.EINVAL + return invalid_parameter_error() } - + sync.mutex_lock(&c.write_mutex) defer sync.mutex_unlock(&c.write_mutex) - + // Clear write buffer clear(&c.write_buf) - + // Reserve space for frame header header_start := len(c.write_buf) resize(&c.write_buf, len(c.write_buf) + MAX_FRAME_HEADER_SIZE) - + // Write payload payload_start := len(c.write_buf) append(&c.write_buf, ..data) - + payload_len := len(data) - + // Determine frame header size header_len: int b0 := byte(opcode) | 0x80 - + if payload_len < 126 { c.write_buf[header_start] = b0 c.write_buf[header_start + 1] = byte(payload_len) @@ -330,15 +426,15 @@ write_message :: proc(c: ^Conn, opcode: Opcode, data: []byte) -> os.Errno { } else if payload_len < 65536 { c.write_buf[header_start] = b0 c.write_buf[header_start + 1] = 126 - endian.put_u16(c.write_buf[header_start+2:], .Big, u16(payload_len)) + endian.put_u16(c.write_buf[header_start + 2:], .Big, u16(payload_len)) header_len = 4 } else { c.write_buf[header_start] = b0 c.write_buf[header_start + 1] = 127 - endian.put_u64(c.write_buf[header_start+2:], .Big, u64(payload_len)) + endian.put_u64(c.write_buf[header_start + 2:], .Big, u64(payload_len)) header_len = 10 } - + // Apply mask if client if !c.is_server { c.write_buf[header_start + 1] |= 0x80 @@ -347,60 +443,61 @@ write_message :: proc(c: ^Conn, opcode: Opcode, data: []byte) -> os.Errno { header_len += 4 mask_bytes(mask, 0, c.write_buf[payload_start:]) } - + // Send frame - frame := c.write_buf[header_start:payload_start+payload_len] - _, send_err := net.send_tcp(c.conn, frame) - if send_err != nil { - return os.EINVAL + frame := c.write_buf[header_start:payload_start + payload_len] + send_err := conn_write(c, frame) + if send_err != os.ERROR_NONE { + return invalid_parameter_error() } - + return os.ERROR_NONE } // read_message reads a complete message -read_message :: proc(c: ^Conn) -> (opcode: Opcode, data: []byte, err: os.Errno) { +read_message :: proc(c: ^Conn) -> (opcode: Opcode, data: []byte, err: os.Error) { if c.state != .Open { - return .Continuation, nil, os.ECONNREFUSED + return .Continuation, nil, connection_refused_error() } - + clear(&c.read_buf) - + for { frame_opcode, frame_err := advance_frame(c) if frame_err != os.ERROR_NONE { return .Continuation, nil, frame_err } - + // Skip control frames if is_control(frame_opcode) { continue } - + // Read frame payload if c.read_remaining > 0 { payload := make([]byte, c.read_remaining) - _, recv_err := net.recv_tcp(c.conn, payload) - if recv_err != nil { + _, recv_err := conn_read(c, payload) + if recv_err != os.ERROR_NONE { delete(payload) - return .Continuation, nil, os.EINVAL + return .Continuation, nil, invalid_parameter_error() } - + c.read_remaining = 0 + // Unmask if server if c.is_server { mask_bytes(c.read_mask_key, c.read_mask_pos, payload) } - + append(&c.read_buf, ..payload) delete(payload) } - + // Check if this is the final frame if c.read_final { break } } - + return c.read_msg_type, c.read_buf[:], os.ERROR_NONE } @@ -434,9 +531,9 @@ set_read_limit :: proc(c: ^Conn, limit: i64) { } // set_ping_handler sets the handler for ping messages -set_ping_handler :: proc(c: ^Conn, handler: proc(data: string) -> os.Errno) { +set_ping_handler :: proc(c: ^Conn, handler: proc(data: string) -> os.Error) { if handler == nil { - c.handle_ping = proc(data: string) -> os.Errno { + c.handle_ping = proc(data: string) -> os.Error { return os.ERROR_NONE } } else { @@ -445,9 +542,9 @@ set_ping_handler :: proc(c: ^Conn, handler: proc(data: string) -> os.Errno) { } // set_pong_handler sets the handler for pong messages -set_pong_handler :: proc(c: ^Conn, handler: proc(data: string) -> os.Errno) { +set_pong_handler :: proc(c: ^Conn, handler: proc(data: string) -> os.Error) { if handler == nil { - c.handle_pong = proc(data: string) -> os.Errno { + c.handle_pong = proc(data: string) -> os.Error { return os.ERROR_NONE } } else { @@ -456,9 +553,9 @@ set_pong_handler :: proc(c: ^Conn, handler: proc(data: string) -> os.Errno) { } // set_close_handler sets the handler for close messages -set_close_handler :: proc(c: ^Conn, handler: proc(code: u16, text: string) -> os.Errno) { +set_close_handler :: proc(c: ^Conn, handler: proc(code: u16, text: string) -> os.Error) { if handler == nil { - c.handle_close = proc(code: u16, text: string) -> os.Errno { + c.handle_close = proc(code: u16, text: string) -> os.Error { return os.ERROR_NONE } } else { @@ -470,9 +567,11 @@ set_close_handler :: proc(c: ^Conn, handler: proc(code: u16, text: string) -> os set_write_deadline :: proc(c: ^Conn, deadline: time.Time) { c.write_deadline = deadline if time.time_to_unix_nano(deadline) != 0 { - timeout_ms := int(time.duration_milliseconds(time.since(deadline))) - if timeout_ms < 0 { timeout_ms = 0 } - net.set_option(c.conn, net.Socket_Option.Send_Timeout, timeout_ms) + remaining := -time.since(deadline) + if remaining < 0 { + remaining = time.Duration(0) + } + net.set_option(c.conn, net.Socket_Option.Send_Timeout, remaining) } } @@ -480,14 +579,16 @@ set_write_deadline :: proc(c: ^Conn, deadline: time.Time) { set_read_deadline :: proc(c: ^Conn, deadline: time.Time) { c.read_deadline = deadline if time.time_to_unix_nano(deadline) != 0 { - timeout_ms := int(time.duration_milliseconds(time.since(deadline))) - if timeout_ms < 0 { timeout_ms = 0 } - net.set_option(c.conn, net.Socket_Option.Receive_Timeout, timeout_ms) + remaining := -time.since(deadline) + if remaining < 0 { + remaining = time.Duration(0) + } + net.set_option(c.conn, net.Socket_Option.Receive_Timeout, remaining) } } // WriteMessage writes a message to the connection -write_message_public :: proc(c: ^Conn, message_type: int, data: []byte) -> os.Errno { +write_message_public :: proc(c: ^Conn, message_type: int, data: []byte) -> os.Error { opcode: Opcode switch message_type { case 1: @@ -495,40 +596,40 @@ write_message_public :: proc(c: ^Conn, message_type: int, data: []byte) -> os.Er case 2: opcode = .Binary case: - return os.EINVAL + return invalid_parameter_error() } - + return write_message(c, opcode, data) } // ReadMessage reads a message from the connection -read_message_public :: proc(c: ^Conn) -> (message_type: int, data: []byte, err: os.Errno) { +read_message_public :: proc(c: ^Conn) -> (message_type: int, data: []byte, err: os.Error) { opcode, data0, err0 := read_message(c) - + if err0 != os.ERROR_NONE { return 0, nil, err0 } - + #partial switch opcode { case .Text: return Text_Message, data0, os.ERROR_NONE case .Binary: return Binary_Message, data0, os.ERROR_NONE case: - return 0, nil, os.EINVAL + return 0, nil, invalid_parameter_error() } } // Close sends a close frame and closes the connection -close_connection_public :: proc(c: ^Conn) -> os.Errno { +close_connection_public :: proc(c: ^Conn) -> os.Error { if c.state != .Open { - return os.ECONNREFUSED + return connection_refused_error() } - + // Send close frame deadline := time.time_add(time.now(), WRITE_WAIT) _ = write_control(c, .Close, format_close_message(.NormalClosure, ""), deadline) - + return close_connection(c) } @@ -536,116 +637,116 @@ close_connection_public :: proc(c: ^Conn) -> os.Errno { // MessageReader provides streaming read interface for WebSocket messages MessageReader :: struct { - conn: ^Conn, - opcode: Opcode, - reader: strings.Reader, - remaining: int, - first: bool, + conn: ^Conn, + opcode: Opcode, + reader: strings.Reader, + remaining: int, + first: bool, next_read_count: int, } // NextReader returns a reader for the next message -next_reader :: proc(c: ^Conn) -> (r: ^MessageReader, err: os.Errno) { +next_reader :: proc(c: ^Conn) -> (r: ^MessageReader, err: os.Error) { // Check connection state if c.state != .Open { - return nil, os.ECONNREFUSED + return nil, connection_refused_error() } - + // Read the next frame header frame_opcode, frame_err := advance_frame(c) if frame_err != os.ERROR_NONE { return nil, frame_err } - + // Skip control frames (handled internally) if is_control(frame_opcode) { // Read ahead for data frame reader, reader_err := next_reader(c) return reader, reader_err } - + // Check if it's a data frame if !is_data(frame_opcode) { - return nil, os.EINVAL + return nil, invalid_parameter_error() } - + // For continuation frames, track message type if frame_opcode == .Continuation { c.read_msg_type = c.read_msg_type } else { c.read_msg_type = frame_opcode } - + // Create reader mr := new(MessageReader) mr.conn = c mr.opcode = frame_opcode mr.remaining = int(c.read_remaining) mr.first = true - + return mr, os.ERROR_NONE } // Read reads from the message reader -message_reader_read :: proc(r: ^MessageReader, p: []byte) -> (n: int, err: os.Errno) { +message_reader_read :: proc(r: ^MessageReader, p: []byte) -> (n: int, err: os.Error) { if r.remaining <= 0 { return 0, os.ERROR_NONE } - + // Calculate read size to_read := r.remaining if to_read > len(p) { to_read = len(p) } - + // Read from connection - actual_read, recv_err := net.recv_tcp(r.conn.conn, p[:to_read]) - if recv_err != nil { - return int(actual_read), os.EINVAL + actual_read, recv_err := conn_read(r.conn, p[:to_read]) + if recv_err != os.ERROR_NONE { + return int(actual_read), invalid_parameter_error() } - + // Unmask if server if r.conn.is_server && actual_read > 0 { mask_bytes(r.conn.read_mask_key, r.conn.read_mask_pos, p[:actual_read]) r.conn.read_mask_pos = (r.conn.read_mask_pos + int(actual_read)) % 4 } - + r.remaining -= int(actual_read) r.next_read_count += 1 - + return int(actual_read), os.ERROR_NONE } // Close closes the message reader -message_reader_close :: proc(r: ^MessageReader) -> os.Errno { +message_reader_close :: proc(r: ^MessageReader) -> os.Error { // Skip any remaining bytes in this frame if r.remaining > 0 { to_skip := make([]byte, r.remaining) defer delete(to_skip) - _, recv_err := net.recv_tcp(r.conn.conn, to_skip) - if recv_err != nil { - return os.EINVAL + _, recv_err := conn_read(r.conn, to_skip) + if recv_err != os.ERROR_NONE { + return invalid_parameter_error() } } - + return os.ERROR_NONE } // MessageWriter provides streaming write interface for WebSocket messages MessageWriter :: struct { - conn: ^Conn, - opcode: Opcode, - pos: int, - closed: bool, + conn: ^Conn, + opcode: Opcode, + pos: int, + closed: bool, payload: [dynamic]byte, } // NextWriter returns a writer for the next message -next_writer :: proc(c: ^Conn, message_type: int) -> (w: ^MessageWriter, err: os.Errno) { +next_writer :: proc(c: ^Conn, message_type: int) -> (w: ^MessageWriter, err: os.Error) { if c.state != .Open { - return nil, os.ECONNREFUSED + return nil, connection_refused_error() } - + // Map int message type to opcode opcode: Opcode switch message_type { @@ -654,53 +755,53 @@ next_writer :: proc(c: ^Conn, message_type: int) -> (w: ^MessageWriter, err: os. case 2: opcode = .Binary case: - return nil, os.EINVAL + return nil, invalid_parameter_error() } - + // Clear write buffer clear(&c.write_buf) - + // Reserve space for frame header header_start := len(c.write_buf) resize(&c.write_buf, len(c.write_buf) + MAX_FRAME_HEADER_SIZE) - + // Create writer mw := new(MessageWriter) mw.conn = c mw.opcode = opcode mw.pos = header_start + MAX_FRAME_HEADER_SIZE mw.closed = false - + return mw, os.ERROR_NONE } // Write writes to the message writer -message_writer_write :: proc(w: ^MessageWriter, p: []byte) -> (n: int, err: os.Errno) { +message_writer_write :: proc(w: ^MessageWriter, p: []byte) -> (n: int, err: os.Error) { if w.closed { - return 0, os.ECONNREFUSED + return 0, connection_refused_error() } - + // Append to payload append(&w.payload, ..p) return len(p), os.ERROR_NONE } // Close flushes the message writer and sends the message -message_writer_close :: proc(w: ^MessageWriter) -> os.Errno { +message_writer_close :: proc(w: ^MessageWriter) -> os.Error { if w.closed { return os.ERROR_NONE } w.closed = true - + // Append payload to connection's write buffer payload_start := len(w.conn.write_buf) - for b in w.payload { append(&w.conn.write_buf, b) } + for b in w.payload {append(&w.conn.write_buf, b)} payload_len := len(w.payload) - + // Now build the frame header header_start := w.pos - MAX_FRAME_HEADER_SIZE b0 := byte(w.opcode) | 0x80 - + header_len: int if payload_len < 126 { w.conn.write_buf[header_start] = b0 @@ -709,15 +810,15 @@ message_writer_close :: proc(w: ^MessageWriter) -> os.Errno { } else if payload_len < 65536 { w.conn.write_buf[header_start] = b0 w.conn.write_buf[header_start + 1] = 126 - endian.put_u16(w.conn.write_buf[header_start+2:], .Big, u16(payload_len)) + endian.put_u16(w.conn.write_buf[header_start + 2:], .Big, u16(payload_len)) header_len = 4 } else { w.conn.write_buf[header_start] = b0 w.conn.write_buf[header_start + 1] = 127 - endian.put_u64(w.conn.write_buf[header_start+2:], .Big, u64(payload_len)) + endian.put_u64(w.conn.write_buf[header_start + 2:], .Big, u64(payload_len)) header_len = 10 } - + // Apply mask if client if !w.conn.is_server { w.conn.write_buf[header_start + 1] |= 0x80 @@ -726,17 +827,17 @@ message_writer_close :: proc(w: ^MessageWriter) -> os.Errno { header_len += 4 mask_bytes(mask, 0, w.conn.write_buf[payload_start:]) } - + // Send frame - frame := w.conn.write_buf[header_start:payload_start+payload_len] - _, send_err := net.send_tcp(w.conn.conn, frame) - if send_err != nil { + frame := w.conn.write_buf[header_start:payload_start + payload_len] + send_err := conn_write(w.conn, frame) + if send_err != os.ERROR_NONE { delete(w.payload) - return os.EINVAL + return invalid_parameter_error() } - + // Clean up delete(w.payload) - + return os.ERROR_NONE } diff --git a/hypersock_websocket/errors.odin b/hypersock_websocket/errors.odin new file mode 100644 index 0000000..e6e3f62 --- /dev/null +++ b/hypersock_websocket/errors.odin @@ -0,0 +1,71 @@ +package hypersock_websocket + +import "core:io" +import "core:os" +import win32 "core:sys/windows" + +invalid_parameter_error :: proc() -> os.Error { + when ODIN_OS == .Linux { + err: os.Error = os.Platform_Error.EINVAL + return err + } else when ODIN_OS == .Windows { + return os.Platform_Error(win32.ERROR_INVALID_PARAMETER) + } else { + return os.General_Error.Invalid_Command + } +} + +access_denied_error :: proc() -> os.Error { + when ODIN_OS == .Linux { + err: os.Error = os.Platform_Error.EACCES + return err + } else when ODIN_OS == .Windows { + return os.Platform_Error(win32.ERROR_ACCESS_DENIED) + } else { + return io.Error.Permission_Denied + } +} + +not_supported_error :: proc() -> os.Error { + when ODIN_OS == .Linux { + err: os.Error = os.Platform_Error.ENOSYS + return err + } else when ODIN_OS == .Windows { + return os.Platform_Error(win32.ERROR_NOT_SUPPORTED) + } else { + return io.Error.Unsupported + } +} + +connection_reset_error :: proc() -> os.Error { + when ODIN_OS == .Linux { + err: os.Error = os.Platform_Error.ECONNRESET + return err + } else when ODIN_OS == .Windows { + return os.Platform_Error(win32.WSAECONNRESET) + } else { + return os.General_Error.Broken_Pipe + } +} + +connection_refused_error :: proc() -> os.Error { + when ODIN_OS == .Linux { + err: os.Error = os.Platform_Error.ECONNREFUSED + return err + } else when ODIN_OS == .Windows { + return os.Platform_Error(win32.WSAECONNREFUSED) + } else { + return os.General_Error.Invalid_Command + } +} + +would_block_error :: proc() -> os.Error { + when ODIN_OS == .Linux { + err: os.Error = os.Platform_Error.EAGAIN + return err + } else when ODIN_OS == .Windows { + return os.Platform_Error(win32.WSAEWOULDBLOCK) + } else { + return os.General_Error.Timeout + } +} diff --git a/hypersock_websocket/upgrader.odin b/hypersock_websocket/upgrader.odin index c9a14d6..5fa25b4 100644 --- a/hypersock_websocket/upgrader.odin +++ b/hypersock_websocket/upgrader.odin @@ -7,66 +7,71 @@ package hypersock_websocket * Handles HTTP to WebSocket upgrade requests */ -import "core:net" +import http "../hypersock_http" import "core:os" -import "core:fmt" import "core:strings" -import "core:time" -import http "../hypersock_http" // UpgradeResponse contains the response from an upgrade UpgradeResponse :: struct { - conn: ^Conn, - subprotocol: string, + conn: ^Conn, + subprotocol: string, handshake_complete: bool, } // Upgrade upgrades the HTTP connection to WebSocket protocol -upgrade :: proc(u: ^Upgrader, w: ^http.Response, r: ^http.Request) -> (^Conn, http.Header, os.Errno) { +upgrade :: proc( + u: ^Upgrader, + w: ^http.Response, + r: ^http.Request, +) -> ( + ^Conn, + http.Header, + os.Error, +) { // Check if connection is already hijacked // For simplicity, we'll assume the response writer has the underlying connection - + // Validate request method if r.method != .GET { - return nil, http.Header{}, os.EINVAL + return nil, http.Header{}, invalid_parameter_error() } - + // Check for Upgrade header upgrade_header := http.header_get(&r.header, "Upgrade") if strings.to_lower(upgrade_header) != "websocket" { - return nil, http.Header{}, os.EINVAL + return nil, http.Header{}, invalid_parameter_error() } - + // Check for Connection header connection_header := http.header_get(&r.header, "Connection") if !strings.contains(strings.to_lower(connection_header), "upgrade") { - return nil, http.Header{}, os.EINVAL + return nil, http.Header{}, invalid_parameter_error() } - + // Check for Sec-WebSocket-Version version_header := http.header_get(&r.header, "Sec-WebSocket-Version") if version_header != "13" { - return nil, http.Header{}, os.EINVAL + return nil, http.Header{}, invalid_parameter_error() } - + // Check for Sec-WebSocket-Key key_header := http.header_get(&r.header, "Sec-WebSocket-Key") if len(key_header) == 0 { - return nil, http.Header{}, os.EINVAL + return nil, http.Header{}, invalid_parameter_error() } - + // Validate key format (must be base64) - if len(key_header) != 24 { // Base64 encoded 16 bytes - return nil, http.Header{}, os.EINVAL + if len(key_header) != 24 { // Base64 encoded 16 bytes + return nil, http.Header{}, invalid_parameter_error() } - + // Handle optional origin check if u.check_origin != nil { if !u.check_origin(r) { - return nil, http.Header{}, os.EACCES + return nil, http.Header{}, access_denied_error() } } - + // Handle subprotocol negotiation selected_subprotocol: string if len(u.subprotocols) > 0 { @@ -74,7 +79,7 @@ upgrade :: proc(u: ^Upgrader, w: ^http.Response, r: ^http.Request) -> (^Conn, ht if client_protocols_header != "" { client_protocols := strings.split(client_protocols_header, ",") defer delete(client_protocols) - + for cp in client_protocols { client_proto := strings.trim_space(strings.trim(cp, "\"")) for sp in u.subprotocols { @@ -89,38 +94,38 @@ upgrade :: proc(u: ^Upgrader, w: ^http.Response, r: ^http.Request) -> (^Conn, ht } } } - + // Hijack the connection (in real implementation, this would get the underlying TCP socket) // For now, we'll need the underlying connection from the response // This is a simplified implementation - in practice, you'd need access to the raw TCP socket - + // Compute accept key accept_key := compute_accept_key(key_header) - + // Build response headers response_headers: http.Header // Headers are zero-initialized, no explicit init needed http.header_set(&response_headers, "Upgrade", "websocket") http.header_set(&response_headers, "Connection", "Upgrade") http.header_set(&response_headers, "Sec-WebSocket-Accept", accept_key) - + if selected_subprotocol != "" { http.header_set(&response_headers, "Sec-WebSocket-Protocol", selected_subprotocol) } - + // In real implementation, you would: // 1. Get the raw TCP socket from the HTTP response writer // 2. Send the HTTP 101 Switching Protocols response // 3. Create a WebSocket connection from the socket - + // For now, return an error indicating the connection needs to be hijacked // In production, integrate with http package to properly hijack connection - + // Placeholder: would normally create WebSocket connection like this: // conn := new_conn(socket, true, u.read_buffer_size, u.write_buffer_size) // conn.subprotocol = selected_subprotocol - - return nil, response_headers, os.ENOSYS + + return nil, response_headers, not_supported_error() } // IsWebSocketUpgrade checks if the request is a WebSocket upgrade request @@ -128,17 +133,17 @@ is_websocket_upgrade :: proc(ctx: ^http.RequestCtx) -> bool { if ctx.request.method != .GET { return false } - + upgrade_header := http.header_get(&ctx.request.header, "Upgrade") if strings.to_lower(upgrade_header) != "websocket" { return false } - + connection_header := http.header_get(&ctx.request.header, "Connection") if !strings.contains(strings.to_lower(connection_header), "upgrade") { return false } - + return true } @@ -147,19 +152,19 @@ select_subprotocol :: proc(proto_client, proto_server: string) -> string { if proto_client == "" || proto_server == "" { return "" } - + client_protocols := strings.split(proto_client, ",") defer delete(client_protocols) - + server_protocols := strings.split(proto_server, ",") defer delete(server_protocols) - + for cp in client_protocols { client_proto := strings.trim_space(strings.trim(cp, "\"")) if client_proto == "" { continue } - + for sp in server_protocols { server_proto := strings.trim_space(strings.trim(sp, "\"")) if client_proto == server_proto { @@ -167,7 +172,7 @@ select_subprotocol :: proc(proto_client, proto_server: string) -> string { } } } - + return "" } diff --git a/hypersock_websocket/websocket.odin b/hypersock_websocket/websocket.odin index be34662..231cc69 100644 --- a/hypersock_websocket/websocket.odin +++ b/hypersock_websocket/websocket.odin @@ -5,18 +5,15 @@ package hypersock_websocket * Based on gorilla/websocket patterns */ -import "core:net" -import "core:os" -import "core:fmt" -import "core:strings" -import "core:time" +import http "../hypersock_http" import "core:crypto" import "core:crypto/legacy/sha1" - import "core:encoding/base64" -import "core:encoding/endian" +import "core:net" +import "core:os" +import "core:strings" import "core:sync" -import http "../hypersock_http" +import "core:time" // WebSocket frame opcodes (RFC 6455) Opcode :: enum u8 { @@ -76,52 +73,55 @@ Conn_State :: enum { // Conn represents a WebSocket connection Conn :: struct { - conn: net.TCP_Socket, - is_server: bool, - state: Conn_State, - subprotocol: string, - + conn: net.TCP_Socket, + tls_socket: ^http.TLS_Socket, + is_server: bool, + state: Conn_State, + subprotocol: string, + // Read fields - read_buf: [dynamic]byte, - read_remaining: i64, - read_final: bool, - read_msg_type: Opcode, - read_limit: i64, - read_mask_key: [4]byte, - read_mask_pos: int, - read_err: os.Errno, - + read_buf: [dynamic]byte, + read_remaining: i64, + read_final: bool, + read_msg_type: Opcode, + read_limit: i64, + read_mask_key: [4]byte, + read_mask_pos: int, + read_err: os.Error, + // Write fields - write_buf: [dynamic]byte, - write_mutex: sync.Mutex, - write_deadline: time.Time, - is_writing: bool, - + write_buf: [dynamic]byte, + write_mutex: sync.Mutex, + write_deadline: time.Time, + is_writing: bool, + // Read deadline - read_deadline: time.Time, + read_deadline: time.Time, + close_code: u16, + close_text: [dynamic]byte, // Handlers - handle_ping: proc(data: string) -> os.Errno, - handle_pong: proc(data: string) -> os.Errno, - handle_close: proc(code: u16, text: string) -> os.Errno, + handle_ping: proc(data: string) -> os.Error, + handle_pong: proc(data: string) -> os.Error, + handle_close: proc(code: u16, text: string) -> os.Error, } // Upgrader handles HTTP to WebSocket upgrade Upgrader :: struct { - read_buffer_size: int, - write_buffer_size: int, - handshake_timeout: time.Duration, - subprotocols: []string, - check_origin: proc(req: ^http.Request) -> bool, + read_buffer_size: int, + write_buffer_size: int, + handshake_timeout: time.Duration, + subprotocols: []string, + check_origin: proc(req: ^http.Request) -> bool, enable_compression: bool, } // Dialer creates client WebSocket connections Dialer :: struct { - net_dial: proc(network, addr: string) -> (net.TCP_Socket, os.Errno), - read_buffer_size: int, - write_buffer_size: int, - handshake_timeout: time.Duration, - subprotocols: []string, + net_dial: proc(network, addr: string) -> (net.TCP_Socket, os.Error), + read_buffer_size: int, + write_buffer_size: int, + handshake_timeout: time.Duration, + subprotocols: []string, enable_compression: bool, } @@ -137,8 +137,8 @@ WebSocket_Error :: enum { // Initialize default upgrader upgrader_default :: proc() -> Upgrader { - return Upgrader{ - read_buffer_size = 4096, + return Upgrader { + read_buffer_size = 4096, write_buffer_size = 4096, handshake_timeout = 45 * time.Second, } @@ -146,8 +146,8 @@ upgrader_default :: proc() -> Upgrader { // Initialize default dialer dialer_default :: proc() -> Dialer { - return Dialer{ - read_buffer_size = 4096, + return Dialer { + read_buffer_size = 4096, write_buffer_size = 4096, handshake_timeout = 45 * time.Second, } @@ -159,19 +159,19 @@ compute_accept_key :: proc(challenge: string) -> string { magic := "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" combined := strings.concatenate([]string{challenge, magic}) defer delete(combined) - + // Initialize SHA1 context ctx: sha1.Context sha1.init(&ctx) - + // Hash the combined string combined_bytes := transmute([]byte)combined sha1.update(&ctx, combined_bytes) - + // Get the final hash hash: [sha1.DIGEST_SIZE]byte sha1.final(&ctx, hash[:]) - + return base64.encode(hash[:]) }