Skip to content

Commit e2024e7

Browse files
committed
Change set_auth_data_hook to be implemented in C
This doesn't expose the global function PQsetAuthDataHook to ruby, but only the one per connection. The conversion of "PGconn *" -> "PG::Connection object" is no longer done per WeakMap, but per st_table. That should make it easier to get Ractor compatible.
1 parent 08f40a5 commit e2024e7

6 files changed

Lines changed: 120 additions & 105 deletions

File tree

ext/pg.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,10 @@ typedef struct {
113113
VALUE encoder_for_put_copy_data;
114114
/* Kind of PG::Coder object for casting COPY rows to ruby values */
115115
VALUE decoder_for_get_copy_data;
116+
#ifdef LIBPQ_HAS_PROMPT_OAUTH_DEVICE
117+
/* Callback for retrieval of OAuth token */
118+
VALUE auth_data_hook;
119+
#endif
116120
/* Ruby encoding index of the client/internal encoding */
117121
int enc_idx : PG_ENC_IDX_BITS;
118122
/* flags controlling Symbol/String field names */

ext/pg_auth_hooks.c

Lines changed: 20 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -8,49 +8,10 @@
88

99
#ifdef LIBPQ_HAS_PROMPT_OAUTH_DEVICE
1010

11-
#ifdef TRUFFLERUBY
12-
static VALUE auth_data_hook;
13-
#else
1411
/*
15-
* On Ruby verisons which support Ractors we store the global callback once
16-
* per Ractor.
12+
* We store the pgconn pointers in a register to retrieve the PG::Connection VALUE in the oauth hook.
1713
*/
18-
#include "ruby/ractor.h"
19-
static rb_ractor_local_key_t auth_data_hook_key;
20-
#endif
21-
22-
static void
23-
auth_data_hook_init(void)
24-
{
25-
#ifdef TRUFFLERUBY
26-
auth_data_hook = Qnil;
27-
rb_gc_register_address(&auth_data_hook);
28-
#else
29-
auth_data_hook_key = rb_ractor_local_storage_value_newkey();
30-
#endif
31-
}
32-
33-
static VALUE
34-
auth_data_hook_get(void)
35-
{
36-
#ifdef TRUFFLERUBY
37-
return auth_data_hook;
38-
#else
39-
VALUE hook = Qnil;
40-
rb_ractor_local_storage_value_lookup(auth_data_hook_key, &hook);
41-
return hook;
42-
#endif
43-
}
44-
45-
static void
46-
auth_data_hook_set(VALUE hook)
47-
{
48-
#ifdef TRUFFLERUBY
49-
auth_data_hook = hook;
50-
#else
51-
rb_ractor_local_storage_value_set(auth_data_hook_key, hook);
52-
#endif
53-
}
14+
struct st_table *pgconn2value;
5415

5516
static VALUE rb_cPromptOAuthDevice;
5617
static VALUE rb_cOAuthBearerRequest;
@@ -298,16 +259,20 @@ oauth_bearer_request_hook_cleanup(VALUE self, VALUE ex)
298259
* currently-registered Ruby auth_data_hook object.
299260
*/
300261
int
301-
auth_data_hook_proxy(PGauthData type, PGconn *conn, void *data)
262+
auth_data_hook_proxy(PGauthData type, PGconn *pgconn, void *data)
302263
{
303-
VALUE proc = auth_data_hook_get(), ret = Qnil;
264+
VALUE rb_conn = Qnil;
265+
VALUE ret = Qnil;
266+
267+
if ( st_lookup(pgconn2value, (st_data_t)pgconn, (st_data_t*)&rb_conn) ) {
268+
t_pg_connection *this = pg_get_connection( rb_conn );
269+
VALUE proc = this->auth_data_hook;
304270

305-
if (proc != Qnil) {
306271
if (type == PQAUTHDATA_PROMPT_OAUTH_DEVICE) {
307272
t_pg_prompt_oauth_device *prompt;
308273

309274
VALUE v_prompt = TypedData_Make_Struct(rb_cPromptOAuthDevice, t_pg_prompt_oauth_device, &pg_prompt_oauth_device_type, prompt);
310-
VALUE args[] = { proc, PTR2NUM(conn), v_prompt };
275+
VALUE args[] = { proc, rb_conn, v_prompt };
311276

312277
prompt->prompt = data;
313278

@@ -318,7 +283,7 @@ auth_data_hook_proxy(PGauthData type, PGconn *conn, void *data)
318283
t_pg_oauth_bearer_request *request;
319284

320285
VALUE v_request = TypedData_Make_Struct(rb_cOAuthBearerRequest, t_pg_oauth_bearer_request, &pg_oauth_bearer_request_type, request);
321-
VALUE args[] = { proc, PTR2NUM(conn), v_request };
286+
VALUE args[] = { proc, rb_conn, v_request };
322287

323288
request->request = data;
324289
request->request->cleanup = oauth_bearer_request_cleanup;
@@ -329,44 +294,30 @@ auth_data_hook_proxy(PGauthData type, PGconn *conn, void *data)
329294
}
330295
}
331296

297+
/* TODO: a hook can return 1, 0 or -1 */
332298
return RTEST(ret);
333299
}
334300

335301
/*
336-
* Document-method: PG.set_auth_data_hook
337-
*
338302
* call-seq:
339-
* PG.set_auth_data_hook {|data| ... } -> Proc
340-
*
341-
* If you pass no arguments, it will reset the handler to the default.
303+
* PG.pgconn2value_size -> Integer
342304
*/
343305
static VALUE
344-
pg_s_set_auth_data_hook(VALUE _self)
306+
pg_oauth_pgconn2value_size_get(VALUE self)
345307
{
346-
PQsetAuthDataHook(gvl_auth_data_hook_proxy); // TODO: Add some safeguards?
347-
348-
VALUE old_proc = auth_data_hook_get(), proc;
349-
350-
if (rb_block_given_p()) {
351-
proc = rb_block_proc();
352-
} else {
353-
/* if no block is given, set back to default */
354-
proc = Qnil;
355-
}
356-
357-
auth_data_hook_set(proc);
358-
359-
return old_proc;
308+
return SIZET2NUM(rb_st_table_size(pgconn2value));
360309
}
361310

311+
362312
void
363313
init_pg_auth_hooks(void)
364314
{
365-
auth_data_hook_init();
315+
pgconn2value = st_init_numtable();
366316

367-
/* rb_mPG = rb_define_module("PG") */
317+
PQsetAuthDataHook(gvl_auth_data_hook_proxy); // TODO: Add some safeguards?
368318

369-
rb_define_singleton_method(rb_mPG, "set_auth_data_hook", pg_s_set_auth_data_hook, 0);
319+
/* rb_mPG = rb_define_module("PG") */
320+
rb_define_private_method(rb_singleton_class(rb_mPG), "pgconn2value_size", pg_oauth_pgconn2value_size_get, 0);
370321

371322
rb_cPromptOAuthDevice = rb_define_class_under(rb_mPG, "PromptOAuthDevice", rb_cObject);
372323
rb_undef_alloc_func(rb_cPromptOAuthDevice);

ext/pg_connection.c

Lines changed: 62 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@ static VALUE pgconn_wait_for_flush( VALUE self );
2222
static void pgconn_set_internal_encoding_index( VALUE );
2323
static const rb_data_type_t pg_connection_type;
2424
static VALUE pgconn_async_flush(VALUE self);
25+
#ifdef LIBPQ_HAS_PROMPT_OAUTH_DEVICE
26+
extern struct st_table *pgconn2value;
27+
#endif
2528

2629
/*
2730
* Global functions
@@ -179,11 +182,17 @@ pgconn_gc_mark( void *_this )
179182
rb_gc_mark_movable( this->trace_stream );
180183
rb_gc_mark_movable( this->encoder_for_put_copy_data );
181184
rb_gc_mark_movable( this->decoder_for_get_copy_data );
185+
#ifdef LIBPQ_HAS_PROMPT_OAUTH_DEVICE
186+
rb_gc_mark_movable( this->auth_data_hook );
187+
#endif
182188
}
183189

184190
static void
185191
pgconn_gc_compact( void *_this )
186192
{
193+
#ifdef LIBPQ_HAS_PROMPT_OAUTH_DEVICE
194+
VALUE old_rb_conn;
195+
#endif
187196
t_pg_connection *this = (t_pg_connection *)_this;
188197
pg_gc_location( this->socket_io );
189198
pg_gc_location( this->notice_receiver );
@@ -193,6 +202,15 @@ pgconn_gc_compact( void *_this )
193202
pg_gc_location( this->trace_stream );
194203
pg_gc_location( this->encoder_for_put_copy_data );
195204
pg_gc_location( this->decoder_for_get_copy_data );
205+
206+
#ifdef LIBPQ_HAS_PROMPT_OAUTH_DEVICE
207+
pg_gc_location( this->auth_data_hook );
208+
/* update the PG::Connection object which is maybe stored in pgconn2value */
209+
if ( st_lookup(pgconn2value, (st_data_t)this->pgconn, (st_data_t*)&old_rb_conn) ) {
210+
VALUE new_rb_conn = rb_gc_location(old_rb_conn);
211+
st_insert( pgconn2value, (st_data_t)this->pgconn, (st_data_t)new_rb_conn );
212+
}
213+
#endif
196214
}
197215

198216

@@ -210,9 +228,15 @@ pgconn_gc_free( void *_this )
210228
}
211229
}
212230
#endif
213-
if (this->pgconn != NULL)
231+
if (this->pgconn != NULL) {
214232
PQfinish( this->pgconn );
215233

234+
#ifdef LIBPQ_HAS_PROMPT_OAUTH_DEVICE
235+
/* Remove from auth hook callback table */
236+
st_delete(pgconn2value, (st_data_t *)&this->pgconn, NULL );
237+
#endif
238+
}
239+
216240
xfree(this);
217241
}
218242

