From 4842469e4bf65dfcb2b6cd15872197bbef8c3582 Mon Sep 17 00:00:00 2001 From: Ajinkya Date: Sat, 23 Nov 2024 00:45:13 +0530 Subject: [PATCH] added client side authentication support while saving malicious and smart events --- .../java/com/akto/auth/grpc/AuthToken.java | 40 +++ .../java/com/akto/filters/HttpCallFilter.java | 55 ++-- .../akto/suspect_data/FlushMessagesTask.java | 26 +- .../java/com/akto/suspect_data/Message.java | 60 +++-- .../java/com/akto/traffic/KafkaRunner.java | 238 +++++++++--------- 5 files changed, 257 insertions(+), 162 deletions(-) create mode 100644 apps/api-threat-detection/src/main/java/com/akto/auth/grpc/AuthToken.java diff --git a/apps/api-threat-detection/src/main/java/com/akto/auth/grpc/AuthToken.java b/apps/api-threat-detection/src/main/java/com/akto/auth/grpc/AuthToken.java new file mode 100644 index 0000000000..163a626a44 --- /dev/null +++ b/apps/api-threat-detection/src/main/java/com/akto/auth/grpc/AuthToken.java @@ -0,0 +1,40 @@ +package com.akto.auth.grpc; + +<<<<<<< Updated upstream +import io.grpc.CallCredentials; +import io.grpc.Metadata; +import io.grpc.Status; + +import java.util.concurrent.Executor; + +import static io.grpc.Metadata.ASCII_STRING_MARSHALLER; + +public class AuthToken extends CallCredentials { + + private final String token; + public static final Metadata.Key AUTHORIZATION_METADATA_KEY = + Metadata.Key.of("Authorization", ASCII_STRING_MARSHALLER); + + public AuthToken(String token) { + this.token = token; + } + + @Override + public void applyRequestMetadata( + RequestInfo requestInfo, Executor appExecutor, MetadataApplier applier) { + appExecutor.execute( + () -> { + try { + Metadata headers = new Metadata(); + headers.put(AUTHORIZATION_METADATA_KEY, token); + applier.apply(headers); + } catch (Throwable e) { + applier.fail(Status.UNAUTHENTICATED.withCause(e)); + } + }); + } +||||||| Stash base +======= +public class AuthToken { +>>>>>>> Stashed changes +} diff --git a/apps/api-threat-detection/src/main/java/com/akto/filters/HttpCallFilter.java b/apps/api-threat-detection/src/main/java/com/akto/filters/HttpCallFilter.java index ac8a2c36f0..f33792b13c 100644 --- a/apps/api-threat-detection/src/main/java/com/akto/filters/HttpCallFilter.java +++ b/apps/api-threat-detection/src/main/java/com/akto/filters/HttpCallFilter.java @@ -2,6 +2,7 @@ import java.util.*; +import com.akto.auth.grpc.AuthToken; import com.akto.cache.RedisBackedCounterCache; import com.akto.dao.context.Context; import com.akto.dao.monitoring.FilterYamlTemplateDao; @@ -22,11 +23,8 @@ import com.akto.kafka.Kafka; import com.akto.log.LoggerMaker; import com.akto.log.LoggerMaker.LogDb; +import com.akto.proto.threat_protection.consumer_service.v1.*; import com.akto.proto.threat_protection.consumer_service.v1.ConsumerServiceGrpc.ConsumerServiceBlockingStub; -import com.akto.proto.threat_protection.consumer_service.v1.ConsumerServiceGrpc; -import com.akto.proto.threat_protection.consumer_service.v1.MaliciousEvent; -import com.akto.proto.threat_protection.consumer_service.v1.SaveSmartEventRequest; -import com.akto.proto.threat_protection.consumer_service.v1.SmartEvent; import com.akto.rules.TestPlugin; import com.akto.suspect_data.Message; import com.akto.test_editor.execution.VariableResolver; @@ -37,6 +35,7 @@ import io.grpc.Grpc; import io.grpc.InsecureChannelCredentials; import io.grpc.ManagedChannel; +import io.grpc.stub.StreamObserver; import io.lettuce.core.RedisClient; public class HttpCallFilter { @@ -58,7 +57,7 @@ public class HttpCallFilter { private final WindowBasedThresholdNotifier windowBasedThresholdNotifier; - private final ConsumerServiceBlockingStub consumerServiceBlockingStub; + private final ConsumerServiceGrpc.ConsumerServiceStub consumerServiceStub; public HttpCallFilter( RedisClient redisClient, int sync_threshold_count, int sync_threshold_time) { @@ -76,7 +75,10 @@ public HttpCallFilter( String target = "localhost:8980"; ManagedChannel channel = Grpc.newChannelBuilder(target, InsecureChannelCredentials.create()).build(); - this.consumerServiceBlockingStub = ConsumerServiceGrpc.newBlockingStub(channel); + this.consumerServiceStub = + ConsumerServiceGrpc.newStub(channel) + .withCallCredentials( + new AuthToken(System.getenv("AKTO_THREAT_PROTECTION_BACKEND_TOKEN"))); } public void filterFunction(List responseParams) { @@ -110,6 +112,7 @@ public void filterFunction(List responseParams) { List rules = new ArrayList<>(); rules.add(new Rule("Lfi Rule 1", new Condition(100, 10))); AggregationRules aggRules = new AggregationRules(); + aggRules.setRule(rules); SourceIPKeyGenerator.instance .generate(responseParam) @@ -131,7 +134,13 @@ public void filterFunction(List responseParams) { .setTimestamp(responseParam.getTime()) .build(); - maliciousSamples.add(new Message(responseParam.getAccountId(), maliciousEvent)); + try { + String data = JsonFormat.printer().print(maliciousEvent); + maliciousSamples.add(new Message(responseParam.getAccountId(), data)); + } catch (InvalidProtocolBufferException e) { + e.printStackTrace(); + return; + } for (Rule rule : aggRules.getRule()) { WindowBasedThresholdNotifier.Result result = @@ -146,8 +155,24 @@ public void filterFunction(List responseParams) { .setDetectedAt(responseParam.getTime()) .setRuleId(rule.getName()) .build(); - this.consumerServiceBlockingStub.saveSmartEvent( - SaveSmartEventRequest.newBuilder().setEvent(smartEvent).build()); + this.consumerServiceStub.saveSmartEvent( + SaveSmartEventRequest.newBuilder().setEvent(smartEvent).build(), + new StreamObserver() { + @Override + public void onNext(SaveSmartEventResponse value) { + // Do nothing + } + + @Override + public void onError(Throwable t) { + t.printStackTrace(); + } + + @Override + public void onCompleted() { + // Do nothing + } + }); } } }); @@ -160,12 +185,12 @@ public void filterFunction(List responseParams) { try { maliciousSamples.forEach( sample -> { - try { - String data = JsonFormat.printer().print(null); - kafka.send(data, KAFKA_MALICIOUS_TOPIC); - } catch (InvalidProtocolBufferException e) { - e.printStackTrace(); - } + sample + .marshal() + .ifPresent( + data -> { + kafka.send(data, KAFKA_MALICIOUS_TOPIC); + }); }); } catch (Exception e) { e.printStackTrace(); diff --git a/apps/api-threat-detection/src/main/java/com/akto/suspect_data/FlushMessagesTask.java b/apps/api-threat-detection/src/main/java/com/akto/suspect_data/FlushMessagesTask.java index fdec77aef3..257ff8af86 100644 --- a/apps/api-threat-detection/src/main/java/com/akto/suspect_data/FlushMessagesTask.java +++ b/apps/api-threat-detection/src/main/java/com/akto/suspect_data/FlushMessagesTask.java @@ -1,21 +1,18 @@ package com.akto.suspect_data; import java.time.Duration; -import java.util.ArrayList; -import java.util.Collections; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Properties; +import java.util.*; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; +import com.akto.auth.grpc.AuthToken; import com.akto.proto.threat_protection.consumer_service.v1.ConsumerServiceGrpc; import com.akto.proto.threat_protection.consumer_service.v1.ConsumerServiceGrpc.ConsumerServiceStub; import com.akto.proto.threat_protection.consumer_service.v1.MaliciousEvent; import com.akto.proto.threat_protection.consumer_service.v1.SaveMaliciousEventRequest; import com.akto.proto.threat_protection.consumer_service.v1.SaveMaliciousEventResponse; +import io.grpc.ChannelCredentials; import org.apache.kafka.clients.consumer.Consumer; import org.apache.kafka.clients.consumer.ConsumerRecord; import org.apache.kafka.clients.consumer.ConsumerRecords; @@ -45,11 +42,13 @@ private FlushMessagesTask() { this.consumer = new KafkaConsumer<>(properties); String target = System.getenv("AKTO_THREAT_PROTECTION_BACKEND_URL"); - // TODO: Secure this connection ManagedChannel channel = Grpc.newChannelBuilder(target, InsecureChannelCredentials.create()).build(); - this.asyncStub = ConsumerServiceGrpc.newStub(channel); + this.asyncStub = + ConsumerServiceGrpc.newStub(channel) + .withCallCredentials( + new AuthToken(System.getenv("AKTO_THREAT_PROTECTION_BACKEND_TOKEN"))); } public static FlushMessagesTask instance = new FlushMessagesTask(); @@ -66,7 +65,6 @@ public void run() { processRecords(records); } catch (Exception e) { e.printStackTrace(); - consumer.close(); } } } @@ -78,9 +76,14 @@ public void processRecords(ConsumerRecords records) { for (ConsumerRecord record : records) { try { MaliciousEvent.Builder builder = MaliciousEvent.newBuilder(); - JsonFormat.parser().merge(record.value(), builder); + Message m = Message.unmarshal(record.value()).orElse(null); + if (m == null) { + continue; + } + + JsonFormat.parser().merge(m.getData(), builder); MaliciousEvent event = builder.build(); - accWiseMessages.computeIfAbsent(record.key(), k -> new ArrayList<>()).add(event); + accWiseMessages.computeIfAbsent(m.getAccountId(), k -> new ArrayList<>()).add(event); } catch (InvalidProtocolBufferException e) { e.printStackTrace(); } @@ -89,7 +92,6 @@ public void processRecords(ConsumerRecords records) { for (Map.Entry> entry : accWiseMessages.entrySet()) { int accountId = Integer.parseInt(entry.getKey()); List events = entry.getValue(); - Context.accountId.set(accountId); this.asyncStub.saveMaliciousEvent( SaveMaliciousEventRequest.newBuilder().addAllEvents(events).build(), diff --git a/apps/api-threat-detection/src/main/java/com/akto/suspect_data/Message.java b/apps/api-threat-detection/src/main/java/com/akto/suspect_data/Message.java index 334590a39f..37904ecebe 100644 --- a/apps/api-threat-detection/src/main/java/com/akto/suspect_data/Message.java +++ b/apps/api-threat-detection/src/main/java/com/akto/suspect_data/Message.java @@ -1,33 +1,57 @@ package com.akto.suspect_data; import com.akto.proto.threat_protection.consumer_service.v1.MaliciousEvent; +import com.fasterxml.jackson.databind.ObjectMapper; + +import java.util.Optional; // Kafka Message Wrapper for suspect data public class Message { - private String accountId; - private MaliciousEvent data; + private String accountId; + private String data; - public Message() { - } + private static ObjectMapper objectMapper = new ObjectMapper(); - public Message(String accountId, MaliciousEvent data) { - this.accountId = accountId; - this.data = data; - } + public Message() {} - public String getAccountId() { - return accountId; - } + public Message(String accountId, String data) { + this.accountId = accountId; + this.data = data; + } - public void setAccountId(String accountId) { - this.accountId = accountId; - } + public String getAccountId() { + return accountId; + } - public MaliciousEvent getData() { - return data; + public void setAccountId(String accountId) { + this.accountId = accountId; + } + + public String getData() { + return data; + } + + public void setData(String data) { + this.data = data; + } + + public Optional marshal() { + try { + return Optional.ofNullable(objectMapper.writeValueAsString(this)); + } catch (Exception e) { + e.printStackTrace(); } - public void setData(MaliciousEvent data) { - this.data = data; + return Optional.empty(); + } + + public static Optional unmarshal(String message) { + try { + return Optional.ofNullable(objectMapper.readValue(message, Message.class)); + } catch (Exception e) { + e.printStackTrace(); } + + return Optional.empty(); + } } diff --git a/libs/utils/src/main/java/com/akto/traffic/KafkaRunner.java b/libs/utils/src/main/java/com/akto/traffic/KafkaRunner.java index f31e35a496..571982f9aa 100644 --- a/libs/utils/src/main/java/com/akto/traffic/KafkaRunner.java +++ b/libs/utils/src/main/java/com/akto/traffic/KafkaRunner.java @@ -9,6 +9,8 @@ import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; +import com.akto.DaoInit; +import com.mongodb.ConnectionString; import org.apache.commons.lang3.function.FailableFunction; import org.apache.kafka.clients.consumer.*; import org.apache.kafka.common.Metric; @@ -25,139 +27,141 @@ import com.akto.runtime.utils.Utils; public class KafkaRunner { - private Consumer consumer; - private static final LoggerMaker loggerMaker = new LoggerMaker(KafkaRunner.class, LogDb.RUNTIME); - private static final DataActor dataActor = DataActorFactory.fetchInstance(); - private static final String KAFKA_GROUP_ID = "akto-threat-detection"; - - public static final ScheduledExecutorService scheduler = Executors.newScheduledThreadPool(2); - - private static Properties generateKafkaProperties() { - String kafkaBrokerUrl = System.getenv("AKTO_KAFKA_BROKER_URL"); - int maxPollRecords = Integer.parseInt( - System.getenv().getOrDefault("AKTO_KAFKA_MAX_POLL_RECORDS_CONFIG", "100")); - - return Utils.configProperties(kafkaBrokerUrl, KAFKA_GROUP_ID, maxPollRecords); + private Consumer consumer; + private static final LoggerMaker loggerMaker = new LoggerMaker(KafkaRunner.class, LogDb.RUNTIME); + private static final DataActor dataActor = DataActorFactory.fetchInstance(); + private static final String KAFKA_GROUP_ID = "akto-threat-detection"; + + public static final ScheduledExecutorService scheduler = Executors.newScheduledThreadPool(2); + + private static Properties generateKafkaProperties() { + String kafkaBrokerUrl = System.getenv("AKTO_KAFKA_BROKER_URL"); + int maxPollRecords = + Integer.parseInt(System.getenv().getOrDefault("AKTO_KAFKA_MAX_POLL_RECORDS_CONFIG", "100")); + + return Utils.configProperties(kafkaBrokerUrl, KAFKA_GROUP_ID, maxPollRecords); + } + + public static void consume( + LogDb module, + List topics, + FailableFunction, Void, Exception> recordProcessor) { + + loggerMaker.setDb(module); + + final KafkaRunner main = new KafkaRunner(); + main.consumer = new KafkaConsumer<>(generateKafkaProperties()); + + boolean hybridSaas = RuntimeMode.isHybridDeployment(); + boolean connected = false; + if (hybridSaas) { + AccountSettings accountSettings = dataActor.fetchAccountSettings(); + if (accountSettings != null) { + int acc = accountSettings.getId(); + Context.accountId.set(acc); + connected = true; + } + } else { + String mongoURI = System.getenv("AKTO_MONGO_CONN"); + DaoInit.init(new ConnectionString(mongoURI)); + Context.accountId.set(1_000_000); + connected = true; } - public static void consume(LogDb module, List topics, - FailableFunction, Void, Exception> recordProcessor) { - - loggerMaker.setDb(module); - - final KafkaRunner main = new KafkaRunner(); - main.consumer = new KafkaConsumer<>(generateKafkaProperties()); - - boolean hybridSaas = RuntimeMode.isHybridDeployment(); - boolean connected = false; - if (!hybridSaas) { - throw new RuntimeException("Hybrid mode is required for this module"); - } - - AccountSettings accountSettings = dataActor.fetchAccountSettings(); - if (accountSettings != null) { - int acc = accountSettings.getId(); - Context.accountId.set(acc); - connected = true; - } - - if (connected) { - loggerMaker.infoAndAddToDb(String.format("Starting module for account : %d", Context.accountId.get())); - AllMetrics.instance.init(module); - } + if (connected) { + loggerMaker.infoAndAddToDb( + String.format("Starting module for account : %d", Context.accountId.get())); + AllMetrics.instance.init(module); + } - final Thread mainThread = Thread.currentThread(); - final AtomicBoolean exceptionOnCommitSync = new AtomicBoolean(false); + final Thread mainThread = Thread.currentThread(); + final AtomicBoolean exceptionOnCommitSync = new AtomicBoolean(false); - Runtime.getRuntime().addShutdownHook(new Thread() { - public void run() { + Runtime.getRuntime() + .addShutdownHook( + new Thread() { + public void run() { main.consumer.wakeup(); try { - if (!exceptionOnCommitSync.get()) { - mainThread.join(); - } + if (!exceptionOnCommitSync.get()) { + mainThread.join(); + } } catch (InterruptedException e) { - e.printStackTrace(); + e.printStackTrace(); } catch (Error e) { - loggerMaker.errorAndAddToDb("Error in main thread: " + e.getMessage()); + loggerMaker.errorAndAddToDb("Error in main thread: " + e.getMessage()); } - } - }); - - scheduler.scheduleAtFixedRate(() -> { - main.logKafkaMetrics(module); - }, 0, 1, TimeUnit.MINUTES); - + } + }); + + scheduler.scheduleAtFixedRate( + () -> { + main.logKafkaMetrics(module); + }, + 0, + 1, + TimeUnit.MINUTES); + + try { + main.consumer.subscribe(topics); + loggerMaker.infoAndAddToDb( + String.format("Consumer subscribed for topics : %s", topics.toString())); + while (true) { + ConsumerRecords records = main.consumer.poll(Duration.ofMillis(10000)); try { - main.consumer.subscribe(topics); - loggerMaker.infoAndAddToDb( - String.format("Consumer subscribed for topics : %s", topics.toString())); - while (true) { - ConsumerRecords records = main.consumer.poll(Duration.ofMillis(10000)); - try { - main.consumer.commitSync(); - } catch (Exception e) { - throw e; - } - - try { - recordProcessor.apply(records); - } catch (Exception e) { - loggerMaker.errorAndAddToDb(e, "Error while processing kafka messages " + e); - } - } - } catch (WakeupException ignored) { - // nothing to catch. This exception is called from the shutdown hook. + main.consumer.commitSync(); } catch (Exception e) { - exceptionOnCommitSync.set(true); - Utils.printL(e); - loggerMaker.errorAndAddToDb("Error in Kafka consumer: " + e.getMessage()); - e.printStackTrace(); - System.exit(0); - } finally { - main.consumer.close(); + throw e; } - } - public void logKafkaMetrics(LogDb module) { try { - Map metrics = this.consumer.metrics(); - for (Map.Entry entry : metrics.entrySet()) { - MetricName key = entry.getKey(); - Metric value = entry.getValue(); - - if (key.name().equals("records-lag-max")) { - double val = value.metricValue().equals(Double.NaN) - ? 0d - : (double) value.metricValue(); - AllMetrics.instance.setKafkaRecordsLagMax((float) val); - } - if (key.name().equals("records-consumed-rate")) { - double val = value.metricValue().equals(Double.NaN) - ? 0d - : (double) value.metricValue(); - AllMetrics.instance.setKafkaRecordsConsumedRate((float) val); - } + recordProcessor.apply(records); + } catch (Exception e) { + loggerMaker.errorAndAddToDb(e, "Error while processing kafka messages " + e); + } + } + } catch (WakeupException ignored) { + // nothing to catch. This exception is called from the shutdown hook. + } catch (Exception e) { + exceptionOnCommitSync.set(true); + Utils.printL(e); + loggerMaker.errorAndAddToDb("Error in Kafka consumer: " + e.getMessage()); + e.printStackTrace(); + System.exit(0); + } finally { + main.consumer.close(); + } + } + + public void logKafkaMetrics(LogDb module) { + try { + Map metrics = this.consumer.metrics(); + for (Map.Entry entry : metrics.entrySet()) { + MetricName key = entry.getKey(); + Metric value = entry.getValue(); + + if (key.name().equals("records-lag-max")) { + double val = value.metricValue().equals(Double.NaN) ? 0d : (double) value.metricValue(); + AllMetrics.instance.setKafkaRecordsLagMax((float) val); + } + if (key.name().equals("records-consumed-rate")) { + double val = value.metricValue().equals(Double.NaN) ? 0d : (double) value.metricValue(); + AllMetrics.instance.setKafkaRecordsConsumedRate((float) val); + } - if (key.name().equals("fetch-latency-avg")) { - double val = value.metricValue().equals(Double.NaN) - ? 0d - : (double) value.metricValue(); - AllMetrics.instance.setKafkaFetchAvgLatency((float) val); - } + if (key.name().equals("fetch-latency-avg")) { + double val = value.metricValue().equals(Double.NaN) ? 0d : (double) value.metricValue(); + AllMetrics.instance.setKafkaFetchAvgLatency((float) val); + } - if (key.name().equals("bytes-consumed-rate")) { - double val = value.metricValue().equals(Double.NaN) - ? 0d - : (double) value.metricValue(); - AllMetrics.instance.setKafkaBytesConsumedRate((float) val); - } - } - } catch (Exception e) { - loggerMaker.errorAndAddToDb( - e, - String.format( - "Failed to get kafka metrics for %s error: %s", module.name(), e)); + if (key.name().equals("bytes-consumed-rate")) { + double val = value.metricValue().equals(Double.NaN) ? 0d : (double) value.metricValue(); + AllMetrics.instance.setKafkaBytesConsumedRate((float) val); } + } + } catch (Exception e) { + loggerMaker.errorAndAddToDb( + e, String.format("Failed to get kafka metrics for %s error: %s", module.name(), e)); } + } }