Skip to content

Commit 08f40a5

Browse files
committed
Change the OAuth hook to be per connection
The per connection hook works only in async mode, but sync API in ruby-pg is documented as for testing only and they have several flaws already. So I don't think there is any need to support them with the OAuth hook. Therefore on sync API the only supported option is non-hooked OAuth. This moves common code to helpers and hooked OAuth to async specs. Connecting with a OAuth hook within a Ractor is currently also not possible, but might be changed in future.
1 parent dc92807 commit 08f40a5

File tree

5 files changed

+224
-120
lines changed

5 files changed

+224
-120
lines changed

lib/pg.rb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,8 @@ def self.version_string( include_buildnum=nil )
8484

8585

8686
### Convenience alias for PG::Connection.new.
87-
def self.connect( *args, &block )
88-
Connection.new( *args, &block )
87+
def self.connect( *args, **kwargs, &block )
88+
Connection.new( *args, **kwargs, &block )
8989
end
9090

9191
if defined?(Ractor.make_shareable)

lib/pg/connection.rb

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -867,8 +867,8 @@ class << self
867867
# It's still possible to do load balancing with +load_balance_hosts+ set to +random+ and to increase the number of connections a node gets, when the hostname is provided multiple times in the host string.
868868
# This is because in non-timeout cases the host is tried multiple times.
869869
#
870-
def new(*args)
871-
conn = connect_to_hosts(*args)
870+
def new(*args, **kwargs)
871+
conn = connect_to_hosts(*args, **kwargs)
872872

873873
if block_given?
874874
begin
@@ -919,8 +919,8 @@ def new(*args)
919919
port: dests.map{|d| d[2] }.join(","))
920920
end
921921

922-
private def connect_to_hosts(*args)
923-
option_string = parse_connect_args(*args)
922+
private def connect_to_hosts(*args, set_auth_data_hook: nil, **kwargs)
923+
option_string = parse_connect_args(*args, **kwargs)
924924
iopts = PG::Connection.conninfo_parse(option_string).each_with_object({}){|h, o| o[h[:keyword].to_sym] = h[:val] if h[:val] }
925925
iopts = PG::Connection.conndefaults.each_with_object({}){|h, o| o[h[:keyword].to_sym] = h[:val] if h[:val] }.merge(iopts)
926926

@@ -943,6 +943,12 @@ def new(*args)
943943
end
944944
conn = self.connect_start(iopts) or
945945
raise(PG::Error, "Unable to create a new connection")
946+
if set_auth_data_hook
947+
@@auth_mutex.synchronize do
948+
@@pgconn_map[conn.send(:pgconn_address)] = conn
949+
end
950+
conn.instance_variable_set(:@auth_data_hook, set_auth_data_hook)
951+
end
946952

947953
raise PG::ConnectionBad, conn.error_message if conn.status == PG::CONNECTION_BAD
948954

@@ -1090,5 +1096,23 @@ def async_api=(enable)
10901096
end
10911097
end
10921098

1099+
if PG.respond_to?(:set_auth_data_hook)
1100+
def call_auth_data_hook(data)
1101+
@auth_data_hook.call(self, data)
1102+
end
1103+
1104+
@@auth_mutex = Mutex.new
1105+
@@pgconn_map = ObjectSpace::WeakMap.new
1106+
1107+
PG.set_auth_data_hook do |conn_num, data|
1108+
@@auth_mutex.synchronize do
1109+
conn = @@pgconn_map[conn_num]
1110+
if conn
1111+
conn.call_auth_data_hook(data)
1112+
end
1113+
end
1114+
end
1115+
end
1116+
10931117
self.async_api = true
10941118
end # class PG::Connection

