Skip to content

Commit 74544d8

Browse files
committed
Extract SQLExecDirect, SQLExecute, SQLPrepare implementation
Co-Authored-By: alinalibq <[email protected]>
1 parent 6180908 commit 74544d8

File tree

5 files changed

+156
-7
lines changed

5 files changed

+156
-7
lines changed

cpp/src/arrow/flight/sql/odbc/odbc_api.cc

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1005,22 +1005,49 @@ SQLRETURN SQLExecDirect(SQLHSTMT stmt, SQLWCHAR* query_text, SQLINTEGER text_len
10051005
ARROW_LOG(DEBUG) << "SQLExecDirectW called with stmt: " << stmt
10061006
<< ", query_text: " << static_cast<const void*>(query_text)
10071007
<< ", text_length: " << text_length;
1008-
// GH-47711 TODO: Implement SQLExecDirect
1009-
return SQL_INVALID_HANDLE;
1008+
1009+
using ODBC::ODBCStatement;
1010+
// The driver is built to handle SELECT statements only.
1011+
return ODBCStatement::ExecuteWithDiagnostics(stmt, SQL_ERROR, [=]() {
1012+
ODBCStatement* statement = reinterpret_cast<ODBCStatement*>(stmt);
1013+
std::string query = ODBC::SqlWcharToString(query_text, text_length);
1014+
1015+
statement->Prepare(query);
1016+
statement->ExecutePrepared();
1017+
1018+
return SQL_SUCCESS;
1019+
});
10101020
}
10111021

10121022
SQLRETURN SQLPrepare(SQLHSTMT stmt, SQLWCHAR* query_text, SQLINTEGER text_length) {
10131023
ARROW_LOG(DEBUG) << "SQLPrepareW called with stmt: " << stmt
10141024
<< ", query_text: " << static_cast<const void*>(query_text)
10151025
<< ", text_length: " << text_length;
1016-
// GH-47712 TODO: Implement SQLPrepare
1017-
return SQL_INVALID_HANDLE;
1026+
1027+
using ODBC::ODBCStatement;
1028+
// The driver is built to handle SELECT statements only.
1029+
return ODBCStatement::ExecuteWithDiagnostics(stmt, SQL_ERROR, [=]() {
1030+
ODBCStatement* statement = reinterpret_cast<ODBCStatement*>(stmt);
1031+
std::string query = ODBC::SqlWcharToString(query_text, text_length);
1032+
1033+
statement->Prepare(query);
1034+
1035+
return SQL_SUCCESS;
1036+
});
10181037
}
10191038

10201039
SQLRETURN SQLExecute(SQLHSTMT stmt) {
10211040
ARROW_LOG(DEBUG) << "SQLExecute called with stmt: " << stmt;
1022-
// GH-47712 TODO: Implement SQLExecute
1023-
return SQL_INVALID_HANDLE;
1041+
1042+
using ODBC::ODBCStatement;
1043+
// The driver is built to handle SELECT statements only.
1044+
return ODBCStatement::ExecuteWithDiagnostics(stmt, SQL_ERROR, [=]() {
1045+
ODBCStatement* statement = reinterpret_cast<ODBCStatement*>(stmt);
1046+
1047+
statement->ExecutePrepared();
1048+
1049+
return SQL_SUCCESS;
1050+
});
10241051
}
10251052

