|
| 1 | +using ModelingToolkit: operation, istree, arguments |
| 2 | +# Method of lines discretization scheme |
| 3 | +struct MOLFiniteDifference{T,T2} <: DiffEqBase.AbstractDiscretization |
| 4 | + dxs::T |
| 5 | + time::T2 |
| 6 | + upwind_order::Int |
| 7 | + centered_order::Int |
| 8 | +end |
| 9 | + |
| 10 | +# Constructors. If no order is specified, both upwind and centered differences will be 2nd order |
| 11 | +MOLFiniteDifference(dxs, time; upwind_order = 1, centered_order = 2) = |
| 12 | + MOLFiniteDifference(dxs, time, upwind_order, centered_order) |
| 13 | + |
| 14 | +function SciMLBase.symbolic_discretize(pdesys::ModelingToolkit.PDESystem,discretization::DiffEqOperators.MOLFiniteDifference) |
| 15 | + pdeeqs = pdesys.eqs isa Vector ? pdesys.eqs : [pdesys.eqs] |
| 16 | + t = discretization.time |
| 17 | + nottime = filter(x->x.val != t.val,pdesys.indvars) |
| 18 | + |
| 19 | + # Discretize space |
| 20 | + |
| 21 | + space = map(nottime) do x |
| 22 | + xdomain = pdesys.domain[findfirst(d->x.val == d.variables,pdesys.domain)] |
| 23 | + @assert xdomain.domain isa IntervalDomain |
| 24 | + dx = discretization.dxs[findfirst(dxs->x.val == dxs[1].val,discretization.dxs)][2] |
| 25 | + dx isa Number ? (xdomain.domain.lower:dx:xdomain.domain.upper) : dx |
| 26 | + end |
| 27 | + tdomain = pdesys.domain[findfirst(d->t.val == d.variables,pdesys.domain)] |
| 28 | + @assert tdomain.domain isa IntervalDomain |
| 29 | + tspan = (tdomain.domain.lower,tdomain.domain.upper) |
| 30 | + |
| 31 | + # Build symbolic variables |
| 32 | + indices = CartesianIndices(((axes(s)[1] for s in space)...,)) |
| 33 | + depvars = map(pdesys.depvars) do u |
| 34 | + [Num(Variable{Symbolics.FnType{Tuple{Any}, Real}}(Base.nameof(ModelingToolkit.operation(u.val)),II.I...))(t) for II in indices] |
| 35 | + end |
| 36 | + spacevals = map(y->[Pair(nottime[i],space[i][y.I[i]]) for i in 1:length(nottime)],indices) |
| 37 | + |
| 38 | + # Build symbolic maps |
| 39 | + edges = reduce(vcat,[[vcat([Colon() for j in 1:i-1],1,[Colon() for j in i+1:length(nottime)]), |
| 40 | + vcat([Colon() for j in 1:i-1],size(depvars[1],i),[Colon() for j in i+1:length(nottime)])] for i in 1:length(nottime)]) |
| 41 | + |
| 42 | + #edgeindices = [indices[e...] for e in edges] |
| 43 | + edgevals = reduce(vcat,[[nottime[i]=>first(space[i]),nottime[i]=>last(space[i])] for i in 1:length(space)]) |
| 44 | + edgevars = [[d[e...] for e in edges] for d in depvars] |
| 45 | + edgemaps = [spacevals[e...] for e in edges] |
| 46 | + initmaps = substitute.(pdesys.depvars,[t=>tspan[1]]) |
| 47 | + |
| 48 | + depvarmaps = reduce(vcat,[substitute.((pdesys.depvars[i],),edgevals) .=> edgevars[i] for i in 1:length(pdesys.depvars)]) |
| 49 | + if length(nottime) == 1 |
| 50 | + left_weights(j) = DiffEqOperators.calculate_weights(discretization.upwind_order, 0.0, [space[j][1],space[j][2]]) |
| 51 | + right_weights(j) = DiffEqOperators.calculate_weights(discretization.upwind_order, 0.0, [space[j][end-1],space[j][end]]) |
| 52 | + central_neighbor_idxs(i,j) = [i+CartesianIndex([ifelse(l==j,-1,0) for l in 1:length(nottime)]...),i,i+CartesianIndex([ifelse(l==j,1,0) for l in 1:length(nottime)]...)] |
| 53 | + derivars = [[dot(left_weights(j),[depvars[j][central_neighbor_idxs(CartesianIndex(2),1)[1:2][2]],depvars[j][central_neighbor_idxs(CartesianIndex(2),1)[1:2][1]]]), |
| 54 | + dot(right_weights(j),[depvars[j][central_neighbor_idxs(CartesianIndex(length(space[1])-1),1)[end-1:end][1]],depvars[j][central_neighbor_idxs(CartesianIndex(length(space[1])-1),1)[end-1:end][2]]])] |
| 55 | + for j in 1:length(pdesys.depvars)] |
| 56 | + depvarderivmaps = reduce(vcat,[substitute.((Differential(nottime[j])(pdesys.depvars[i]),),edgevals) .=> derivars[i] |
| 57 | + for i in 1:length(pdesys.depvars) for j in 1:length(nottime)]) |
| 58 | + else |
| 59 | + # TODO: Fix Neumann and Robin on higher dimension |
| 60 | + depvarderivmaps = [] |
| 61 | + end |
| 62 | + |
| 63 | + # Generate initial conditions and bc equations |
| 64 | + u0 = [] |
| 65 | + bceqs = [] |
| 66 | + for bc in pdesys.bcs |
| 67 | + if ModelingToolkit.operation(bc.lhs) isa Sym && t.val ∉ ModelingToolkit.arguments(bc.lhs) |
| 68 | + # initial condition |
| 69 | + # Assume in the form `u(...) ~ ...` for now |
| 70 | + push!(u0,vec(depvars[findfirst(isequal(bc.lhs),initmaps)] .=> substitute.((bc.rhs,),spacevals))) |
| 71 | + else |
| 72 | + # Algebraic equations for BCs |
| 73 | + i = findfirst(x->occursin(x.val,bc.lhs),first.(depvarmaps)) |
| 74 | + |
| 75 | + # TODO: Fix Neumann and Robin on higher dimension |
| 76 | + lhs = length(nottime) == 1 ? substitute(bc.lhs,depvarderivmaps[i]) : bc.lhs |
| 77 | + |
| 78 | + lhs = substitute(lhs,depvarmaps[i]) |
| 79 | + rhs = substitute.((bc.rhs,),edgemaps[i]) |
| 80 | + lhs = lhs isa Vector ? lhs : [lhs] # handle 1D |
| 81 | + push!(bceqs,lhs .~ rhs) |
| 82 | + end |
| 83 | + end |
| 84 | + u0 = reduce(vcat,u0) |
| 85 | + bceqs = reduce(vcat,bceqs) |
| 86 | + |
| 87 | + # Generate PDE Equations |
| 88 | + interior = indices[[2:length(s)-1 for s in space]...] |
| 89 | + eqs = vec(map(Base.product(interior,pdeeqs)) do p |
| 90 | + i,eq = p |
| 91 | + |
| 92 | + # TODO: Number of points in the central_neighbor_idxs should be dependent |
| 93 | + # on discretization.centered_order |
| 94 | + # TODO: Generalize central difference handling to allow higher even order derivatives |
| 95 | + central_neighbor_idxs(i,j) = [i+CartesianIndex([ifelse(l==j,-1,0) for l in 1:length(nottime)]...),i,i+CartesianIndex([ifelse(l==j,1,0) for l in 1:length(nottime)]...)] |
| 96 | + central_weights(i,j) = DiffEqOperators.calculate_weights(2, 0.0, [space[j][i[j]-1],space[j][i[j]],space[j][i[j]+1]]) |
| 97 | + central_deriv_rules = [(Differential(nottime[j])^2)(pdesys.depvars[k]) => dot(central_weights(i,j),depvars[k][central_neighbor_idxs(i,j)]) for j in 1:length(nottime), k in 1:length(pdesys.depvars)] |
| 98 | + valrules = vcat([pdesys.depvars[k] => depvars[k][i] for k in 1:length(pdesys.depvars)], |
| 99 | + [nottime[k] => space[k][i[k]] for k in 1:length(nottime)]) |
| 100 | + |
| 101 | + # TODO: Use rule matching for nonlinear Laplacian |
| 102 | + |
| 103 | + # TODO: upwind rules needs interpolation into `@rule` |
| 104 | + #forward_weights(i,j) = DiffEqOperators.calculate_weights(discretization.upwind_order, 0.0, [space[j][i[j]],space[j][i[j]+1]]) |
| 105 | + #reverse_weights(i,j) = DiffEqOperators.calculate_weights(discretization.upwind_order, 0.0, [space[j][i[j]-1],space[j][i[j]]]) |
| 106 | + #upwinding_rules = [@rule(*(~~a,(Differential(nottime[j]))(u),~~b) => IfElse.ifelse(*(~~a..., ~~b...,)>0, |
| 107 | + # *(~~a..., ~~b..., dot(reverse_weights(i,j),depvars[k][central_neighbor_idxs(i,j)[1:2]])), |
| 108 | + # *(~~a..., ~~b..., dot(forward_weights(i,j),depvars[k][central_neighbor_idxs(i,j)[2:3]])))) |
| 109 | + # for j in 1:length(nottime), k in 1:length(pdesys.depvars)] |
| 110 | + |
| 111 | + substitute(eq.lhs,vcat(vec(central_deriv_rules),valrules)) ~ substitute(eq.rhs,vcat(vec(central_deriv_rules),valrules)) |
| 112 | + end) |
| 113 | + |
| 114 | + # Finalize |
| 115 | + defaults = pdesys.ps === nothing || pdesys.ps === SciMLBase.NullParameters() ? u0 : vcat(u0,pdesys.ps) |
| 116 | + ps = pdesys.ps === nothing || pdesys.ps === SciMLBase.NullParameters() ? Num[] : first.(pdesys.ps) |
| 117 | + sys = ODESystem(vcat(eqs,unique(bceqs)),t,vec(reduce(vcat,vec(depvars))),ps,defaults=Dict(defaults)) |
| 118 | + sys, tspan |
| 119 | +end |
| 120 | + |
| 121 | +function SciMLBase.discretize(pdesys::ModelingToolkit.PDESystem,discretization::DiffEqOperators.MOLFiniteDifference) |
| 122 | + sys, tspan = SciMLBase.symbolic_discretize(pdesys,discretization) |
| 123 | + simpsys = structural_simplify(sys) |
| 124 | + prob = ODEProblem(simpsys,Pair[],tspan) |
| 125 | +end |
| 126 | + |
| 127 | +# Piracy, to be deleted when https://github.com/JuliaSymbolics/SymbolicUtils.jl/pull/251 |
| 128 | +# merges |
| 129 | +Base.occursin(needle::ModelingToolkit.SymbolicUtils.Symbolic, haystack::ModelingToolkit.SymbolicUtils.Symbolic) = _occursin(needle, haystack) |
| 130 | +Base.occursin(needle, haystack::ModelingToolkit.SymbolicUtils.Symbolic) = _occursin(needle, haystack) |
| 131 | +Base.occursin(needle::ModelingToolkit.SymbolicUtils.Symbolic, haystack) = _occursin(needle, haystack) |
| 132 | +function _occursin(needle, haystack) |
| 133 | + isequal(needle, haystack) && return true |
| 134 | + |
| 135 | + if istree(haystack) |
| 136 | + args = arguments(haystack) |
| 137 | + for arg in args |
| 138 | + occursin(needle, arg) && return true |
| 139 | + end |
| 140 | + end |
| 141 | + return false |
| 142 | +end |
0 commit comments