spec/helpers.rb

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -673,6 +673,38 @@ def with_env_vars(**kwargs)
673673
def set_etc_hosts(hostaddr, hostname)
674674
system "sudo --non-interactive sed -i '/.* #{hostname}$/{h;s/.*/#{hostaddr} #{hostname}/};${x;/^$/{s//#{hostaddr} #{hostname}/;H};x}' /etc/hosts" or skip("unable to change /etc/hosts file")
675675
end
676+
677+
def build_oauth_validator
678+
skip "requires a PostgreSQL 18 cluster" unless $pg_server.version >= 18
679+
680+
system "make", "-s", "-C", (TEST_DIRECTORY + "spec/oauth").to_s
681+
raise "Building OAuth validator library failed!" unless $?.success?
682+
683+
require 'webrick'
684+
685+
PG.connect(@conninfo) do |conn|
686+
conn.exec("DROP USER IF EXISTS testuseroauth")
687+
conn.exec("CREATE USER testuseroauth")
688+
end
689+
end
690+
691+
def start_fake_oauth(port)
692+
server = WEBrick::HTTPServer.new(Port: port, Logger: WEBrick::Log.new(nil, WEBrick::BasicLog::WARN))
693+
server.mount_proc("/.well-known/openid-configuration") do |req, res|
694+
res["Content-Type"] = "application/json"
695+
res.body = %!{"issuer":"http://localhost:#{port}","token_endpoint":"http://localhost:#{port}/token","device_authorization_endpoint":"http://localhost:#{@port + 3}/devauth"}!
696+
end
697+
server.mount_proc("/devauth") do |req, res|
698+
res["Content-Type"] = "application/json"
699+
res.body = %!{"device_code":"42","user_code":"666","verification_uri":"http://localhost:#{port}/verify","expires_in":60}!
700+
end
701+
server.mount_proc("/token") do |req, res|
702+
res["Content-Type"] = "application/json"
703+
res.body = %!{"access_token":"yes","token_type":""}!
704+
end
705+
Thread.new { server.start }
706+
server
707+
end
676708
end
677709

678710
RSpec.configure do |config|

