wireguard-android/ui/src/main/java/com/wireguard/android/model/TunnelManager.kt

256 lines
11 KiB
Kotlin

/*
* Copyright © 2017-2023 WireGuard LLC. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/
package com.wireguard.android.model
import android.content.BroadcastReceiver
import android.content.Context
import android.content.Intent
import android.os.Build
import android.util.Log
import android.widget.Toast
import androidx.databinding.BaseObservable
import androidx.databinding.Bindable
import com.wireguard.android.Application.Companion.get
import com.wireguard.android.Application.Companion.getBackend
import com.wireguard.android.Application.Companion.getTunnelManager
import com.wireguard.android.BR
import com.wireguard.android.R
import com.wireguard.android.backend.Statistics
import com.wireguard.android.backend.Tunnel
import com.wireguard.android.configStore.ConfigStore
import com.wireguard.android.databinding.ObservableSortedKeyedArrayList
import com.wireguard.android.util.ErrorMessages
import com.wireguard.android.util.UserKnobs
import com.wireguard.android.util.applicationScope
import com.wireguard.config.Config
import kotlinx.coroutines.CompletableDeferred
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.SupervisorJob
import kotlinx.coroutines.async
import kotlinx.coroutines.awaitAll
import kotlinx.coroutines.flow.first
import kotlinx.coroutines.launch
import kotlinx.coroutines.withContext
/**
* Maintains and mediates changes to the set of available WireGuard tunnels,
*/
class TunnelManager(private val configStore: ConfigStore) : BaseObservable() {
private val tunnels = CompletableDeferred<ObservableSortedKeyedArrayList<String, ObservableTunnel>>()
private val context: Context = get()
private val tunnelMap: ObservableSortedKeyedArrayList<String, ObservableTunnel> = ObservableSortedKeyedArrayList(TunnelComparator)
private var haveLoaded = false
private fun addToList(name: String, config: Config?, state: Tunnel.State): ObservableTunnel {
val tunnel = ObservableTunnel(this, name, config, state)
tunnelMap.add(tunnel)
return tunnel
}
suspend fun getTunnels(): ObservableSortedKeyedArrayList<String, ObservableTunnel> = tunnels.await()
suspend fun create(name: String, config: Config?): ObservableTunnel = withContext(Dispatchers.Main.immediate) {
if (Tunnel.isNameInvalid(name))
throw IllegalArgumentException(context.getString(R.string.tunnel_error_invalid_name))
if (tunnelMap.containsKey(name))
throw IllegalArgumentException(context.getString(R.string.tunnel_error_already_exists, name))
addToList(name, withContext(Dispatchers.IO) { configStore.create(name, config!!) }, Tunnel.State.DOWN)
}
suspend fun delete(tunnel: ObservableTunnel) = withContext(Dispatchers.Main.immediate) {
val originalState = tunnel.state
val wasLastUsed = tunnel == lastUsedTunnel
// Make sure nothing touches the tunnel.
if (wasLastUsed)
lastUsedTunnel = null
tunnelMap.remove(tunnel)
try {
if (originalState == Tunnel.State.UP)
withContext(Dispatchers.IO) { getBackend().setState(tunnel, Tunnel.State.DOWN, null) }
try {
withContext(Dispatchers.IO) { configStore.delete(tunnel.name) }
} catch (e: Throwable) {
if (originalState == Tunnel.State.UP)
withContext(Dispatchers.IO) { getBackend().setState(tunnel, Tunnel.State.UP, tunnel.config) }
throw e
}
} catch (e: Throwable) {
// Failure, put the tunnel back.
tunnelMap.add(tunnel)
if (wasLastUsed)
lastUsedTunnel = tunnel
throw e
}
}
@get:Bindable
var lastUsedTunnel: ObservableTunnel? = null
private set(value) {
if (value == field) return
field = value
notifyPropertyChanged(BR.lastUsedTunnel)
applicationScope.launch { UserKnobs.setLastUsedTunnel(value?.name) }
}
suspend fun getTunnelConfig(tunnel: ObservableTunnel): Config = withContext(Dispatchers.Main.immediate) {
tunnel.onConfigChanged(withContext(Dispatchers.IO) { configStore.load(tunnel.name) })!!
}
fun onCreate() {
applicationScope.launch {
try {
onTunnelsLoaded(withContext(Dispatchers.IO) { configStore.enumerate() }, withContext(Dispatchers.IO) { getBackend().runningTunnelNames })
} catch (e: Throwable) {
Log.e(TAG, Log.getStackTraceString(e))
}
}
}
private fun onTunnelsLoaded(present: Iterable<String>, running: Collection<String>) {
for (name in present)
addToList(name, null, if (running.contains(name)) Tunnel.State.UP else Tunnel.State.DOWN)
applicationScope.launch {
val lastUsedName = UserKnobs.lastUsedTunnel.first()
if (lastUsedName != null)
lastUsedTunnel = tunnelMap[lastUsedName]
haveLoaded = true
restoreState(true)
tunnels.complete(tunnelMap)
}
}
private fun refreshTunnelStates() {
applicationScope.launch {
try {
val running = withContext(Dispatchers.IO) { getBackend().runningTunnelNames }
for (tunnel in tunnelMap)
tunnel.onStateChanged(if (running.contains(tunnel.name)) Tunnel.State.UP else Tunnel.State.DOWN)
} catch (e: Throwable) {
Log.e(TAG, Log.getStackTraceString(e))
}
}
}
suspend fun restoreState(force: Boolean) {
if (!haveLoaded || (!force && !UserKnobs.restoreOnBoot.first()))
return
val previouslyRunning = UserKnobs.runningTunnels.first()
if (previouslyRunning.isEmpty()) return
withContext(Dispatchers.IO) {
try {
tunnelMap.filter { previouslyRunning.contains(it.name) }.map { async(Dispatchers.IO + SupervisorJob()) { setTunnelState(it, Tunnel.State.UP) } }
.awaitAll()
} catch (e: Throwable) {
Log.e(TAG, Log.getStackTraceString(e))
}
}
}
suspend fun saveState() {
UserKnobs.setRunningTunnels(tunnelMap.filter { it.state == Tunnel.State.UP }.map { it.name }.toSet())
}
suspend fun setTunnelConfig(tunnel: ObservableTunnel, config: Config): Config = withContext(Dispatchers.Main.immediate) {
tunnel.onConfigChanged(withContext(Dispatchers.IO) {
getBackend().setState(tunnel, tunnel.state, config)
configStore.save(tunnel.name, config)
})!!
}
suspend fun setTunnelName(tunnel: ObservableTunnel, name: String): String = withContext(Dispatchers.Main.immediate) {
if (Tunnel.isNameInvalid(name))
throw IllegalArgumentException(context.getString(R.string.tunnel_error_invalid_name))
if (tunnelMap.containsKey(name)) {
throw IllegalArgumentException(context.getString(R.string.tunnel_error_already_exists, name))
}
val originalState = tunnel.state
val wasLastUsed = tunnel == lastUsedTunnel
// Make sure nothing touches the tunnel.
if (wasLastUsed)
lastUsedTunnel = null
tunnelMap.remove(tunnel)
var throwable: Throwable? = null
var newName: String? = null
try {
if (originalState == Tunnel.State.UP)
withContext(Dispatchers.IO) { getBackend().setState(tunnel, Tunnel.State.DOWN, null) }
withContext(Dispatchers.IO) { configStore.rename(tunnel.name, name) }
newName = tunnel.onNameChanged(name)
if (originalState == Tunnel.State.UP)
withContext(Dispatchers.IO) { getBackend().setState(tunnel, Tunnel.State.UP, tunnel.config) }
} catch (e: Throwable) {
throwable = e
// On failure, we don't know what state the tunnel might be in. Fix that.
getTunnelState(tunnel)
}
// Add the tunnel back to the manager, under whatever name it thinks it has.
tunnelMap.add(tunnel)
if (wasLastUsed)
lastUsedTunnel = tunnel
if (throwable != null)
throw throwable
newName!!
}
suspend fun setTunnelState(tunnel: ObservableTunnel, state: Tunnel.State): Tunnel.State = withContext(Dispatchers.Main.immediate) {
var newState = tunnel.state
var throwable: Throwable? = null
try {
newState = withContext(Dispatchers.IO) { getBackend().setState(tunnel, state, tunnel.getConfigAsync()) }
if (newState == Tunnel.State.UP)
lastUsedTunnel = tunnel
} catch (e: Throwable) {
throwable = e
}
tunnel.onStateChanged(newState)
saveState()
if (throwable != null)
throw throwable
newState
}
class IntentReceiver : BroadcastReceiver() {
override fun onReceive(context: Context, intent: Intent?) {
applicationScope.launch {
val manager = getTunnelManager()
if (intent == null) return@launch
val action = intent.action ?: return@launch
if ("com.wireguard.android.action.REFRESH_TUNNEL_STATES" == action) {
manager.refreshTunnelStates()
return@launch
}
if (Build.VERSION.SDK_INT < Build.VERSION_CODES.M || !UserKnobs.allowRemoteControlIntents.first())
return@launch
val state: Tunnel.State
state = when (action) {
"com.wireguard.android.action.SET_TUNNEL_UP" -> Tunnel.State.UP
"com.wireguard.android.action.SET_TUNNEL_DOWN" -> Tunnel.State.DOWN
else -> return@launch
}
val tunnelName = intent.getStringExtra("tunnel") ?: return@launch
val tunnels = manager.getTunnels()
val tunnel = tunnels[tunnelName] ?: return@launch
try {
manager.setTunnelState(tunnel, state)
} catch (e: Throwable) {
Toast.makeText(context, ErrorMessages[e], Toast.LENGTH_LONG).show()
}
}
}
}
suspend fun getTunnelState(tunnel: ObservableTunnel): Tunnel.State = withContext(Dispatchers.Main.immediate) {
tunnel.onStateChanged(withContext(Dispatchers.IO) { getBackend().getState(tunnel) })
}
suspend fun getTunnelStatistics(tunnel: ObservableTunnel): Statistics = withContext(Dispatchers.Main.immediate) {
tunnel.onStatisticsChanged(withContext(Dispatchers.IO) { getBackend().getStatistics(tunnel) })!!
}
companion object {
private const val TAG = "WireGuard/TunnelManager"
}
}