diff --git a/Application/ApplicationServiceRegister.cs b/Application/ApplicationServiceRegister.cs index 34cfb30..cc4b0c0 100644 --- a/Application/ApplicationServiceRegister.cs +++ b/Application/ApplicationServiceRegister.cs @@ -1,10 +1,15 @@ using Hangfire; using Hangfire.Storage.SQLite; +using Microsoft.AspNetCore.Authentication.Cookies; +using Microsoft.AspNetCore.Authentication.JwtBearer; +using Microsoft.Extensions.Caching.Memory; using Microsoft.Extensions.DependencyInjection; +using Microsoft.IdentityModel.Tokens; using MTWireGuard.Application.Mapper; using MTWireGuard.Application.Repositories; using MTWireGuard.Application.Services; using System.Reflection; +using System.Text; namespace MTWireGuard.Application { @@ -13,7 +18,7 @@ namespace MTWireGuard.Application public static void AddApplicationServices(this IServiceCollection services) { // Add DBContext - services.AddDbContext(ServiceLifetime.Singleton); + services.AddDbContext(); // Add HangFire services.AddHangfire(config => @@ -26,17 +31,66 @@ namespace MTWireGuard.Application services.AddSingleton(); services.AddSingleton(); services.AddSingleton(); + services.AddSingleton(); services.AddAutoMapper( - (provider, expression) => { + (provider, expression) => + { expression.AddProfile(provider.GetService()); expression.AddProfile(provider.GetService()); expression.AddProfile(provider.GetService()); + expression.AddProfile(provider.GetService()); }, new List()); // Add Mikrotik API Service - services.AddSingleton(); + services.AddScoped(); + // XSRF Protection + services.AddAntiforgery(o => + { + o.HeaderName = "XSRF-TOKEN"; + o.FormFieldName = "XSRF-Validation-Token"; + o.Cookie.Name = "XSRF-Validation"; + }); + + // Authentication and Authorization + services.AddAuthentication(CookieAuthenticationDefaults.AuthenticationScheme).AddCookie(options => + { + options.ExpireTimeSpan = TimeSpan.FromMinutes(15); + options.SlidingExpiration = true; + options.LoginPath = "/Login"; + options.AccessDeniedPath = "/Forbidden"; + options.Cookie.Name = "Authentication"; + options.LogoutPath = "/Logout"; + }); + + services.ConfigureApplicationCookie(configure => + { + configure.Cookie.Name = "MTWireguard"; + }); + + services.AddAuthorization(); + + // Add Razor Pages + services.AddRazorPages().AddRazorPagesOptions(o => + { + //o.Conventions.ConfigureFilter(new IgnoreAntiforgeryTokenAttribute()); + o.Conventions.AuthorizeFolder("/"); + o.Conventions.AllowAnonymousToPage("/Login"); + }); + + // Add Session + services.AddDistributedMemoryCache(); + + services.AddSession(options => + { + options.IdleTimeout = TimeSpan.FromMinutes(1); + options.Cookie.HttpOnly = true; + options.Cookie.IsEssential = true; + }); + + // Add CORS + services.AddCors(); } } } diff --git a/Application/DBContext.cs b/Application/DBContext.cs index 66ef3f9..1cd91da 100644 --- a/Application/DBContext.cs +++ b/Application/DBContext.cs @@ -1,4 +1,5 @@ using Microsoft.EntityFrameworkCore; +using MTWireGuard.Application.Models; using MTWireGuard.Application.Models.Mikrotik; namespace MTWireGuard.Application @@ -6,6 +7,9 @@ namespace MTWireGuard.Application public class DBContext : DbContext { public DbSet Users { get; set; } + public DbSet Servers { get; set; } + public DbSet DataUsages { get; set; } + public DbSet LastKnownTraffic { get; set; } public string DbPath { get; } @@ -15,7 +19,11 @@ namespace MTWireGuard.Application } protected override void OnConfiguring(DbContextOptionsBuilder options) - => options.UseSqlite($"Data Source={DbPath}"); + { + options.UseSqlite($"Data Source={DbPath}", opt => + { + opt.MigrationsAssembly("MTWireGuard.Application"); + }); + } } - } diff --git a/Application/Helper.cs b/Application/Helper.cs index e120bd8..151ec2d 100644 --- a/Application/Helper.cs +++ b/Application/Helper.cs @@ -1,22 +1,54 @@ -using System.IO.Compression; +using AutoMapper; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Mvc.Controllers; +using Microsoft.AspNetCore.Mvc.Rendering; +using Microsoft.AspNetCore.Mvc; +using Microsoft.AspNetCore.Routing; +using Microsoft.EntityFrameworkCore; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Options; +using MTWireGuard.Application.MinimalAPI; +using MTWireGuard.Application.Models; +using MTWireGuard.Application.Repositories; +using System.IO.Compression; using System.Text; using System.Text.RegularExpressions; +using System.Text.Json; namespace MTWireGuard.Application { public class Helper { public static readonly string[] UpperCaseTopics = - { - "dhcp", - "ppp", - "l2tp", - "pptp", - "sstp" - }; + [ + "dhcp", + "ppp", + "l2tp", + "pptp", + "sstp" + ]; - private static readonly string[] SizeSuffixes = - { "bytes", "KB", "MB", "GB", "TB", "PB", "EB", "ZB", "YB" }; + public static string PeersTrafficUsageScript(string apiURL) + { + return $"/tool fetch mode=http url=\"{apiURL}\" http-method=post check-certificate=no http-data=([/interface/wireguard/peers/print show-ids proplist=rx,tx as-value]);"; + } + + public static string PeersLastHandshakeScript(string apiURL) + { + return $"/tool fetch mode=http url=\"{apiURL}\" http-method=post check-certificate=no http-data=([/interface/wireguard/peers/print show-ids proplist=last-handshake as-value]);"; + } + + public static int ParseEntityID(string entityID) + { + return Convert.ToInt32(entityID[1..], 16); + } + + public static string ParseEntityID(int entityID) + { + return $"*{entityID:X}"; + } + + private static readonly string[] SizeSuffixes = ["bytes", "KB", "MB", "GB", "TB", "PB", "EB", "ZB", "YB"]; public static string ConvertByteSize(long value, int decimalPlaces = 2) { if (decimalPlaces < 0) { throw new ArgumentOutOfRangeException("decimalPlaces"); } @@ -42,6 +74,171 @@ namespace MTWireGuard.Application adjustedSize, SizeSuffixes[mag]); } + + #region API Section + public static List ParseTrafficUsage(string input) + { + string[] items = input.Split(".id="); + + List objects = items + .Where(item => !string.IsNullOrEmpty(item)) + .Select(item => $"id={item}") + .ToList(); + + return objects + .Select(x => x.Split(';').ToList()) + .Select(arr => + { + var obj = new UsageObject(); + var id = arr.Find(x => x.Contains("id")).Split('=')[1]; + var rx = arr.Find(x => x.Contains("rx")).Split('=')[1] ?? "0"; + var tx = arr.Find(x => x.Contains("tx")).Split('=')[1] ?? "0"; + obj.Id = id; + obj.RX = int.Parse(rx); + obj.TX = int.Parse(tx); + return obj; + }).ToList(); + } + + public static List ParseActivityUpdates(string input) + { + string[] items = input.Split(".id="); + + List objects = items + .Where(item => !string.IsNullOrEmpty(item)) + .Select(item => $"id={item}") + .ToList(); + + return objects + .Select(x => x.Split(';').ToList()) + .Select(arr => + { + var obj = new UserActivityUpdate(); + var id = arr.Find(x => x.Contains("id")).Split('=')[1]; + var handshake = arr.Find(x => x.Contains("handshake"))?.Split('=')[1] ?? "Never"; + obj.Id = ParseEntityID(id); + obj.LastHandshake = handshake; + return obj; + }).ToList(); + } + #endregion + + public static async void HandleUserTraffics(List updates, DBContext dbContext, IMikrotikRepository API) + { + var dataUsages = await dbContext.DataUsages.ToListAsync(); + var existingItems = dataUsages.OrderBy(x => x.CreationTime).ToList(); + var lastKnownTraffics = dbContext.LastKnownTraffic.ToList(); + var users = await dbContext.Users.ToListAsync(); + foreach (var item in updates) + { + var tempUser = users.Find(x => x.Id == item.UserID); + if (tempUser == null) continue; + using var transaction = await dbContext.Database.BeginTransactionAsync(); + try + { + LastKnownTraffic lastKnown = lastKnownTraffics.Find(x => x.UserID == item.UserID); + if (lastKnown == null) continue; + + var old = existingItems.FindLast(oldItem => oldItem.UserID == item.UserID); + if (old == null) + { + await dbContext.DataUsages.AddAsync(item); + tempUser.RX = item.RX + lastKnown.RX; + tempUser.TX = item.TX + lastKnown.TX; + } + else + { + if ((old.RX <= item.RX || old.TX <= item.TX) && + (old.RX != item.RX && old.TX != item.TX)) // Normal Data (and not duplicate) + { + await dbContext.DataUsages.AddAsync(item); + } + else if (old.RX > item.RX || old.TX > item.TX) // Server Reset + { + lastKnown.RX = old.RX; + lastKnown.TX = old.TX; + lastKnown.CreationTime = DateTime.Now; + dbContext.LastKnownTraffic.Update(lastKnown); + //if (lastKnown != null) + //{ + // dbContext.LastKnownTraffic.Update(lastKnown); + //} + //else + //{ + // await dbContext.LastKnownTraffic.AddAsync(lastKnown); + //} + item.ResetNotes = $"System reset detected at: {DateTime.Now}"; + await dbContext.DataUsages.AddAsync(item); + } + if (item.RX > old.RX) tempUser.RX = item.RX + lastKnown.RX; + if (item.TX > old.TX) tempUser.TX = item.TX + lastKnown.TX; + } + if (tempUser.TrafficLimit > 0 && tempUser.RX + tempUser.TX >= tempUser.TrafficLimit) + { + // Disable User + var disable = await API.DisableUser(item.UserID); + if (disable.Code != "200") + { + Console.WriteLine("Failed disabling user"); + } + } + dbContext.Users.Update(tempUser); + await dbContext.SaveChangesAsync(); + transaction.Commit(); + } + catch (Exception ex) + { + transaction.Rollback(); + Console.WriteLine(ex.Message); + } + } + } + + public static string GetProjectVersion() + { + return System.Reflection.Assembly.GetExecutingAssembly().GetName().Version.ToString(); + } + + public static TimeSpan ConvertToTimeSpan(string input) + { + int weeks = 0; + int days = 0; + int hours = 0; + int minutes = 0; + int seconds = 0; + + if (input.Contains('w')) + { + string w = input.Split('w').First(); + weeks = int.Parse(w); + input = input.Remove(0, input.IndexOf('w') + 1); + } + if (input.Contains('d')) + { + string d = input.Split('d').First(); + days = int.Parse(d); + input = input.Remove(0, input.IndexOf('d') + 1); + } + if (input.Contains('h')) + { + string h = input.Split('h').First(); + hours = int.Parse(h); + input = input.Remove(0, input.IndexOf('h') + 1); + } + if (input.Contains('m')) + { + string m = input.Split('m').First(); + minutes = int.Parse(m); + input = input.Remove(0, input.IndexOf('m') + 1); + } + if (input.Contains('s')) + { + string s = input.Split('s').First(); + seconds = int.Parse(s); + } + + return new TimeSpan((weeks * 7) + days, hours, minutes, seconds); + } } public static class StringCompression @@ -73,7 +270,7 @@ namespace MTWireGuard.Application } } - public static class StringExtensions + public static partial class StringExtensions { public static string FirstCharToUpper(this string input) => input switch @@ -83,6 +280,50 @@ namespace MTWireGuard.Application _ => string.Concat(input[0].ToString().ToUpper(), input.AsSpan(1)) }; - public static string RemoveNonNumerics(this string input) => Regex.Replace(input, "[^0-9.]", ""); + public static string RemoveNonNumerics(this string input) => Numerics().Replace(input, ""); + [GeneratedRegex("[^0-9.]")] + private static partial Regex Numerics(); + } + + public static class ViewResultExtensions + { + public static string ToHtml(this ViewResult result, HttpContext httpContext) + { + var feature = httpContext.Features.Get(); + var routeData = feature.RouteData; + var viewName = result.ViewName ?? routeData.Values["action"] as string; + var actionContext = new ActionContext(httpContext, routeData, new ControllerActionDescriptor()); + var options = httpContext.RequestServices.GetRequiredService>(); + var htmlHelperOptions = options.Value.HtmlHelperOptions; + var viewEngineResult = result.ViewEngine?.FindView(actionContext, viewName, true) ?? options.Value.ViewEngines.Select(x => x.FindView(actionContext, viewName, true)).FirstOrDefault(x => x != null); + var view = viewEngineResult.View; + var builder = new StringBuilder(); + + using (var output = new StringWriter(builder)) + { + var viewContext = new ViewContext(actionContext, view, result.ViewData, result.TempData, output, htmlHelperOptions); + + view + .RenderAsync(viewContext) + .GetAwaiter() + .GetResult(); + } + + return builder.ToString(); + } + } + + public static class SessionExtensions + { + public static void Set(this ISession session, string key, T value) + { + session.SetString(key, JsonSerializer.Serialize(value)); + } + + public static T? Get(this ISession session, string key) + { + var value = session.GetString(key); + return value == null ? default : JsonSerializer.Deserialize(value); + } } } diff --git a/Application/MTWireGuard.Application.csproj b/Application/MTWireGuard.Application.csproj index 99e2497..9099ba8 100644 --- a/Application/MTWireGuard.Application.csproj +++ b/Application/MTWireGuard.Application.csproj @@ -1,26 +1,30 @@ - + - net6.0 + net8.0 enable enable + 2.0.0 + 2.0.0 - - - + + + + + - - - + + + all runtime; build; native; contentfiles; analyzers; buildtransitive - - + + diff --git a/Application/SetupValidator.cs b/Application/SetupValidator.cs new file mode 100644 index 0000000..314d7f4 --- /dev/null +++ b/Application/SetupValidator.cs @@ -0,0 +1,140 @@ +using AutoMapper; +using Microsoft.AspNetCore.Http; +using Microsoft.Extensions.DependencyInjection; +using MTWireGuard.Application.Models; +using MTWireGuard.Application.Repositories; +using MTWireGuard.Application.Services; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Net.Sockets; +using System.Text; +using System.Threading.Tasks; +using static System.Runtime.InteropServices.JavaScript.JSType; + +namespace MTWireGuard.Application +{ + public class SetupValidator(IServiceProvider serviceProvider) + { + private IMikrotikRepository api; + + public async Task Validate() + { + var envVariables = ValidateEnvironmentVariables(); + if (envVariables) + { + Console.BackgroundColor = ConsoleColor.Black; + Console.ForegroundColor = ConsoleColor.Red; + Console.WriteLine($"[-] Environment variables are not set!"); + Console.WriteLine($"[!] Please set \"MT_IP\", \"MT_USER\", \"MT_PASS\", \"MT_PUBLIC_IP\" variables in container environment."); + Console.ResetColor(); + Shutdown(); + } + + serviceProvider.GetService().Database.EnsureCreated(); + api = serviceProvider.GetService(); + + var (apiConnection, apiConnectionMessage) = await ValidateAPIConnection(); + if (!apiConnection) + { + Console.BackgroundColor = ConsoleColor.Black; + Console.ForegroundColor = ConsoleColor.Red; + Console.WriteLine($"[-] Error connecting to the router api!"); + Console.WriteLine($"[!] {apiConnectionMessage}"); + Console.ResetColor(); + Shutdown(); + } + + var ip = GetIPAddress(); + if (string.IsNullOrEmpty(ip)) + { + Console.BackgroundColor = ConsoleColor.Black; + Console.ForegroundColor = ConsoleColor.Red; + Console.WriteLine($"[-] Error getting container IP address!"); + Console.ResetColor(); + Shutdown(); + } + var scripts = await api.GetScripts(); + var schedulers = await api.GetSchedulers(); + var trafficScript = scripts.Find(x => x.Name == "SendTrafficUsage"); + var handshakeScript = scripts.Find(x => x.Name == "SendActivityUpdates"); + var trafficScheduler = schedulers.Find(x => x.Name == "TrafficUsage"); + + if (trafficScript == null) + { + var create = await api.CreateScript(new() + { + Name = "SendTrafficUsage", + Policies = ["write", "read", "test"], + DontRequiredPermissions = false, + Source = Helper.PeersTrafficUsageScript($"http://{ip}/api/usage") + }); + var result = create.Code; + } + if (handshakeScript == null) + { + var create = await api.CreateScript(new() + { + Name = "SendActivityUpdates", + Policies = ["write", "read", "test"], + DontRequiredPermissions = false, + Source = Helper.PeersLastHandshakeScript($"http://{ip}/api/activity") + }); + var result = create.Code; + } + if (trafficScheduler == null) + { + var create = await api.CreateScheduler(new() + { + Name = "TrafficUsage", + Interval = new TimeSpan(0, 5, 0), + OnEvent = "SendTrafficUsage", + Policies = ["write", "read", "test"] + }); + var result = create.Code; + } + } + + private static bool ValidateEnvironmentVariables() + { + string? IP = Environment.GetEnvironmentVariable("MT_IP"); + string? USER = Environment.GetEnvironmentVariable("MT_USER"); + string? PASS = Environment.GetEnvironmentVariable("MT_PASS"); + string? PUBLICIP = Environment.GetEnvironmentVariable("MT_PUBLIC_IP"); + + return string.IsNullOrEmpty(IP) || string.IsNullOrEmpty(USER) || string.IsNullOrEmpty(PUBLICIP); + } + + private async Task<(bool status, string? message)> ValidateAPIConnection() + { + try + { + return (await api.TryConnectAsync(), string.Empty); + } + catch (Exception ex) + { + return (false, ex.Message); + } + } + + private static string GetIPAddress() + { + try + { + var name = System.Net.Dns.GetHostName(); + var port = Environment.GetEnvironmentVariable("ASPNETCORE_HTTP_PORTS"); + return System.Net.Dns.GetHostEntry(name).AddressList.FirstOrDefault(x => x.AddressFamily == AddressFamily.InterNetwork).ToString() + $":{port}"; + } + catch (Exception ex) + { + Console.WriteLine(ex.Message); + return string.Empty; + } + } + + private static void Shutdown() + { + Environment.Exit(0); + } + } +}