@@ -264,6 +288,9 @@ pgconn_s_allocate( VALUE klass )
264288
RB_OBJ_WRITE(self, &this->type_map_for_results, pg_typemap_all_strings);
265289
RB_OBJ_WRITE(self, &this->encoder_for_put_copy_data, Qnil);
266290
RB_OBJ_WRITE(self, &this->decoder_for_get_copy_data, Qnil);
291+
#ifdef LIBPQ_HAS_PROMPT_OAUTH_DEVICE
292+
RB_OBJ_WRITE(self, &this->auth_data_hook, Qnil);
293+
#endif
267294
RB_OBJ_WRITE(self, &this->trace_stream, Qnil);
268295
rb_ivar_set(self, rb_intern("@calls_to_put_copy_data"), INT2FIX(0));
269296
rb_ivar_set(self, rb_intern("@iopts_for_reset"), Qnil);
@@ -623,22 +650,45 @@ pgconn_reset_poll(VALUE self)
623650
return INT2FIX((int)status);
624651
}
625652

653+
#ifdef LIBPQ_HAS_PROMPT_OAUTH_DEVICE
654+
626655
/*
627656
* call-seq:
628-
* conn.pgconn_address()
629-
*
630-
* Returns the PGconn address.
657+
* conn.auth_data_hook(&block)
631658
*
632-
* It can be used to compare with the address received by the OAuth hook:
659+
* Set a auth data hook.
660+
*/
661+
static VALUE
662+
pgconn_auth_data_hook_set(VALUE self, VALUE proc)
663+
{
664+
t_pg_connection *this = pg_get_connection( self );
665+
666+
if (rb_obj_is_proc(proc)) {
667+
/* set proc */
668+
st_insert( pgconn2value, (st_data_t)this->pgconn, (st_data_t)self );
669+
} else if (NIL_P(proc)) {
670+
/* if nil is given, set back to default */
671+
st_delete( pgconn2value, (st_data_t *)&this->pgconn, NULL );
672+
} else {
673+
rb_raise(rb_eArgError, "Proc object expected");
674+
}
675+
RB_OBJ_WRITE(self, &this->auth_data_hook, proc);
676+
return proc;
677+
}
678+
679+
/*
680+
* call-seq:
681+
* conn.auth_data_hook()
633682
*
634-
* PG.set_auth_data_hook do |pgconn_address, data|
683+
* Returns the defined auth data hook.
635684
*/
636685
static VALUE
637-
pgconn_pgconn_address(VALUE self)
686+
pgconn_auth_data_hook_get(VALUE self)
638687
{
639-
return PTR2NUM(pg_get_pgconn(self));
688+
return pg_get_connection(self)->auth_data_hook;
640689
}
641690

