/Users/alexjokela/projects/lattice/src/tls.c
Line | Count | Source |
1 | | #include "tls.h" |
2 | | #include "net.h" |
3 | | |
4 | | #ifdef LATTICE_HAS_TLS |
5 | | |
6 | | #include <openssl/ssl.h> |
7 | | #include <openssl/err.h> |
8 | | #include <openssl/x509v3.h> |
9 | | #include <string.h> |
10 | | #include <stdlib.h> |
11 | | #include <stdio.h> |
12 | | #include <unistd.h> |
13 | | |
14 | | #ifndef FD_SETSIZE |
15 | | #define FD_SETSIZE 1024 |
16 | | #endif |
17 | | |
18 | | /* ── Per-fd SSL session tracking ── */ |
19 | | |
20 | | static SSL *tls_sessions[FD_SETSIZE]; |
21 | | |
22 | | /* ── Lazy-initialised shared context ── */ |
23 | | |
24 | | static SSL_CTX *g_ssl_ctx = NULL; |
25 | | |
26 | 3 | static SSL_CTX *get_ssl_ctx(char **err) { |
27 | 3 | if (g_ssl_ctx) return g_ssl_ctx; |
28 | | |
29 | 3 | const SSL_METHOD *method = TLS_client_method(); |
30 | 3 | g_ssl_ctx = SSL_CTX_new(method); |
31 | 3 | if (!g_ssl_ctx) { |
32 | 0 | *err = strdup("tls: failed to create SSL_CTX"); |
33 | 0 | return NULL; |
34 | 0 | } |
35 | | |
36 | | /* Load system CA certificates and enable peer verification */ |
37 | 3 | SSL_CTX_set_default_verify_paths(g_ssl_ctx); |
38 | 3 | SSL_CTX_set_verify(g_ssl_ctx, SSL_VERIFY_PEER, NULL); |
39 | 3 | SSL_CTX_set_min_proto_version(g_ssl_ctx, TLS1_2_VERSION); |
40 | | |
41 | 3 | return g_ssl_ctx; |
42 | 3 | } |
43 | | |
44 | 0 | static char *ssl_error_string(const char *prefix) { |
45 | 0 | unsigned long e = ERR_get_error(); |
46 | 0 | char buf[512]; |
47 | 0 | if (e) { |
48 | 0 | char ssl_buf[256]; |
49 | 0 | ERR_error_string_n(e, ssl_buf, sizeof(ssl_buf)); |
50 | 0 | snprintf(buf, sizeof(buf), "%s: %s", prefix, ssl_buf); |
51 | 0 | } else { |
52 | 0 | snprintf(buf, sizeof(buf), "%s: unknown SSL error", prefix); |
53 | 0 | } |
54 | 0 | return strdup(buf); |
55 | 0 | } |
56 | | |
57 | | /* ── tls_connect ── */ |
58 | | |
59 | 3 | int net_tls_connect(const char *host, int port, char **err) { |
60 | 3 | SSL_CTX *ctx = get_ssl_ctx(err); |
61 | 3 | if (!ctx) return -1; |
62 | | |
63 | | /* Use the existing TCP connect for the raw socket */ |
64 | 3 | int fd = net_tcp_connect(host, port, err); |
65 | 3 | if (fd < 0) return -1; |
66 | | |
67 | 3 | SSL *ssl = SSL_new(ctx); |
68 | 3 | if (!ssl) { |
69 | 0 | *err = ssl_error_string("tls_connect: SSL_new"); |
70 | 0 | net_tcp_close(fd); |
71 | 0 | return -1; |
72 | 0 | } |
73 | | |
74 | | /* SNI — required by most HTTPS servers */ |
75 | 3 | SSL_set_tlsext_host_name(ssl, host); |
76 | | |
77 | | /* Enable hostname verification */ |
78 | 3 | SSL_set1_host(ssl, host); |
79 | | |
80 | 3 | SSL_set_fd(ssl, fd); |
81 | | |
82 | 3 | if (SSL_connect(ssl) != 1) { |
83 | 0 | *err = ssl_error_string("tls_connect: SSL_connect"); |
84 | 0 | SSL_free(ssl); |
85 | 0 | net_tcp_close(fd); |
86 | 0 | return -1; |
87 | 0 | } |
88 | | |
89 | 3 | if (fd >= 0 && fd < FD_SETSIZE) |
90 | 3 | tls_sessions[fd] = ssl; |
91 | | |
92 | 3 | return fd; |
93 | 3 | } |
94 | | |
95 | | /* ── tls_read ── */ |
96 | | |
97 | 6 | #define TLS_READ_BUF 8192 |
98 | | |
99 | 6 | char *net_tls_read(int fd, char **err) { |
100 | 6 | if (fd < 0 || fd >= FD_SETSIZE || !tls_sessions[fd]) { |
101 | 3 | *err = strdup("tls_read: not a TLS socket"); |
102 | 3 | return NULL; |
103 | 3 | } |
104 | | |
105 | 3 | char *buf = malloc(TLS_READ_BUF + 1); |
106 | 3 | if (!buf) { *err = strdup("tls_read: out of memory"); return NULL; } |
107 | | |
108 | 3 | int n = SSL_read(tls_sessions[fd], buf, TLS_READ_BUF); |
109 | 3 | if (n < 0) { |
110 | 0 | free(buf); |
111 | 0 | *err = ssl_error_string("tls_read: SSL_read"); |
112 | 0 | return NULL; |
113 | 0 | } |
114 | | |
115 | 3 | buf[n] = '\0'; |
116 | 3 | return buf; |
117 | 3 | } |
118 | | |
119 | | /* ── tls_read_bytes ── */ |
120 | | |
121 | 0 | char *net_tls_read_bytes(int fd, size_t count, char **err) { |
122 | 0 | if (fd < 0 || fd >= FD_SETSIZE || !tls_sessions[fd]) { |
123 | 0 | *err = strdup("tls_read_bytes: not a TLS socket"); |
124 | 0 | return NULL; |
125 | 0 | } |
126 | | |
127 | 0 | char *buf = malloc(count + 1); |
128 | 0 | if (!buf) { *err = strdup("tls_read_bytes: out of memory"); return NULL; } |
129 | | |
130 | 0 | size_t total = 0; |
131 | 0 | while (total < count) { |
132 | 0 | int n = SSL_read(tls_sessions[fd], buf + total, (int)(count - total)); |
133 | 0 | if (n <= 0) break; /* EOF or error */ |
134 | 0 | total += (size_t)n; |
135 | 0 | } |
136 | |
|
137 | 0 | buf[total] = '\0'; |
138 | 0 | return buf; |
139 | 0 | } |
140 | | |
141 | | /* ── tls_write ── */ |
142 | | |
143 | 6 | bool net_tls_write(int fd, const char *data, size_t len, char **err) { |
144 | 6 | if (fd < 0 || fd >= FD_SETSIZE || !tls_sessions[fd]) { |
145 | 3 | *err = strdup("tls_write: not a TLS socket"); |
146 | 3 | return false; |
147 | 3 | } |
148 | | |
149 | 3 | size_t total = 0; |
150 | 6 | while (total < len) { |
151 | 3 | int n = SSL_write(tls_sessions[fd], data + total, (int)(len - total)); |
152 | 3 | if (n <= 0) { |
153 | 0 | *err = ssl_error_string("tls_write: SSL_write"); |
154 | 0 | return false; |
155 | 0 | } |
156 | 3 | total += (size_t)n; |
157 | 3 | } |
158 | 3 | return true; |
159 | 3 | } |
160 | | |
161 | | /* ── tls_close ── */ |
162 | | |
163 | 3 | void net_tls_close(int fd) { |
164 | 3 | if (fd >= 0 && fd < FD_SETSIZE && tls_sessions[fd]) { |
165 | 3 | SSL_shutdown(tls_sessions[fd]); |
166 | 3 | SSL_free(tls_sessions[fd]); |
167 | 3 | tls_sessions[fd] = NULL; |
168 | 3 | } |
169 | 3 | net_tcp_close(fd); |
170 | 3 | } |
171 | | |
172 | | /* ── tls_available ── */ |
173 | | |
174 | 6 | bool net_tls_available(void) { |
175 | 6 | return true; |
176 | 6 | } |
177 | | |
178 | | /* ── tls_cleanup ── */ |
179 | | |
180 | 865 | void net_tls_cleanup(void) { |
181 | 886k | for (int i = 0; i < FD_SETSIZE; i++) { |
182 | 885k | if (tls_sessions[i]) { |
183 | 0 | SSL_shutdown(tls_sessions[i]); |
184 | 0 | SSL_free(tls_sessions[i]); |
185 | 0 | tls_sessions[i] = NULL; |
186 | 0 | } |
187 | 885k | } |
188 | 865 | if (g_ssl_ctx) { |
189 | 1 | SSL_CTX_free(g_ssl_ctx); |
190 | | g_ssl_ctx = NULL; |
191 | 1 | } |
192 | 865 | } |
193 | | |
194 | | #else /* !LATTICE_HAS_TLS */ |
195 | | |
196 | | /* ── Stubs when built without OpenSSL ── */ |
197 | | |
198 | | #include <stdlib.h> |
199 | | #include <string.h> |
200 | | |
201 | | static char *no_tls_err(void) { |
202 | | return strdup("TLS not available (built without OpenSSL)"); |
203 | | } |
204 | | |
205 | | int net_tls_connect(const char *host, int port, char **err) { |
206 | | (void)host; (void)port; |
207 | | *err = no_tls_err(); |
208 | | return -1; |
209 | | } |
210 | | |
211 | | char *net_tls_read(int fd, char **err) { |
212 | | (void)fd; |
213 | | *err = no_tls_err(); |
214 | | return NULL; |
215 | | } |
216 | | |
217 | | char *net_tls_read_bytes(int fd, size_t count, char **err) { |
218 | | (void)fd; (void)count; |
219 | | *err = no_tls_err(); |
220 | | return NULL; |
221 | | } |
222 | | |
223 | | bool net_tls_write(int fd, const char *data, size_t len, char **err) { |
224 | | (void)fd; (void)data; (void)len; |
225 | | *err = no_tls_err(); |
226 | | return false; |
227 | | } |
228 | | |
229 | | void net_tls_close(int fd) { |
230 | | (void)fd; |
231 | | } |
232 | | |
233 | | bool net_tls_available(void) { |
234 | | return false; |
235 | | } |
236 | | |
237 | | void net_tls_cleanup(void) { |
238 | | /* nothing to do */ |
239 | | } |
240 | | |
241 | | #endif /* LATTICE_HAS_TLS */ |