@@ -222,13 +222,13 @@ private void validateIdentifier(String identifier) {
222222 }
223223 }
224224
225- public void createTable (String tableName , String ... columnDefs ) {
225+ public void createTable (String tableName , List < String > columnDefs ) {
226226 createTable (tableName , columnDefs , null );
227227 }
228228
229- public void createTable (String tableName , String [] columnDefs , String [] constraints ) {
229+ public void createTable (String tableName , List < String > columnDefs , List < String > constraints ) {
230230 validateIdentifier (tableName );
231- if (columnDefs == null || columnDefs .length == 0 ) {
231+ if (columnDefs == null || columnDefs .isEmpty () ) {
232232 throw new IllegalArgumentException ("At least one column definition is required." );
233233 }
234234 for (String colDef : columnDefs ) {
@@ -288,6 +288,103 @@ public void execute(String sql) {
288288 }
289289 }
290290
291+ public int upsertData (String tableName , List <String > conflictColumns , Map <String , Object > insertData ) throws SQLException {
292+ validateIdentifier (tableName );
293+ if (conflictColumns == null || conflictColumns .isEmpty ()) {
294+ throw new IllegalArgumentException ("Conflict columns list cannot be null or empty." );
295+ }
296+ if (insertData == null || insertData .isEmpty ()) {
297+ throw new IllegalArgumentException ("Insert data map cannot be null or empty." );
298+ }
299+
300+ for (String col : conflictColumns ) {
301+ validateIdentifier (col );
302+ }
303+ List <String > insertColumns = new ArrayList <>();
304+ List <Object > insertValues = new ArrayList <>();
305+ for (Map .Entry <String , Object > entry : new LinkedHashMap <>(insertData ).entrySet ()) {
306+ validateIdentifier (entry .getKey ());
307+ insertColumns .add (quoteIdentifier (entry .getKey ()));
308+ insertValues .add (entry .getValue ());
309+ }
310+
311+ String sql ;
312+ List <String > updateAssignments = new ArrayList <>();
313+
314+ switch (databaseType ) {
315+ case SQLITE :
316+ for (String col : insertColumns ) {
317+ String unquotedCol = col ;
318+ if (col .startsWith ("\" " ) && col .endsWith ("\" " )) {
319+ unquotedCol = col .substring (1 , col .length () - 1 );
320+ }
321+ if (!conflictColumns .contains (unquotedCol )) {
322+ updateAssignments .add (col + " = excluded." + col );
323+ }
324+ }
325+ if (updateAssignments .isEmpty ()) {
326+ if (insertColumns .size () == conflictColumns .size ()) {
327+ throw new IllegalArgumentException ("Upsert requires at least one column to update that is not part of the conflict key." );
328+ } else {
329+ throw new IllegalStateException ("Failed to generate update assignments for ON CONFLICT clause." );
330+ }
331+ }
332+
333+ sql = String .format (
334+ "INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) DO UPDATE SET %s" ,
335+ quoteIdentifier (tableName ), // table
336+ String .join (", " , insertColumns ), // (col1, col2)
337+ String .join (", " , Collections .nCopies (insertColumns .size (), "?" )), // (?, ?)
338+ String .join (", " , conflictColumns .stream ().map (this ::quoteIdentifier ).toArray (String []::new )), // conflict cols (quoted)
339+ String .join (", " , updateAssignments ) // col = excluded.col, ...
340+ );
341+ break ;
342+
343+ case MYSQL :
344+ for (String col : insertColumns ) {
345+ String unquotedCol = col ;
346+ if (col .startsWith ("`" ) && col .endsWith ("`" )) { // Basic unquoting if needed
347+ unquotedCol = col .substring (1 , col .length () - 1 );
348+ }
349+ if (!conflictColumns .contains (unquotedCol )) {
350+ updateAssignments .add (col + " = VALUES(" + col + ")" ); // Use quoted identifier
351+ }
352+ }
353+ if (updateAssignments .isEmpty ()) {
354+ if (insertColumns .size () == conflictColumns .size ()) {
355+ throw new IllegalArgumentException ("Upsert requires at least one column to update that is not part of the primary/unique key." );
356+ } else {
357+ throw new IllegalStateException ("Failed to generate update assignments for ON DUPLICATE KEY UPDATE clause." );
358+ }
359+ }
360+
361+ sql = String .format (
362+ "INSERT INTO %s (%s) VALUES (%s) ON DUPLICATE KEY UPDATE %s" ,
363+ quoteIdentifier (tableName ),
364+ String .join (", " , insertColumns ),
365+ String .join (", " , Collections .nCopies (insertColumns .size (), "?" )),
366+ String .join (", " , updateAssignments )
367+ );
368+ break ;
369+
370+ default :
371+ logger .error ("Upsert operation is not supported for database type: " + databaseType );
372+ throw new UnsupportedOperationException ("Upsert not supported for database type: " + databaseType );
373+ }
374+
375+ logger .debug ("Executing upsert: " + sql + " with " + insertValues .size () + " parameters." );
376+
377+ try (PreparedStatement pstmt = getConnection ().prepareStatement (sql )) {
378+ bindParameters (pstmt , insertValues );
379+ int affectedRows = pstmt .executeUpdate ();
380+ logger .debug ("Upsert successful, " + affectedRows + " row(s) affected." );
381+ return affectedRows ;
382+ } catch (SQLException e ) {
383+ logger .error ("Upsert failed for table '" + tableName + "': " + e .getMessage () + " [SQL: " + sql + "]" );
384+ throw e ;
385+ }
386+ }
387+
291388 public int insertData (String tableName , Object ... columnsAndValues ) {
292389 validateIdentifier (tableName );
293390 if (columnsAndValues .length == 0 ) {
@@ -337,7 +434,7 @@ public int insertDataIfEmpty(String tableName, Object... columnsAndValues) {
337434
338435 Map <String , Object > where = new HashMap <>();
339436 where .put (checkColumn , checkValue );
340- List <Map <String , Object >> existing = selectData (tableName , new String []{ checkColumn } , where );
437+ List <Map <String , Object >> existing = selectData (tableName , Collections . singletonList ( checkColumn ) , where );
341438
342439 if (existing .isEmpty ()) {
343440 logger .debug ("No existing record found for " + checkColumn + ". Inserting..." );
@@ -436,14 +533,14 @@ public int deleteData(String tableName, Map<String, Object> whereClause) {
436533 }
437534 }
438535
439- public List <Map <String , Object >> selectData (String tableName , String [] columnsToSelect , Map <String , Object > whereClause ) {
536+ public List <Map <String , Object >> selectData (String tableName , List < String > columnsToSelect , Map <String , Object > whereClause ) {
440537 validateIdentifier (tableName );
441- if (columnsToSelect == null || columnsToSelect .length == 0 ) {
538+ if (columnsToSelect == null || columnsToSelect .isEmpty () ) {
442539 throw new IllegalArgumentException ("Must specify at least one column to select (or '*')." );
443540 }
444541
445542 String selectColsString ;
446- if (columnsToSelect .length == 1 && "*" .equals (columnsToSelect [ 0 ] )) {
543+ if (columnsToSelect .size () == 1 && "*" .equals (columnsToSelect . get ( 0 ) )) {
447544 selectColsString = "*" ;
448545 } else {
449546 List <String > quotedCols = new ArrayList <>();
@@ -504,7 +601,7 @@ public List<Map<String, Object>> selectData(String tableName, String[] columnsTo
504601 return results ;
505602 }
506603
507- public List <Map <String , Object >> selectData (String tableName , String [] columnsToSelect , String whereColumn , Object whereValue ) {
604+ public List <Map <String , Object >> selectData (String tableName , List < String > columnsToSelect , String whereColumn , Object whereValue ) {
508605 Map <String , Object > whereClause = new LinkedHashMap <>();
509606 if (whereColumn != null ) {
510607 validateIdentifier (whereColumn );
@@ -516,7 +613,7 @@ public List<Map<String, Object>> selectData(String tableName, String[] columnsTo
516613 public <T > List <T > selectSingleColumn (String tableName , String columnToSelect , Map <String , Object > whereClause , Class <T > expectedType ) throws SQLException {
517614 validateIdentifier (columnToSelect );
518615
519- List <Map <String , Object >> rawResults = selectData (tableName , new String []{ columnToSelect } , whereClause );
616+ List <Map <String , Object >> rawResults = selectData (tableName , Collections . singletonList ( columnToSelect ) , whereClause );
520617 List <T > results = new ArrayList <>();
521618
522619 for (Map <String , Object > row : rawResults ) {
@@ -787,4 +884,4 @@ private static String padRight(String s, int n) {
787884 }
788885 return sb .toString ();
789886 }
790- }
887+ }
0 commit comments