spec/pg/connection_async_spec.rb

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,4 +156,163 @@ def interrupt_thread(exc=nil)
156156
expect( conn.hostaddr ).to eq( "::1" )
157157
expect( conn.port ).to eq( @port )
158158
end
159+
160+
describe "option set_auth_data_hook", :postgresql_18 do
161+
before :all do
162+
build_oauth_validator
163+
end
164+
165+
before :each do
166+
@old_env, ENV["PGOAUTHDEBUG"] = ENV["PGOAUTHDEBUG"], "UNSAFE"
167+
end
168+
169+
it "should call prompt oauth device hook" do
170+
oauth_server = start_fake_oauth(@port + 3)
171+
172+
verification_uri, user_code, verification_uri_complete, expires_in = nil, nil, nil, nil
173+
conn1, conn2 = nil, nil
174+
175+
hook = proc do |conn, data|
176+
case data
177+
when PG::PromptOAuthDevice
178+
conn1 = conn
179+
verification_uri = data.verification_uri
180+
user_code = data.user_code
181+
verification_uri_complete = data.verification_uri_complete
182+
expires_in = data.expires_in
183+
true
184+
end
185+
end
186+
187+
begin
188+
PG.connect("host=localhost port=#{@port} dbname=test user=testuseroauth oauth_issuer=http://localhost:#{@port + 3} oauth_client_id=foo", set_auth_data_hook: hook) do |conn|
189+
conn.exec("SELECT 1")
190+
conn2 = conn
191+
end
192+
rescue PG::ConnectionBad => e
193+
if e.message =~ /no OAuth flows are available/
194+
skip "requires libpq-oauth to be installed"
195+
end
196+
raise
197+
ensure
198+
oauth_server.shutdown
199+
end
200+
201+
expect(conn1).to eq(conn2)
202+
expect(verification_uri).to eq("http://localhost:#{@port + 3}/verify")
203+
expect(user_code).to eq("666")
204+
expect(verification_uri_complete).to eq(nil)
205+
expect(expires_in).to eq(60)
206+
end
207+
208+
it "should call oauth bearer request hook" do
209+
openid_configuration, scope = nil, nil
210+
conn1, conn2 = nil, nil
211+
212+
hook = proc do |conn, data|
213+
case data
214+
when PG::OAuthBearerRequest
215+
conn1 = conn
216+
openid_configuration = data.openid_configuration
217+
scope = data.scope
218+
data.token = "yes"
219+
true
220+
end
221+
end
222+
223+
PG.connect(host: "localhost", port: @port, dbname: "test", user: "testuseroauth", oauth_issuer: "http://localhost:#{@port + 3}", oauth_client_id: "foo", set_auth_data_hook: hook) do |conn|
224+
conn.exec("SELECT 1")
225+
conn2 = conn
226+
end
227+
228+
expect(conn1).to eq(conn2)
229+
expect(openid_configuration).to eq("http://localhost:#{@port + 3}/.well-known/openid-configuration")
230+
expect(scope).to eq("test")
231+
end
232+
233+
it "shouldn't garbage collect PG::Connection in use" do
234+
conn1 = nil
235+
hook = proc do |conn, data|
236+
case data
237+
when PG::OAuthBearerRequest
238+
data.token = "yes"
239+
conn1 = conn
240+
true
241+
end
242+
end
243+
244+
GC.stress = true
245+
begin
246+
conn = PG.connect(host: "localhost", port: @port, dbname: "test", user: "testuseroauth", oauth_issuer: "http://localhost:#{@port + 3}", oauth_client_id: "foo", set_auth_data_hook: hook)
247+
ensure
248+
GC.stress = false
249+
end
250+
conn.exec("SELECT 1")
251+
252+
expect(conn1).to eq(conn)
253+
end
254+
255+
it "should garbage collect PG::Connection after use" do
256+
hook = proc do |conn, data|
257+
case data
258+
when PG::OAuthBearerRequest
259+
conn1 = conn
260+
openid_configuration = data.openid_configuration
261+
scope = data.scope
262+
data.token = "yes"
263+
true
264+
end
265+
end
266+
267+
10.times do
268+
conn = PG.connect(host: "localhost", port: @port, dbname: "test", user: "testuseroauth", oauth_issuer: "http://localhost:#{@port + 3}", oauth_client_id: "foo", set_auth_data_hook: hook)
269+
conn.exec("SELECT 1")
270+
end
271+
272+
before = PG::Connection.class_variable_get(:@@pgconn_map).keys.size
273+
GC.start
274+
after = PG::Connection.class_variable_get(:@@pgconn_map).keys.size
275+
276+
expect(before - after).to be_between(2, 20)
277+
end
278+
279+
# TODO: Is resetting the global hook still useful, when the hook is per connection?
280+
# it "should reset the hook when called without block" do
281+
# oauth_server = start_fake_oauth(@port + 3)
282+
#
283+
# PG.set_auth_data_hook do |conn_num, data|
284+
# raise "broken hook"
285+
# end
286+
#
287+
# expect do
288+
# PG.connect("host=localhost port=#{@port} dbname=test user=testuseroauth oauth_issuer=http://localhost:#{@port + 3} oauth_client_id=foo") {}
289+
# end.to raise_error("broken hook")
290+
#
291+
# PG.set_auth_data_hook
292+
#
293+
# begin
294+
# PG.connect("host=localhost port=#{@port} dbname=test user=testuseroauth oauth_issuer=http://localhost:#{@port + 3} oauth_client_id=foo") do |conn|
295+
# conn.exec("SELECT 1")
296+
# end
297+
# rescue PG::ConnectionBad => e
298+
# if e.message =~ /no OAuth flows are available/
299+
# skip "requires libpq-oauth to be installed"
300+
# end
301+
# raise
302+
# ensure
303+
# oauth_server.shutdown
304+
# end
305+
# end
306+
307+
# around :example do |ex|
308+
# GC.stress = true
309+
# ex.run
310+
# GC.stress = false
311+
# end
312+
313+
after :each do
314+
# PG.set_auth_data_hook
315+
ENV["PGOAUTHDEBUG"] = @old_env
316+
end
317+
end
159318
end

0 commit comments

Comments
 (0)