diff --git a/script/import_scripts/smf2.rb b/script/import_scripts/smf2.rb index d2c41fd16..b33e81a3e 100644 --- a/script/import_scripts/smf2.rb +++ b/script/import_scripts/smf2.rb @@ -54,8 +54,7 @@ class ImportScripts::Smf2 < ImportScripts::Base options.password = HighLine.new.ask('') {|q| q.echo = false } end - @db = Mysql2::Client.new(host: options.host, username: options.username, - password: options.password, database: options.database) + @default_db_connection = create_db_connection end def execute @@ -284,15 +283,21 @@ class ImportScripts::Smf2 < ImportScripts::Base private - def query(sql, **opts, &block) - return __query(sql).to_a if opts[:as] == :array - return __query(sql, as: :array).first[0] if opts[:as] == :single - return __query(sql, stream: true).each(&block) if block_given? - return __query(sql, stream: true) + def create_db_connection + Mysql2::Client.new(host: options.host, username: options.username, + password: options.password, database: options.database) end - def __query(sql, **opts) - @db.query(sql.gsub('{prefix}', options.prefix), + def query(sql, **opts, &block) + db = opts[:connection] || @default_db_connection + return __query(db, sql).to_a if opts[:as] == :array + return __query(db, sql, as: :array).first[0] if opts[:as] == :single + return __query(db, sql, stream: true).each(&block) if block_given? + return __query(db, sql, stream: true) + end + + def __query(db, sql, **opts) + db.query(sql.gsub('{prefix}', options.prefix), {symbolize_keys: true, cache_rows: false}.merge(opts)) end