diff --git a/main.go b/main.go index efa7faa..585dbe8 100644 --- a/main.go +++ b/main.go @@ -41,9 +41,8 @@ type MongoConfig struct { func getMongoConfig() MongoConfig { // Load environment variables from .env file err := godotenv.Load() - if err != nil { - log.Fatal("Error loading .env file") + log.Printf("Warning: Could not load .env file: %v", err) } return MongoConfig{ @@ -61,50 +60,88 @@ func getEnvOrDefault(key, defaultValue string) string { } func connectMongoDB(config MongoConfig) (*mongo.Client, *mongo.Collection, error) { - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + // Increase connection timeout significantly + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() - client, err := mongo.Connect(ctx, options.Client().ApplyURI(config.URI)) + // Configure client options with more robust settings + clientOptions := options.Client(). + ApplyURI(config.URI). + SetMaxPoolSize(10). + SetMinPoolSize(1). + SetMaxConnIdleTime(30 * time.Second). + SetConnectTimeout(10 * time.Second). + SetSocketTimeout(30 * time.Second). + SetServerSelectionTimeout(10 * time.Second) + + client, err := mongo.Connect(ctx, clientOptions) if err != nil { - return nil, nil, err + return nil, nil, fmt.Errorf("failed to connect to MongoDB: %v", err) } - // Test connection - err = client.Ping(ctx, nil) - if err != nil { - return nil, nil, err + // Test connection with retry logic + var pingErr error + for i := 0; i < 3; i++ { + pingCtx, pingCancel := context.WithTimeout(context.Background(), 5*time.Second) + pingErr = client.Ping(pingCtx, nil) + pingCancel() + + if pingErr == nil { + break + } + + if i < 2 { + log.Printf("Ping attempt %d failed, retrying: %v", i+1, pingErr) + time.Sleep(2 * time.Second) + } + } + + if pingErr != nil { + client.Disconnect(context.Background()) + return nil, nil, fmt.Errorf("failed to ping MongoDB after retries: %v", pingErr) } collection := client.Database(config.Database).Collection(config.Collection) - // Create indexes for better performance + // Create indexes with longer timeout and better error handling + err = createIndexes(collection) + if err != nil { + log.Printf("Warning: Could not create indexes: %v", err) + // Don't fail completely, just warn + } + + return client, collection, nil +} + +func createIndexes(collection *mongo.Collection) error { indexes := []mongo.IndexModel{ { Keys: bson.D{{Key: "ip", Value: 1}}, + Options: options.Index().SetBackground(true), }, { Keys: bson.D{{Key: "expires_at", Value: 1}}, + Options: options.Index().SetBackground(true), }, { Keys: bson.D{{Key: "status", Value: 1}}, + Options: options.Index().SetBackground(true), }, { Keys: bson.D{ {Key: "status", Value: 1}, {Key: "expires_at", Value: 1}, }, + Options: options.Index().SetBackground(true), }, } - ctx, cancel = context.WithTimeout(context.Background(), 10*time.Second) + // Use longer timeout for index creation + ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) defer cancel() - _, err = collection.Indexes().CreateMany(ctx, indexes) - if err != nil { - log.Printf("Warning: Could not create indexes: %v", err) - } - - return client, collection, nil + _, err := collection.Indexes().CreateMany(ctx, indexes) + return err } func parseDuration(duration string) (time.Duration, error) { @@ -169,7 +206,8 @@ func addBanRecord(collection *mongo.Collection, ip, reason, duration string, jso Hostname: getHostname(), } - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + // Increase timeout for insert operations + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() _, err = collection.InsertOne(ctx, record) @@ -177,7 +215,8 @@ func addBanRecord(collection *mongo.Collection, ip, reason, duration string, jso } func removeBanRecord(collection *mongo.Collection, ip string) error { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + // Increase timeout for update operations + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() now := time.Now() @@ -198,67 +237,114 @@ func removeBanRecord(collection *mongo.Collection, ip string) error { } func cleanupExpiredBans(collection *mongo.Collection) error { - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + // Significantly increase timeout for cleanup operations + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) defer cancel() - // Find expired bans - filter := bson.M{ - "expires_at": bson.M{"$lte": time.Now()}, - "status": "active", - } + // Process in batches to avoid timeout + batchSize := 100 + processedCount := 0 - cursor, err := collection.Find(ctx, filter) - if err != nil { - return err - } - defer cursor.Close(ctx) - - var expiredBans []BanRecord - if err = cursor.All(ctx, &expiredBans); err != nil { - return err - } - - // Remove expired IPs from BitNinja and mark as expired in MongoDB - for _, ban := range expiredBans { - fmt.Printf("Removing expired ban for IP: %s (banned at: %s)\n", - ban.IP, ban.BannedAt.Format("2006-01-02 15:04:05")) - - // Remove from BitNinja - cmd := exec.Command("bitninjacli", "--blacklist", fmt.Sprintf("--del=%s", ban.IP)) - if err := cmd.Run(); err != nil { - fmt.Printf("Error removing IP %s from BitNinja: %v\n", ban.IP, err) - continue + for { + // Find expired bans in batches + filter := bson.M{ + "expires_at": bson.M{"$lte": time.Now()}, + "status": "active", } - // Mark as expired in MongoDB - updateFilter := bson.M{"_id": ban.ID} - update := bson.M{ - "$set": bson.M{ - "status": "expired", - }, - } - - _, err = collection.UpdateOne(ctx, updateFilter, update) + opts := options.Find().SetLimit(int64(batchSize)) + cursor, err := collection.Find(ctx, filter, opts) if err != nil { - fmt.Printf("Error updating MongoDB for IP %s: %v\n", ban.IP, err) + return fmt.Errorf("error finding expired bans: %v", err) + } + + var expiredBans []BanRecord + if err = cursor.All(ctx, &expiredBans); err != nil { + cursor.Close(ctx) + return fmt.Errorf("error decoding expired bans: %v", err) + } + cursor.Close(ctx) + + if len(expiredBans) == 0 { + break // No more expired bans + } + + // Process this batch + for _, ban := range expiredBans { + fmt.Printf("Removing expired ban for IP: %s (banned at: %s)\n", + ban.IP, ban.BannedAt.Format("2006-01-02 15:04:05")) + + // Remove from BitNinja with timeout + cmd := exec.Command("bitninjacli", "--blacklist", fmt.Sprintf("--del=%s", ban.IP)) + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + + // Set timeout for external command + cmdCtx, cmdCancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cmdCancel() + + if err := cmd.Start(); err != nil { + fmt.Printf("Error starting command for IP %s: %v\n", ban.IP, err) + continue + } + + // Wait for command to complete or timeout + done := make(chan error, 1) + go func() { + done <- cmd.Wait() + }() + + select { + case err := <-done: + if err != nil { + fmt.Printf("Error removing IP %s from BitNinja: %v\n", ban.IP, err) + continue + } + case <-cmdCtx.Done(): + fmt.Printf("Timeout removing IP %s from BitNinja\n", ban.IP) + if cmd.Process != nil { + cmd.Process.Kill() + } + continue + } + + // Mark as expired in MongoDB + updateFilter := bson.M{"_id": ban.ID} + update := bson.M{ + "$set": bson.M{ + "status": "expired", + }, + } + + updateCtx, updateCancel := context.WithTimeout(context.Background(), 10*time.Second) + _, err = collection.UpdateOne(updateCtx, updateFilter, update) + updateCancel() + + if err != nil { + fmt.Printf("Error updating MongoDB for IP %s: %v\n", ban.IP, err) + } + } + + processedCount += len(expiredBans) + + // Check if we processed less than batch size (last batch) + if len(expiredBans) < batchSize { + break } } - fmt.Printf("Processed %d expired bans\n", len(expiredBans)) + fmt.Printf("Processed %d expired bans\n", processedCount) return nil } -// Nieuwe functie om oude records te verwijderen func purgeOldRecords(collection *mongo.Collection, olderThanDays int) error { - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + // Increase timeout for purge operations + ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) defer cancel() - // Bereken de datum vanaf wanneer records verwijderd moeten worden + // Calculate cutoff date cutoffDate := time.Now().AddDate(0, 0, -olderThanDays) - // Filter voor records die verwijderd kunnen worden: - // 1. Status "removed" en removed_at ouder dan cutoffDate - // 2. Status "expired" en banned_at ouder dan cutoffDate (voor oude expired records) filter := bson.M{ "$or": []bson.M{ { @@ -272,7 +358,7 @@ func purgeOldRecords(collection *mongo.Collection, olderThanDays int) error { }, } - // Eerst tellen hoeveel records verwijderd gaan worden + // Count records to purge count, err := collection.CountDocuments(ctx, filter) if err != nil { return fmt.Errorf("error counting records to purge: %v", err) @@ -285,8 +371,15 @@ func purgeOldRecords(collection *mongo.Collection, olderThanDays int) error { fmt.Printf("Found %d records older than %d days to purge\n", count, olderThanDays) - // Optioneel: log welke records verwijderd gaan worden - cursor, err := collection.Find(ctx, filter) + // Process in batches if there are many records + if count > 1000 { + fmt.Println("Large number of records detected, processing in batches...") + return purgeInBatches(collection, filter, olderThanDays) + } + + // Log records that will be purged (limit to avoid timeout) + opts := options.Find().SetLimit(100) + cursor, err := collection.Find(ctx, filter, opts) if err != nil { return fmt.Errorf("error finding records to purge: %v", err) } @@ -297,9 +390,13 @@ func purgeOldRecords(collection *mongo.Collection, olderThanDays int) error { return fmt.Errorf("error decoding records to purge: %v", err) } - // Log de records die verwijderd gaan worden - fmt.Println("Records to be purged:") - for _, record := range recordsToPurge { + fmt.Println("Sample of records to be purged:") + for i, record := range recordsToPurge { + if i >= 10 { // Limit output + fmt.Printf("... and %d more records\n", len(recordsToPurge)-10) + break + } + var dateStr string if record.RemovedAt != nil { dateStr = record.RemovedAt.Format("2006-01-02 15:04:05") @@ -310,7 +407,7 @@ func purgeOldRecords(collection *mongo.Collection, olderThanDays int) error { record.IP, record.Status, dateStr, record.Reason) } - // Daadwerkelijk verwijderen + // Actually delete result, err := collection.DeleteMany(ctx, filter) if err != nil { return fmt.Errorf("error purging records: %v", err) @@ -320,13 +417,46 @@ func purgeOldRecords(collection *mongo.Collection, olderThanDays int) error { return nil } +func purgeInBatches(collection *mongo.Collection, filter bson.M, olderThanDays int) error { + batchSize := 1000 + totalDeleted := int64(0) + + for { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + + // Delete in batches + opts := options.Delete().SetHint(bson.D{{Key: "status", Value: 1}}) + result, err := collection.DeleteMany(ctx, filter, opts) + cancel() + + if err != nil { + return fmt.Errorf("error purging batch: %v", err) + } + + totalDeleted += result.DeletedCount + fmt.Printf("Purged batch: %d records\n", result.DeletedCount) + + // If we deleted less than batch size, we're done + if result.DeletedCount < int64(batchSize) { + break + } + + // Small delay between batches + time.Sleep(100 * time.Millisecond) + } + + fmt.Printf("Successfully purged %d total records from database\n", totalDeleted) + return nil +} + func handleAdd(collection *mongo.Collection, ip string, duration string, reason string, jsonObject map[string]interface{}) { // Add to BitNinja blacklist enhancedReason := fmt.Sprintf("%s (Duration: %s, Host: %s)", reason, duration, getHostname()) cmd := exec.Command("bitninjacli", "--blacklist", fmt.Sprintf("--add=%s", ip), fmt.Sprintf("--comment=%s", enhancedReason)) + out, err := cmd.Output() if err != nil { - fmt.Println("Error adding IP to BitNinja:", err) + fmt.Printf("Error adding IP to BitNinja: %v\n", err) return } fmt.Println(string(out)) @@ -349,9 +479,10 @@ func handleAdd(collection *mongo.Collection, ip string, duration string, reason func handleDel(collection *mongo.Collection, ip string, duration string, reason string, jsonObject map[string]interface{}) { // Remove from BitNinja cmd := exec.Command("bitninjacli", "--blacklist", fmt.Sprintf("--del=%s", ip)) + out, err := cmd.Output() if err != nil { - fmt.Println("Error deleting IP from BitNinja:", err) + fmt.Printf("Error deleting IP from BitNinja: %v\n", err) return } fmt.Println(string(out)) @@ -396,12 +527,22 @@ func handlePurge(collection *mongo.Collection, daysStr string) { } func handleList(collection *mongo.Collection) { - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + // Increase timeout for list operations + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() - // Find active bans, sorted by banned_at descending + // First get the accurate total count filter := bson.M{"status": "active"} - opts := options.Find().SetSort(bson.D{{Key: "banned_at", Value: -1}}) + totalCount, err := collection.CountDocuments(ctx, filter) + if err != nil { + fmt.Printf("Error counting active bans: %v\n", err) + return + } + + // Find active bans, sorted by banned_at descending (limited for display) + opts := options.Find(). + SetSort(bson.D{{Key: "banned_at", Value: -1}}). + SetLimit(1000) // Limit results to avoid timeout cursor, err := collection.Find(ctx, filter, opts) if err != nil { @@ -440,11 +581,15 @@ func handleList(collection *mongo.Collection) { ban.Hostname) } - fmt.Printf("\nTotal active bans: %d\n", len(bans)) + fmt.Printf("\nTotal active bans: %d\n", totalCount) + if int64(len(bans)) < totalCount { + fmt.Printf("Showing latest %d entries (sorted by banned date)\n", len(bans)) + } } func handleStats(collection *mongo.Collection) { - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + // Increase timeout for stats operations + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() // Count by status @@ -535,9 +680,11 @@ func main() { log.Fatal("Failed to connect to MongoDB:", err) } defer func() { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - client.Disconnect(ctx) + if err := client.Disconnect(ctx); err != nil { + log.Printf("Error disconnecting from MongoDB: %v", err) + } }() if len(os.Args) < 2 {