Skip to content

Commit 6cb005f

Browse files
authored
add kan arch (PaddlePaddle#1139)
* add kan arch * fix:address review comments * update arch.md and fix review comments
1 parent 5d1b6c2 commit 6cb005f

File tree

2 files changed

+386
-0
lines changed

2 files changed

+386
-0
lines changed

docs/zh/api/arch.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
- FNO1d
2222
- Generator
2323
- HEDeepONets
24+
- KAN
2425
- LorenzEmbedding
2526
- MLP
2627
- ModelList

ppsci/arch/kan.py

Lines changed: 385 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,385 @@
1+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import math
16+
from typing import Callable
17+
from typing import Tuple
18+
19+
import paddle
20+
21+
from ppsci.arch import base
22+
from ppsci.utils import initializer
23+
24+
"""
25+
This is the paddle implementation of Korogonov-Arnold-Network (KAN)
26+
which is based on the torch implementation [efficient-kan] by Blealtan and akkashdash
27+
please refer to their work (https://github.com/Blealtan/efficient-kan)
28+
Authors: guhaohao0991([email protected])
29+
Date: 2025/04/
30+
"""
31+
32+
33+
class KANLinear(paddle.nn.Layer):
34+
def __init__(
35+
self,
36+
in_features: int,
37+
out_features: int,
38+
grid_size: int = 5,
39+
spline_order: int = 3,
40+
scale_noise: float = 0.1,
41+
scale_base: float = 1.0,
42+
scale_spline: float = 1.0,
43+
enable_standalone_scale_spline: bool = True,
44+
base_activation: Callable[[paddle.Tensor], paddle.Tensor] = paddle.nn.Silu,
45+
grid_eps: float = 0.02,
46+
grid_range: Tuple[float, float] = (-1, 1),
47+
):
48+
super().__init__()
49+
self.in_features = in_features
50+
self.out_features = out_features
51+
self.grid_size = grid_size
52+
self.spline_order = spline_order
53+
54+
h = (grid_range[1] - grid_range[0]) / grid_size
55+
grid = (
56+
(
57+
paddle.arange(start=-spline_order, end=grid_size + spline_order + 1) * h
58+
+ grid_range[0]
59+
)
60+
.expand(shape=[in_features, -1])
61+
.contiguous()
62+
)
63+
self.register_buffer(name="grid", tensor=grid)
64+
65+
self.base_weight = self.create_parameter(
66+
shape=[out_features, in_features],
67+
default_initializer=paddle.nn.initializer.Assign(
68+
paddle.empty(shape=[out_features, in_features])
69+
),
70+
)
71+
self.spline_weight = self.create_parameter(
72+
shape=[out_features, in_features, grid_size + spline_order],
73+
default_initializer=paddle.nn.initializer.Assign(
74+
paddle.empty(
75+
shape=[out_features, in_features, grid_size + spline_order]
76+
)
77+
),
78+
)
79+
80+
if enable_standalone_scale_spline:
81+
self.spline_scaler = self.create_parameter(
82+
shape=[out_features, in_features],
83+
default_initializer=paddle.nn.initializer.Assign(
84+
paddle.empty(shape=[out_features, in_features])
85+
),
86+
)
87+
88+
self.scale_noise = scale_noise
89+
self.scale_base = scale_base
90+
self.scale_spline = scale_spline
91+
self.enable_standalone_scale_spline = enable_standalone_scale_spline
92+
self.base_activation = base_activation()
93+
self.grid_eps = grid_eps
94+
95+
self.reset_parameters()
96+
97+
def reset_parameters(self):
98+
self.base_weight = initializer.kaiming_uniform_(
99+
tensor=self.base_weight,
100+
a=math.sqrt(5) * self.scale_base,
101+
nonlinearity="leaky_relu",
102+
)
103+
with paddle.no_grad():
104+
noise = (
105+
(
106+
paddle.rand(
107+
shape=[self.grid_size + 1, self.in_features, self.out_features]
108+
)
109+
- 1 / 2
110+
)
111+
* self.scale_noise
112+
/ self.grid_size
113+
)
114+
115+
paddle.assign(
116+
(self.scale_spline if not self.enable_standalone_scale_spline else 1.0)
117+
* self.curve2coeff(
118+
self.grid.T[self.spline_order : -self.spline_order], noise
119+
),
120+
output=self.spline_weight.data,
121+
)
122+
123+
if self.enable_standalone_scale_spline:
124+
self.spline_scaler = initializer.kaiming_uniform_(
125+
tensor=self.spline_scaler,
126+
a=math.sqrt(5) * self.scale_spline,
127+
nonlinearity="leaky_relu",
128+
)
129+
130+
def b_splines(self, x: paddle.Tensor):
131+
"""
132+
Compute the B-spline bases for the given input tensor.
133+
134+
Args:
135+
x (paddle.Tensor): Input tensor of shape (batch_size, in_features).
136+
137+
Returns:
138+
paddle.Tensor: B-spline bases tensor of shape (batch_size, in_features, grid_size + spline_order).
139+
"""
140+
assert x.dim() == 2 and x.shape[1] == self.in_features
141+
grid: paddle.Tensor = self.grid
142+
x = x.unsqueeze(axis=-1)
143+
bases = ((x >= grid[:, :-1]) & (x < grid[:, 1:])).to(x.dtype)
144+
145+
for k in range(1, self.spline_order + 1):
146+
bases = (x - grid[:, : -(k + 1)]) / (
147+
grid[:, k:-1] - grid[:, : -(k + 1)]
148+
) * bases[:, :, :-1] + (grid[:, k + 1 :] - x) / (
149+
grid[:, k + 1 :] - grid[:, 1:-k]
150+
) * bases[
151+
:, :, 1:
152+
]
153+
154+
assert tuple(bases.shape) == (
155+
x.shape[0],
156+
self.in_features,
157+
self.grid_size + self.spline_order,
158+
)
159+
160+
return bases.contiguous()
161+
162+
def curve2coeff(self, x: paddle.Tensor, y: paddle.Tensor):
163+
"""
164+
Compute the coefficients of the curve that interpolates the given points.
165+
166+
Args:
167+
x (paddle.Tensor): Input tensor of shape (batch_size, in_features).
168+
y (paddle.Tensor): Output tensor of shape (batch_size, in_features, out_features).
169+
170+
Returns:
171+
paddle.Tensor: Coefficients tensor of shape (out_features, in_features, grid_size + spline_order).
172+
"""
173+
assert x.dim() == 2 and x.shape[1] == self.in_features
174+
assert tuple(y.shape) == (x.shape[0], self.in_features, self.out_features)
175+
176+
A = self.b_splines(x).transpose(
177+
perm=dim2perm(self.b_splines(x).ndim, 0, 1)
178+
) # [in_features, batch_size, grid_size + spline_order]
179+
B = y.transpose(
180+
perm=dim2perm(y.ndim, 0, 1)
181+
) # [in_features, batch_size, out_features]
182+
solution = paddle.linalg.lstsq(x=A, y=B)[
183+
0
184+
] # [in_features, grid_size + spline_order, out_features]
185+
if A.shape[0] == 1:
186+
solution = solution.unsqueeze(axis=0)
187+
# print("A shape: ", A.shape, "B shape: ", B.shape, "Solution shape: ", solution.shape)
188+
result = solution.transpose([2, 0, 1])
189+
assert tuple(result.shape) == (
190+
self.out_features,
191+
self.in_features,
192+
self.grid_size + self.spline_order,
193+
)
194+
195+
return result.contiguous()
196+
197+
@property
198+
def scaled_spline_weight(self):
199+
return self.spline_weight * (
200+
self.spline_scaler.unsqueeze(axis=-1)
201+
if self.enable_standalone_scale_spline
202+
else 1.0
203+
)
204+
205+
def forward(self, x: paddle.Tensor):
206+
assert x.dim() == 2 and x.shape[1] == self.in_features
207+
208+
base_output = paddle.nn.functional.linear(
209+
x=self.base_activation(x), weight=self.base_weight.T
210+
)
211+
212+
spline_output = paddle.nn.functional.linear(
213+
x=self.b_splines(x).reshape([x.shape[0], -1]).contiguous(),
214+
weight=self.scaled_spline_weight.reshape(
215+
[self.out_features, -1]
216+
).T.contiguous(),
217+
)
218+
# cant calculate 1st order derivation using view
219+
# spline_output = paddle.nn.functional.linear(
220+
# x=self.b_splines(x).view(x.shape[0], -1),
221+
# weight=self.scaled_spline_weight.view(self.out_features, -1).T)
222+
223+
return base_output + spline_output
224+
225+
@paddle.no_grad()
226+
def update_grid(self, x: paddle.Tensor, margin=0.01):
227+
assert x.dim() == 2 and x.shape[1] == self.in_features
228+
batch = x.shape[0]
229+
230+
splines = self.b_splines(x) # [batch, in, coeff]
231+
splines = splines.transpose(perm=[1, 0, 2]) # [in, batch, coeff]
232+
orig_coeff = self.scaled_spline_weight # [out, in, coeff]
233+
orig_coeff = orig_coeff.transpose(perm=[1, 2, 0]) # [in, coeff, out]
234+
unreduced_spline_output = paddle.bmm(
235+
x=splines, y=orig_coeff
236+
) # [in, batch, out]
237+
unreduced_spline_output = unreduced_spline_output.transpose(
238+
perm=[1, 0, 2]
239+
) # [batch, in, out]
240+
241+
# sort each channel individually to collect data distribution
242+
x_sorted = (paddle.sort(x=x, axis=0), paddle.argsort(x=x, axis=0))[0]
243+
grid_adaptive = x_sorted[
244+
paddle.linspace(
245+
start=0, stop=batch - 1, num=self.grid_size + 1, dtype="int64"
246+
)
247+
]
248+
uniform_step = (x_sorted[-1] - x_sorted[0] + 2 * margin) / self.grid_size
249+
grid_uniform = (
250+
paddle.arange(dtype="float32", end=self.grid_size + 1).unsqueeze(axis=1)
251+
* uniform_step
252+
+ x_sorted[0]
253+
- margin
254+
)
255+
256+
grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive
257+
grid = paddle.concat(
258+
x=[
259+
grid[:1]
260+
- uniform_step
261+
* paddle.arange(
262+
start=self.spline_order, end=0, step=-1, dtype="float32"
263+
).unsqueeze(axis=1),
264+
grid,
265+
grid[-1:]
266+
+ uniform_step
267+
* paddle.arange(
268+
start=1, end=self.spline_order + 1, dtype="float32"
269+
).unsqueeze(axis=1),
270+
],
271+
axis=0,
272+
)
273+
274+
paddle.assign(grid.T, output=self.grid)
275+
paddle.assign(
276+
self.curve2coeff(x, unreduced_spline_output), output=self.spline_weight.data
277+
)
278+
279+
def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0):
280+
"""
281+
Compute the regularization loss.
282+
283+
L1 and the entropy loss is for the feature selection, i.e., let the weight of the activation function be small.
284+
"""
285+
l1_fake = self.spline_weight.abs().mean(axis=-1)
286+
regularization_loss_activation = l1_fake.sum()
287+
p = l1_fake / regularization_loss_activation
288+
regularization_loss_entropy = -paddle.sum(x=p * p.log())
289+
return (
290+
regularize_activation * regularization_loss_activation
291+
+ regularize_entropy * regularization_loss_entropy
292+
)
293+
294+
295+
class KAN(base.Arch):
296+
"""Kolmogorov-Arnold Network (KAN).
297+
298+
Args:
299+
layers_hidden (Tuple[int, ...]): The number of hidden neurons in each layer.
300+
input_keys (Tuple[str, ...]): The keys of the input dictionary.
301+
output_keys (Tuple[str, ...]): The keys of the output dictionary.
302+
grid_size (int): The size of the grid used by the spline basis functions. Default: 5.
303+
spline_order (int): The order of the spline basis functions. Default: 3.
304+
scale_noise (float): The scaling factor for the noise added to the weights of the KAN-linear layers. Default: 0.1.
305+
scale_base (float): The scaling factor for the base activation output. Default: 1.0.
306+
scale_spline (float): The scaling factor for the b-spline output. Default: 1.0.
307+
base_activation (Callable[[paddle.Tensor], paddle.Tensor]): The base activation function. Default: paddle.nn.Silu.
308+
grid_eps (float): The epsilon value used to initialize the grid. Default: 0.02.
309+
grid_range (Tuple[float, float]): The domain range of the grid for b-spline interpolation. Default: (-1, 1).
310+
311+
Examples:
312+
>>> import paddle
313+
>>> import ppsci
314+
>>> model = ppsci.arch.KAN(
315+
... layers_hidden=(2, 5, 5, 1),
316+
... input_keys=("x", "y"),
317+
... output_keys=("z"),
318+
... grid_size=5,
319+
... spline_order=3
320+
>>> )
321+
>>> input_dict = {"x": paddle.rand([64, 1]),
322+
... "y": paddle.rand([64, 1])}
323+
>>> output_dict = model(input_dict)
324+
>>> print(output_dict["z"].shape)
325+
[64, 1]
326+
"""
327+
328+
def __init__(
329+
self,
330+
layers_hidden: Tuple[int, ...],
331+
input_keys: Tuple[str, ...],
332+
output_keys: Tuple[str, ...],
333+
grid_size: int = 5,
334+
spline_order: int = 3,
335+
scale_noise: float = 0.1,
336+
scale_base: float = 1.0,
337+
scale_spline: float = 1.0,
338+
base_activation: Callable[[paddle.Tensor], paddle.Tensor] = paddle.nn.Silu,
339+
grid_eps: float = 0.02,
340+
grid_range: Tuple[float, float] = (-1, 1),
341+
):
342+
super().__init__()
343+
self.input_keys = input_keys
344+
self.output_keys = output_keys
345+
self.grid_size = grid_size
346+
self.spline_order = spline_order
347+
self.layers = paddle.nn.LayerList()
348+
for in_features, out_features in zip(layers_hidden, layers_hidden[1:]):
349+
self.layers.append(
350+
KANLinear(
351+
in_features,
352+
out_features,
353+
grid_size=grid_size,
354+
spline_order=spline_order,
355+
scale_noise=scale_noise,
356+
scale_base=scale_base,
357+
scale_spline=scale_spline,
358+
base_activation=base_activation,
359+
grid_eps=grid_eps,
360+
grid_range=grid_range,
361+
)
362+
)
363+
364+
def forward(self, x_dict, update_grid=False):
365+
x = self.concat_to_tensor(x_dict, self.input_keys, axis=-1)
366+
for index, layer in enumerate(self.layers):
367+
if update_grid:
368+
layer.update_grid(x)
369+
x = layer(x)
370+
if index < len(self.layers) - 1:
371+
x = paddle.nn.functional.tanh(x=x)
372+
out_dic = self.split_to_dict(x, self.output_keys, axis=-1)
373+
return out_dic
374+
375+
def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0):
376+
return sum(
377+
layer.regularization_loss(regularize_activation, regularize_entropy)
378+
for layer in self.layers
379+
)
380+
381+
382+
def dim2perm(ndim, dim0, dim1):
383+
perm = list(range(ndim))
384+
perm[dim0], perm[dim1] = perm[dim1], perm[dim0]
385+
return perm

0 commit comments

Comments
 (0)