691+
#endif
642692

643693
/*
644694
* call-seq:
@@ -4766,7 +4816,10 @@ init_pg_connection(void)
47664816
rb_define_private_method(rb_cPGconn, "reset_start2", pgconn_reset_start2, 1);
47674817
rb_define_method(rb_cPGconn, "reset_poll", pgconn_reset_poll, 0);
47684818
rb_define_alias(rb_cPGconn, "close", "finish");
4769-
rb_define_private_method(rb_cPGconn, "pgconn_address", pgconn_pgconn_address, 0);
4819+
#ifdef LIBPQ_HAS_PROMPT_OAUTH_DEVICE
4820+
rb_define_method(rb_cPGconn, "auth_data_hook=", pgconn_auth_data_hook_set, 1);
4821+
rb_define_method(rb_cPGconn, "auth_data_hook", pgconn_auth_data_hook_get, 0);
4822+
#endif
47704823

47714824
/****** PG::Connection INSTANCE METHODS: Connection Status ******/
47724825
rb_define_method(rb_cPGconn, "db", pgconn_db, 0);

lib/pg/connection.rb

Lines changed: 5 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -943,11 +943,11 @@ def new(*args, **kwargs)
943943
end
944944
conn = self.connect_start(iopts) or
945945
raise(PG::Error, "Unable to create a new connection")
946-
if set_auth_data_hook
947-
@@auth_mutex.synchronize do
948-
@@pgconn_map[conn.send(:pgconn_address)] = conn
949-
end
950-
conn.instance_variable_set(:@auth_data_hook, set_auth_data_hook)
946+
947+
if conn.respond_to?(:auth_data_hook)
948+
conn.auth_data_hook = set_auth_data_hook
949+
elsif set_auth_data_hook
950+
raise ArgumentError, "invalid option set_auth_data_hook"
951951
end
952952

