Updated timeout handling and overal performance

This commit is contained in:
2025-06-28 10:34:20 +02:00
parent 894e202542
commit 93bc79410e

305
main.go
View File

@ -41,9 +41,8 @@ type MongoConfig struct {
func getMongoConfig() MongoConfig { func getMongoConfig() MongoConfig {
// Load environment variables from .env file // Load environment variables from .env file
err := godotenv.Load() err := godotenv.Load()
if err != nil { if err != nil {
log.Fatal("Error loading .env file") log.Printf("Warning: Could not load .env file: %v", err)
} }
return MongoConfig{ return MongoConfig{
@ -61,50 +60,88 @@ func getEnvOrDefault(key, defaultValue string) string {
} }
func connectMongoDB(config MongoConfig) (*mongo.Client, *mongo.Collection, error) { func connectMongoDB(config MongoConfig) (*mongo.Client, *mongo.Collection, error) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) // Increase connection timeout significantly
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel() defer cancel()
client, err := mongo.Connect(ctx, options.Client().ApplyURI(config.URI)) // Configure client options with more robust settings
clientOptions := options.Client().
ApplyURI(config.URI).
SetMaxPoolSize(10).
SetMinPoolSize(1).
SetMaxConnIdleTime(30 * time.Second).
SetConnectTimeout(10 * time.Second).
SetSocketTimeout(30 * time.Second).
SetServerSelectionTimeout(10 * time.Second)
client, err := mongo.Connect(ctx, clientOptions)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, fmt.Errorf("failed to connect to MongoDB: %v", err)
} }
// Test connection // Test connection with retry logic
err = client.Ping(ctx, nil) var pingErr error
if err != nil { for i := 0; i < 3; i++ {
return nil, nil, err pingCtx, pingCancel := context.WithTimeout(context.Background(), 5*time.Second)
pingErr = client.Ping(pingCtx, nil)
pingCancel()
if pingErr == nil {
break
}
if i < 2 {
log.Printf("Ping attempt %d failed, retrying: %v", i+1, pingErr)
time.Sleep(2 * time.Second)
}
}
if pingErr != nil {
client.Disconnect(context.Background())
return nil, nil, fmt.Errorf("failed to ping MongoDB after retries: %v", pingErr)
} }
collection := client.Database(config.Database).Collection(config.Collection) collection := client.Database(config.Database).Collection(config.Collection)
// Create indexes for better performance // Create indexes with longer timeout and better error handling
err = createIndexes(collection)
if err != nil {
log.Printf("Warning: Could not create indexes: %v", err)
// Don't fail completely, just warn
}
return client, collection, nil
}
func createIndexes(collection *mongo.Collection) error {
indexes := []mongo.IndexModel{ indexes := []mongo.IndexModel{
{ {
Keys: bson.D{{Key: "ip", Value: 1}}, Keys: bson.D{{Key: "ip", Value: 1}},
Options: options.Index().SetBackground(true),
}, },
{ {
Keys: bson.D{{Key: "expires_at", Value: 1}}, Keys: bson.D{{Key: "expires_at", Value: 1}},
Options: options.Index().SetBackground(true),
}, },
{ {
Keys: bson.D{{Key: "status", Value: 1}}, Keys: bson.D{{Key: "status", Value: 1}},
Options: options.Index().SetBackground(true),
}, },
{ {
Keys: bson.D{ Keys: bson.D{
{Key: "status", Value: 1}, {Key: "status", Value: 1},
{Key: "expires_at", Value: 1}, {Key: "expires_at", Value: 1},
}, },
Options: options.Index().SetBackground(true),
}, },
} }
ctx, cancel = context.WithTimeout(context.Background(), 10*time.Second) // Use longer timeout for index creation
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
defer cancel() defer cancel()
_, err = collection.Indexes().CreateMany(ctx, indexes) _, err := collection.Indexes().CreateMany(ctx, indexes)
if err != nil { return err
log.Printf("Warning: Could not create indexes: %v", err)
}
return client, collection, nil
} }
func parseDuration(duration string) (time.Duration, error) { func parseDuration(duration string) (time.Duration, error) {
@ -169,7 +206,8 @@ func addBanRecord(collection *mongo.Collection, ip, reason, duration string, jso
Hostname: getHostname(), Hostname: getHostname(),
} }
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) // Increase timeout for insert operations
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel() defer cancel()
_, err = collection.InsertOne(ctx, record) _, err = collection.InsertOne(ctx, record)
@ -177,7 +215,8 @@ func addBanRecord(collection *mongo.Collection, ip, reason, duration string, jso
} }
func removeBanRecord(collection *mongo.Collection, ip string) error { func removeBanRecord(collection *mongo.Collection, ip string) error {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) // Increase timeout for update operations
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel() defer cancel()
now := time.Now() now := time.Now()
@ -198,67 +237,114 @@ func removeBanRecord(collection *mongo.Collection, ip string) error {
} }
func cleanupExpiredBans(collection *mongo.Collection) error { func cleanupExpiredBans(collection *mongo.Collection) error {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) // Significantly increase timeout for cleanup operations
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel() defer cancel()
// Find expired bans // Process in batches to avoid timeout
filter := bson.M{ batchSize := 100
"expires_at": bson.M{"$lte": time.Now()}, processedCount := 0
"status": "active",
}
cursor, err := collection.Find(ctx, filter) for {
if err != nil { // Find expired bans in batches
return err filter := bson.M{
} "expires_at": bson.M{"$lte": time.Now()},
defer cursor.Close(ctx) "status": "active",
var expiredBans []BanRecord
if err = cursor.All(ctx, &expiredBans); err != nil {
return err
}
// Remove expired IPs from BitNinja and mark as expired in MongoDB
for _, ban := range expiredBans {
fmt.Printf("Removing expired ban for IP: %s (banned at: %s)\n",
ban.IP, ban.BannedAt.Format("2006-01-02 15:04:05"))
// Remove from BitNinja
cmd := exec.Command("bitninjacli", "--blacklist", fmt.Sprintf("--del=%s", ban.IP))
if err := cmd.Run(); err != nil {
fmt.Printf("Error removing IP %s from BitNinja: %v\n", ban.IP, err)
continue
} }
// Mark as expired in MongoDB opts := options.Find().SetLimit(int64(batchSize))
updateFilter := bson.M{"_id": ban.ID} cursor, err := collection.Find(ctx, filter, opts)
update := bson.M{
"$set": bson.M{
"status": "expired",
},
}
_, err = collection.UpdateOne(ctx, updateFilter, update)
if err != nil { if err != nil {
fmt.Printf("Error updating MongoDB for IP %s: %v\n", ban.IP, err) return fmt.Errorf("error finding expired bans: %v", err)
}
var expiredBans []BanRecord
if err = cursor.All(ctx, &expiredBans); err != nil {
cursor.Close(ctx)
return fmt.Errorf("error decoding expired bans: %v", err)
}
cursor.Close(ctx)
if len(expiredBans) == 0 {
break // No more expired bans
}
// Process this batch
for _, ban := range expiredBans {
fmt.Printf("Removing expired ban for IP: %s (banned at: %s)\n",
ban.IP, ban.BannedAt.Format("2006-01-02 15:04:05"))
// Remove from BitNinja with timeout
cmd := exec.Command("bitninjacli", "--blacklist", fmt.Sprintf("--del=%s", ban.IP))
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
// Set timeout for external command
cmdCtx, cmdCancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cmdCancel()
if err := cmd.Start(); err != nil {
fmt.Printf("Error starting command for IP %s: %v\n", ban.IP, err)
continue
}
// Wait for command to complete or timeout
done := make(chan error, 1)
go func() {
done <- cmd.Wait()
}()
select {
case err := <-done:
if err != nil {
fmt.Printf("Error removing IP %s from BitNinja: %v\n", ban.IP, err)
continue
}
case <-cmdCtx.Done():
fmt.Printf("Timeout removing IP %s from BitNinja\n", ban.IP)
if cmd.Process != nil {
cmd.Process.Kill()
}
continue
}
// Mark as expired in MongoDB
updateFilter := bson.M{"_id": ban.ID}
update := bson.M{
"$set": bson.M{
"status": "expired",
},
}
updateCtx, updateCancel := context.WithTimeout(context.Background(), 10*time.Second)
_, err = collection.UpdateOne(updateCtx, updateFilter, update)
updateCancel()
if err != nil {
fmt.Printf("Error updating MongoDB for IP %s: %v\n", ban.IP, err)
}
}
processedCount += len(expiredBans)
// Check if we processed less than batch size (last batch)
if len(expiredBans) < batchSize {
break
} }
} }
fmt.Printf("Processed %d expired bans\n", len(expiredBans)) fmt.Printf("Processed %d expired bans\n", processedCount)
return nil return nil
} }
// Nieuwe functie om oude records te verwijderen
func purgeOldRecords(collection *mongo.Collection, olderThanDays int) error { func purgeOldRecords(collection *mongo.Collection, olderThanDays int) error {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) // Increase timeout for purge operations
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
defer cancel() defer cancel()
// Bereken de datum vanaf wanneer records verwijderd moeten worden // Calculate cutoff date
cutoffDate := time.Now().AddDate(0, 0, -olderThanDays) cutoffDate := time.Now().AddDate(0, 0, -olderThanDays)
// Filter voor records die verwijderd kunnen worden:
// 1. Status "removed" en removed_at ouder dan cutoffDate
// 2. Status "expired" en banned_at ouder dan cutoffDate (voor oude expired records)
filter := bson.M{ filter := bson.M{
"$or": []bson.M{ "$or": []bson.M{
{ {
@ -272,7 +358,7 @@ func purgeOldRecords(collection *mongo.Collection, olderThanDays int) error {
}, },
} }
// Eerst tellen hoeveel records verwijderd gaan worden // Count records to purge
count, err := collection.CountDocuments(ctx, filter) count, err := collection.CountDocuments(ctx, filter)
if err != nil { if err != nil {
return fmt.Errorf("error counting records to purge: %v", err) return fmt.Errorf("error counting records to purge: %v", err)
@ -285,8 +371,15 @@ func purgeOldRecords(collection *mongo.Collection, olderThanDays int) error {
fmt.Printf("Found %d records older than %d days to purge\n", count, olderThanDays) fmt.Printf("Found %d records older than %d days to purge\n", count, olderThanDays)
// Optioneel: log welke records verwijderd gaan worden // Process in batches if there are many records
cursor, err := collection.Find(ctx, filter) if count > 1000 {
fmt.Println("Large number of records detected, processing in batches...")
return purgeInBatches(collection, filter, olderThanDays)
}
// Log records that will be purged (limit to avoid timeout)
opts := options.Find().SetLimit(100)
cursor, err := collection.Find(ctx, filter, opts)
if err != nil { if err != nil {
return fmt.Errorf("error finding records to purge: %v", err) return fmt.Errorf("error finding records to purge: %v", err)
} }
@ -297,9 +390,13 @@ func purgeOldRecords(collection *mongo.Collection, olderThanDays int) error {
return fmt.Errorf("error decoding records to purge: %v", err) return fmt.Errorf("error decoding records to purge: %v", err)
} }
// Log de records die verwijderd gaan worden fmt.Println("Sample of records to be purged:")
fmt.Println("Records to be purged:") for i, record := range recordsToPurge {
for _, record := range recordsToPurge { if i >= 10 { // Limit output
fmt.Printf("... and %d more records\n", len(recordsToPurge)-10)
break
}
var dateStr string var dateStr string
if record.RemovedAt != nil { if record.RemovedAt != nil {
dateStr = record.RemovedAt.Format("2006-01-02 15:04:05") dateStr = record.RemovedAt.Format("2006-01-02 15:04:05")
@ -310,7 +407,7 @@ func purgeOldRecords(collection *mongo.Collection, olderThanDays int) error {
record.IP, record.Status, dateStr, record.Reason) record.IP, record.Status, dateStr, record.Reason)
} }
// Daadwerkelijk verwijderen // Actually delete
result, err := collection.DeleteMany(ctx, filter) result, err := collection.DeleteMany(ctx, filter)
if err != nil { if err != nil {
return fmt.Errorf("error purging records: %v", err) return fmt.Errorf("error purging records: %v", err)
@ -320,13 +417,46 @@ func purgeOldRecords(collection *mongo.Collection, olderThanDays int) error {
return nil return nil
} }
func purgeInBatches(collection *mongo.Collection, filter bson.M, olderThanDays int) error {
batchSize := 1000
totalDeleted := int64(0)
for {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
// Delete in batches
opts := options.Delete().SetHint(bson.D{{Key: "status", Value: 1}})
result, err := collection.DeleteMany(ctx, filter, opts)
cancel()
if err != nil {
return fmt.Errorf("error purging batch: %v", err)
}
totalDeleted += result.DeletedCount
fmt.Printf("Purged batch: %d records\n", result.DeletedCount)
// If we deleted less than batch size, we're done
if result.DeletedCount < int64(batchSize) {
break
}
// Small delay between batches
time.Sleep(100 * time.Millisecond)
}
fmt.Printf("Successfully purged %d total records from database\n", totalDeleted)
return nil
}
func handleAdd(collection *mongo.Collection, ip string, duration string, reason string, jsonObject map[string]interface{}) { func handleAdd(collection *mongo.Collection, ip string, duration string, reason string, jsonObject map[string]interface{}) {
// Add to BitNinja blacklist // Add to BitNinja blacklist
enhancedReason := fmt.Sprintf("%s (Duration: %s, Host: %s)", reason, duration, getHostname()) enhancedReason := fmt.Sprintf("%s (Duration: %s, Host: %s)", reason, duration, getHostname())
cmd := exec.Command("bitninjacli", "--blacklist", fmt.Sprintf("--add=%s", ip), fmt.Sprintf("--comment=%s", enhancedReason)) cmd := exec.Command("bitninjacli", "--blacklist", fmt.Sprintf("--add=%s", ip), fmt.Sprintf("--comment=%s", enhancedReason))
out, err := cmd.Output() out, err := cmd.Output()
if err != nil { if err != nil {
fmt.Println("Error adding IP to BitNinja:", err) fmt.Printf("Error adding IP to BitNinja: %v\n", err)
return return
} }
fmt.Println(string(out)) fmt.Println(string(out))
@ -349,9 +479,10 @@ func handleAdd(collection *mongo.Collection, ip string, duration string, reason
func handleDel(collection *mongo.Collection, ip string, duration string, reason string, jsonObject map[string]interface{}) { func handleDel(collection *mongo.Collection, ip string, duration string, reason string, jsonObject map[string]interface{}) {
// Remove from BitNinja // Remove from BitNinja
cmd := exec.Command("bitninjacli", "--blacklist", fmt.Sprintf("--del=%s", ip)) cmd := exec.Command("bitninjacli", "--blacklist", fmt.Sprintf("--del=%s", ip))
out, err := cmd.Output() out, err := cmd.Output()
if err != nil { if err != nil {
fmt.Println("Error deleting IP from BitNinja:", err) fmt.Printf("Error deleting IP from BitNinja: %v\n", err)
return return
} }
fmt.Println(string(out)) fmt.Println(string(out))
@ -396,12 +527,22 @@ func handlePurge(collection *mongo.Collection, daysStr string) {
} }
func handleList(collection *mongo.Collection) { func handleList(collection *mongo.Collection) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) // Increase timeout for list operations
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel() defer cancel()
// Find active bans, sorted by banned_at descending // First get the accurate total count
filter := bson.M{"status": "active"} filter := bson.M{"status": "active"}
opts := options.Find().SetSort(bson.D{{Key: "banned_at", Value: -1}}) totalCount, err := collection.CountDocuments(ctx, filter)
if err != nil {
fmt.Printf("Error counting active bans: %v\n", err)
return
}
// Find active bans, sorted by banned_at descending (limited for display)
opts := options.Find().
SetSort(bson.D{{Key: "banned_at", Value: -1}}).
SetLimit(1000) // Limit results to avoid timeout
cursor, err := collection.Find(ctx, filter, opts) cursor, err := collection.Find(ctx, filter, opts)
if err != nil { if err != nil {
@ -440,11 +581,15 @@ func handleList(collection *mongo.Collection) {
ban.Hostname) ban.Hostname)
} }
fmt.Printf("\nTotal active bans: %d\n", len(bans)) fmt.Printf("\nTotal active bans: %d\n", totalCount)
if int64(len(bans)) < totalCount {
fmt.Printf("Showing latest %d entries (sorted by banned date)\n", len(bans))
}
} }
func handleStats(collection *mongo.Collection) { func handleStats(collection *mongo.Collection) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) // Increase timeout for stats operations
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel() defer cancel()
// Count by status // Count by status
@ -535,9 +680,11 @@ func main() {
log.Fatal("Failed to connect to MongoDB:", err) log.Fatal("Failed to connect to MongoDB:", err)
} }
defer func() { defer func() {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel() defer cancel()
client.Disconnect(ctx) if err := client.Disconnect(ctx); err != nil {
log.Printf("Error disconnecting from MongoDB: %v", err)
}
}() }()
if len(os.Args) < 2 { if len(os.Args) < 2 {