10261053
SQLRETURN SQLFetch(SQLHSTMT stmt) {

cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_statement.cc

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,10 @@ FlightSqlStatement::FlightSqlStatement(const Diagnostics& diagnostics,
6969
call_options_.timeout = TimeoutDuration{-1};
7070
}
7171

72+
FlightSqlStatement::~FlightSqlStatement() {
73+
ClosePreparedStatementIfAny(prepared_statement_, call_options_);
74+
}
75+
7276
bool FlightSqlStatement::SetAttribute(StatementAttributeId attribute,
7377
const Attribute& value) {
7478
switch (attribute) {
@@ -119,7 +123,6 @@ bool FlightSqlStatement::ExecutePrepared() {
119123

120124
Result<std::shared_ptr<FlightInfo>> result =
121125
prepared_statement_->Execute(call_options_);
122-
123126
ThrowIfNotOK(result.status());
124127

125128
current_result_set_ = std::make_shared<FlightSqlResultSet>(

cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_statement.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ class FlightSqlStatement : public Statement {
4949
FlightSqlStatement(const Diagnostics& diagnostics, FlightSqlClient& sql_client,
5050
FlightClientOptions client_options, FlightCallOptions call_options,
5151
const MetadataSettings& metadata_settings);
52+
~FlightSqlStatement();
5253

5354
bool SetAttribute(StatementAttributeId attribute, const Attribute& value) override;
5455

cpp/src/arrow/flight/sql/odbc/tests/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ add_arrow_test(flight_sql_odbc_test
3737
connection_attr_test.cc
3838
connection_test.cc
3939
statement_attr_test.cc
40+
statement_test.cc
4041
# Enable Protobuf cleanup after test execution
4142
# GH-46889: move protobuf_test_util to a more common location
4243
../../../../engine/substrait/protobuf_test_util.cc
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
#include "arrow/flight/sql/odbc/tests/odbc_test_suite.h"
18+
19+
#include "arrow/flight/sql/odbc/odbc_impl/platform.h"
20+
21+
#include <sql.h>
22+
#include <sqltypes.h>
23+
#include <sqlucode.h>
24+
25+
#include <limits>
26+
27+
#include <gmock/gmock.h>
28+
#include <gtest/gtest.h>
29+
30+
namespace arrow::flight::sql::odbc {
31+
32+
template <typename T>
33+
class StatementTest : public T {};
34+
35+
class StatementMockTest : public FlightSQLODBCMockTestBase {};
36+
class StatementRemoteTest : public FlightSQLODBCRemoteTestBase {};
37+
using TestTypes = ::testing::Types<StatementMockTest, StatementRemoteTest>;
38+
TYPED_TEST_SUITE(StatementTest, TestTypes);
39+
40+
TYPED_TEST(StatementTest, TestSQLExecDirectSimpleQuery) {
41+
std::wstring wsql = L"SELECT 1;";
42+
std::vector<SQLWCHAR> sql0(wsql.begin(), wsql.end());
43+
44+
ASSERT_EQ(SQL_SUCCESS,
45+
SQLExecDirect(this->stmt, &sql0[0], static_cast<SQLINTEGER>(sql0.size())));
46+
47+
// GH-47713 TODO: Uncomment call to SQLFetch SQLGetData after implementation
48+
/*
49+
ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt));
50+
51+
SQLINTEGER val;
52+
53+
ASSERT_EQ(SQL_SUCCESS, SQLGetData(this->stmt, 1, SQL_C_LONG, &val, 0, 0));
54+
// Verify 1 is returned
55+
EXPECT_EQ(1, val);
56+
57+
ASSERT_EQ(SQL_NO_DATA, SQLFetch(this->stmt));
58+
59+
ASSERT_EQ(SQL_ERROR, SQLGetData(this->stmt, 1, SQL_C_LONG, &val, 0, 0));
60+
// Invalid cursor state
61+
VerifyOdbcErrorState(SQL_HANDLE_STMT, this->stmt, kErrorState24000);
62+
*/
63+
}
64+
65+
TYPED_TEST(StatementTest, TestSQLExecDirectInvalidQuery) {
66+
std::wstring wsql = L"SELECT;";
67+
std::vector<SQLWCHAR> sql0(wsql.begin(), wsql.end());
68+
69+
ASSERT_EQ(SQL_ERROR,
70+
SQLExecDirect(this->stmt, &sql0[0], static_cast<SQLINTEGER>(sql0.size())));
71+
// ODBC provides generic error code HY000 to all statement errors
72+
VerifyOdbcErrorState(SQL_HANDLE_STMT, this->stmt, kErrorStateHY000);
73+
}
74+
75+
TYPED_TEST(StatementTest, TestSQLExecuteSimpleQuery) {
76+
std::wstring wsql = L"SELECT 1;";
77+
std::vector<SQLWCHAR> sql0(wsql.begin(), wsql.end());
78+
79+
ASSERT_EQ(SQL_SUCCESS,
80+
SQLPrepare(this->stmt, &sql0[0], static_cast<SQLINTEGER>(sql0.size())));
81+
82+
ASSERT_EQ(SQL_SUCCESS, SQLExecute(this->stmt));
83+
84+
// GH-47713 TODO: Uncomment call to SQLFetch SQLGetData after implementation
85+
/*
86+
// Fetch data
87+
ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt));
88+
89+
SQLINTEGER val;
90+
ASSERT_EQ(SQL_SUCCESS, SQLGetData(this->stmt, 1, SQL_C_LONG, &val, 0, 0));
91+
92+
// Verify 1 is returned
93+
EXPECT_EQ(1, val);
94+
95+
ASSERT_EQ(SQL_NO_DATA, SQLFetch(this->stmt));
96+
97+
ASSERT_EQ(SQL_ERROR, SQLGetData(this->stmt, 1, SQL_C_LONG, &val, 0, 0));
98+
// Invalid cursor state
99+
VerifyOdbcErrorState(SQL_HANDLE_STMT, this->stmt, kErrorState24000);
100+
*/
101+
}
102+
103+
TYPED_TEST(StatementTest, TestSQLPrepareInvalidQuery) {
104+
std::wstring wsql = L"SELECT;";
105+
std::vector<SQLWCHAR> sql0(wsql.begin(), wsql.end());
106+
107+
ASSERT_EQ(SQL_ERROR,
108+
SQLPrepare(this->stmt, &sql0[0], static_cast<SQLINTEGER>(sql0.size())));
109+
// ODBC provides generic error code HY000 to all statement errors
110+
VerifyOdbcErrorState(SQL_HANDLE_STMT, this->stmt, kErrorStateHY000);
111+
112+
ASSERT_EQ(SQL_ERROR, SQLExecute(this->stmt));
113+
// Verify function sequence error state is returned
114+
VerifyOdbcErrorState(SQL_HANDLE_STMT, this->stmt, kErrorStateHY010);
115+
}
116+
117+
} // namespace arrow::flight::sql::odbc

0 commit comments

Comments
 (0)