953953
raise PG::ConnectionBad, conn.error_message if conn.status == PG::CONNECTION_BAD
@@ -1096,23 +1096,5 @@ def async_api=(enable)
10961096
end
10971097
end
10981098

1099-
if PG.respond_to?(:set_auth_data_hook)
1100-
def call_auth_data_hook(data)
1101-
@auth_data_hook.call(self, data)
1102-
end
1103-
1104-
@@auth_mutex = Mutex.new
1105-
@@pgconn_map = ObjectSpace::WeakMap.new
1106-
1107-
PG.set_auth_data_hook do |conn_num, data|
1108-
@@auth_mutex.synchronize do
1109-
conn = @@pgconn_map[conn_num]
1110-
if conn
1111-
conn.call_auth_data_hook(data)
1112-
end
1113-
end
1114-
end
1115-
end
1116-
11171099
self.async_api = true
11181100
end # class PG::Connection

spec/pg/connection_async_spec.rb

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -264,16 +264,17 @@ def interrupt_thread(exc=nil)
264264
end
265265
end
266266

267-
10.times do
267+
before = PG.send(:pgconn2value_size)
268+
20.times do
268269
conn = PG.connect(host: "localhost", port: @port, dbname: "test", user: "testuseroauth", oauth_issuer: "http://localhost:#{@port + 3}", oauth_client_id: "foo", set_auth_data_hook: hook)
269270
conn.exec("SELECT 1")
270271
end
271272

272-
before = PG::Connection.class_variable_get(:@@pgconn_map).keys.size
273273
GC.start
274-
after = PG::Connection.class_variable_get(:@@pgconn_map).keys.size
274+
after = PG.send(:pgconn2value_size)
275275

276-
expect(before - after).to be_between(2, 20)
276+
# Number of GC'ed objects
277+
expect(before + 20 - after).to be_between(1, 50)
277278
end
278279

279280
# TODO: Is resetting the global hook still useful, when the hook is per connection?

spec/pg/gc_compact_spec.rb

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
require_relative '../helpers'
2525

2626
describe "GC.compact", if: GC.respond_to?(:compact) do
27+
hook_called = false
28+
2729
before :all do
2830
TM1 = Class.new(PG::TypeMapByClass) do
2931
def conv_array(value)
@@ -57,6 +59,20 @@ def conv_array(value)
5759
CANCON.socket_io
5860
end
5961

62+
if PG::Connection.instance_methods.include?(:auth_data_hook)
63+
build_oauth_validator
64+
@old_env, ENV["PGOAUTHDEBUG"] = ENV["PGOAUTHDEBUG"], "UNSAFE"
65+
HOOKED_CONN = PG::Connection.connect_start( "host=localhost port=#{@port} dbname=test user=testuseroauth oauth_issuer=http://localhost:#{@port + 3} oauth_client_id=foo" )
66+
HOOKED_CONN.auth_data_hook = proc do |conn, data|
67+
case data
68+
when PG::OAuthBearerRequest
69+
data.token = "yes"
70+
hook_called = true
71+
true
72+
end
73+
end
74+
end
75+
6076
begin
6177
# Use GC.verify_compaction_references instead of GC.compact .
6278
# This has the advantage that all movable objects are actually moved.
@@ -111,6 +127,14 @@ def conv_array(value)
111127
expect( CANCON.socket_io ).to be_kind_of( IO )
112128
end
113129

130+
it "should compact PG::Connection in pgconn2value", :postgresql_18 do
131+
wait_for_polling_ok(HOOKED_CONN)
132+
expect( HOOKED_CONN.error_message ).to eq("")
133+
HOOKED_CONN.finish
134+
expect( hook_called ).to be_truthy
135+
ENV["PGOAUTHDEBUG"] = @old_env
136+
end
137+
114138
after :all do
115139
CONN2.close
116140
end

0 commit comments

Comments
 (0)