1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package net.sf.michaelo.activedirectory;
17
18 import java.util.Arrays;
19 import java.util.Hashtable;
20 import java.util.NoSuchElementException;
21 import java.util.Objects;
22 import java.util.Scanner;
23 import java.util.concurrent.ThreadLocalRandom;
24 import java.util.logging.Level;
25 import java.util.logging.Logger;
26
27 import javax.naming.Context;
28 import javax.naming.InvalidNameException;
29 import javax.naming.NameNotFoundException;
30 import javax.naming.NamingEnumeration;
31 import javax.naming.NamingException;
32 import javax.naming.directory.Attribute;
33 import javax.naming.directory.Attributes;
34 import javax.naming.directory.DirContext;
35 import javax.naming.directory.InitialDirContext;
36
37 import org.apache.commons.lang3.Validate;
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71 public class ActiveDirectoryDnsLocator {
72
73 private static class SrvRecord implements Comparable<SrvRecord> {
74
75 static final String UNAVAILABLE_SERVICE = ".";
76
77 private int priority;
78 private int weight;
79 private int sum;
80 private int port;
81 private String target;
82
83 public SrvRecord(int priority, int weight, int port, String target) {
84 Validate.inclusiveBetween(0, 0xFFFF, priority, "priority must be between 0 and 65535");
85 Validate.inclusiveBetween(0, 0xFFFF, weight, "weight must be between 0 and 65535");
86 Validate.inclusiveBetween(0, 0xFFFF, port, "port must be between 0 and 65535");
87 Validate.notEmpty(target, "target cannot be null or empty");
88
89 this.priority = priority;
90 this.weight = weight;
91 this.port = port;
92 this.target = target;
93 }
94
95 @Override
96 public boolean equals(Object obj) {
97 if (obj == null || !(obj instanceof SrvRecord))
98 return false;
99
100 SrvRecord that = (SrvRecord) obj;
101
102 return priority == that.priority && weight == that.weight && port == that.port
103 && target.equals(that.target);
104 }
105
106 @Override
107 public int hashCode() {
108 return Objects.hash(priority, weight, port, target);
109 }
110
111 @Override
112 public String toString() {
113 StringBuilder builder = new StringBuilder("SRV RR: ");
114 builder.append(priority).append(' ');
115 builder.append(weight).append(' ');
116 if (sum != 0)
117 builder.append('(').append(sum).append(") ");
118 builder.append(port).append(' ');
119 builder.append(target);
120 return builder.toString();
121 }
122
123 @Override
124 public int compareTo(SrvRecord that) {
125
126 if (priority > that.priority) {
127 return 1;
128 } else if (priority < that.priority) {
129 return -1;
130 } else if (weight == 0 && that.weight != 0) {
131 return -1;
132 } else if (weight != 0 && that.weight == 0) {
133 return 1;
134 } else {
135 return 0;
136 }
137 }
138
139 }
140
141
142
143
144 public static class HostPort {
145
146 private String host;
147 private int port;
148
149 public HostPort(String host, int port) {
150 Validate.notEmpty(host, "host cannot be null or empty");
151 Validate.inclusiveBetween(0, 0xFFFF, port, "port must be between 0 and 65535");
152
153 this.host = host;
154 this.port = port;
155 }
156
157 public String getHost() {
158 return host;
159 }
160
161 public int getPort() {
162 return port;
163 }
164
165 @Override
166 public boolean equals(Object obj) {
167 if (obj == null || !(obj instanceof HostPort))
168 return false;
169
170 HostPort that = (HostPort) obj;
171
172 return host.equals(that.host) && port == that.port;
173 }
174
175 @Override
176 public int hashCode() {
177 return Objects.hash(host, port);
178 }
179
180 @Override
181 public String toString() {
182 StringBuilder builder = new StringBuilder();
183 builder.append(host).append(':').append(port);
184 return builder.toString();
185 }
186
187 }
188
189 private static final String SRV_RR_FORMAT = "_%s._tcp.%s";
190 private static final String SRV_RR_WITH_SITES_FORMAT = "_%s._tcp.%s._sites.%s";
191
192 private static final String SRV_RR = "SRV";
193 private static final String[] SRV_RR_ATTR = new String[] { SRV_RR };
194
195 private static final Logger LOGGER = Logger
196 .getLogger(ActiveDirectoryDnsLocator.class.getName());
197
198 private final Hashtable<String, Object> env;
199
200 private ActiveDirectoryDnsLocator(Builder builder) {
201 env = new Hashtable<String, Object>();
202 env.put(Context.INITIAL_CONTEXT_FACTORY, builder.contextFactory);
203 env.putAll(builder.additionalProperties);
204 }
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221 public static final class Builder {
222
223
224 private String contextFactory;
225 private Hashtable<String, Object> additionalProperties;
226
227 private boolean done;
228
229
230
231
232 public Builder() {
233
234 contextFactory("com.sun.jndi.dns.DnsContextFactory");
235 additionalProperties = new Hashtable<String, Object>();
236 }
237
238
239
240
241
242
243
244
245
246
247
248
249 public Builder contextFactory(String contextFactory) {
250 check();
251 this.contextFactory = validateAndReturnString("contextFactory", contextFactory);
252 return this;
253 }
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268 public Builder additionalProperty(String name, Object value) {
269 check();
270 Validate.notEmpty(name, "additional property's name cannot be null or empty");
271 this.additionalProperties.put(name, value);
272 return this;
273 }
274
275
276
277
278
279
280
281
282
283
284 public ActiveDirectoryDnsLocator build() {
285
286 ActiveDirectoryDnsLocator serviceLocator = new ActiveDirectoryDnsLocator(this);
287 done = true;
288
289 return serviceLocator;
290 }
291
292 private void check() {
293 if (done)
294 throw new IllegalStateException("cannot modify an already used builder");
295 }
296
297 private String validateAndReturnString(String name, String value) {
298 return Validate.notEmpty(value, "property '%s' cannot be null or empty", name);
299 }
300
301 }
302
303 private SrvRecord[] lookUpSrvRecords(DirContext context, String name) throws NamingException {
304 Attributes attrs = null;
305
306 try {
307 attrs = context.getAttributes(name, SRV_RR_ATTR);
308 } catch (InvalidNameException e) {
309 NamingException ne = new NamingException("name '" + name + "' is invalid");
310 ne.initCause(ne);
311 throw ne;
312 } catch (NameNotFoundException e) {
313 return null;
314 }
315
316 Attribute srvAttr = attrs.get(SRV_RR);
317 if (srvAttr == null)
318 return null;
319
320 NamingEnumeration<?> records = null;
321
322 SrvRecord[] srvRecords = new SrvRecord[srvAttr.size()];
323
324 try {
325 records = srvAttr.getAll();
326
327 int recordCnt = 0;
328 while (records.hasMoreElements()) {
329 String record = (String) records.nextElement();
330 try (Scanner scanner = new Scanner(record)) {
331 scanner.useDelimiter(" ");
332
333 int priority = scanner.nextInt();
334 int weight = scanner.nextInt();
335 int port = scanner.nextInt();
336 String target = scanner.next();
337 SrvRecord srvRecord = new SrvRecord(priority, weight, port, target);
338
339 srvRecords[recordCnt++] = srvRecord;
340 }
341 }
342 } catch (NoSuchElementException e) {
343 throw new IllegalStateException("The supplied SRV RR is invalid", e);
344 } finally {
345 if (records != null)
346 try {
347 records.close();
348 } catch (NamingException e) {
349 ;
350 }
351 }
352
353
354
355
356
357 if (srvRecords.length == 0 || srvRecords.length == 1
358 && srvRecords[0].target.equals(SrvRecord.UNAVAILABLE_SERVICE))
359 return null;
360
361 return srvRecords;
362 }
363
364 private HostPort[] sortByRfc2782(SrvRecord[] srvRecords) {
365 if (srvRecords == null)
366 return null;
367
368
369 Arrays.sort(srvRecords);
370
371 HostPort[] sortedServers = new HostPort[srvRecords.length];
372 for (int i = 0, start = -1, end = -1, hp = 0; i < srvRecords.length; i++) {
373
374 start = i;
375 while (i + 1 < srvRecords.length
376 && srvRecords[i].priority == srvRecords[i + 1].priority) {
377 i++;
378 }
379 end = i;
380
381 for (int repeat = 0; repeat < (end - start) + 1; repeat++) {
382 int sum = 0;
383 for (int j = start; j <= end; j++) {
384 if (srvRecords[j] != null) {
385 sum += srvRecords[j].weight;
386 srvRecords[j].sum = sum;
387 }
388 }
389
390 int r = sum == 0 ? 0 : ThreadLocalRandom.current().nextInt(sum + 1);
391 for (int k = start; k <= end; k++) {
392 SrvRecord srvRecord = srvRecords[k];
393
394 if (srvRecord != null && srvRecord.sum >= r) {
395 String host = srvRecord.target.substring(0, srvRecord.target.length() - 1);
396 sortedServers[hp++] = new HostPort(host, srvRecord.port);
397 srvRecords[k] = null;
398 }
399 }
400 }
401 }
402
403 return sortedServers;
404 }
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426 public HostPort[] locate(String service, String site, String domainName)
427 throws NamingException {
428 Validate.notEmpty(service, "service cannot be null or empty");
429 Validate.notEmpty(domainName, "domainName cannot be null or empty");
430
431 DirContext context = null;
432 try {
433 context = new InitialDirContext(env);
434 } catch (NamingException e) {
435 NamingException ne = new NamingException("failed to create DNS directory context");
436 ne.initCause(e);
437 throw ne;
438 }
439
440 SrvRecord[] srvRecords = null;
441
442 String lookupName;
443 if (site != null && !site.isEmpty())
444 lookupName = String.format(SRV_RR_WITH_SITES_FORMAT, service, site, domainName);
445 else
446 lookupName = String.format(SRV_RR_FORMAT, service, domainName);
447
448 try {
449 LOGGER.log(Level.FINE, "Looking up SRV RRs for ''{0}''", lookupName);
450
451 srvRecords = lookUpSrvRecords(context, lookupName);
452 } catch (NamingException e) {
453 LOGGER.log(Level.SEVERE,
454 String.format("Failed to look up SRV RRs for '%s'", lookupName), e);
455
456 throw e;
457 } finally {
458 try {
459 context.close();
460 } catch (NamingException e) {
461 ;
462 }
463 }
464
465 if (srvRecords == null)
466 LOGGER.log(Level.FINE, "No SRV RRs for ''{0}'' found", lookupName);
467 else {
468 if (LOGGER.isLoggable(Level.FINER))
469 LOGGER.log(Level.FINER, "Found {0} SRV RRs for ''{1}'': {2}", new Object[] {
470 srvRecords.length, lookupName, Arrays.toString(srvRecords) });
471 else
472 LOGGER.log(Level.FINE, "Found {0} SRV RRs for ''{1}''",
473 new Object[] { srvRecords.length, lookupName });
474 }
475
476 HostPort[] servers = sortByRfc2782(srvRecords);
477
478 if (servers == null)
479 return null;
480
481 if (LOGGER.isLoggable(Level.FINER))
482 LOGGER.log(Level.FINER, "Selected {0} servers for ''{1}'': {2}",
483 new Object[] { servers.length, lookupName, Arrays.toString(servers) });
484 else
485 LOGGER.log(Level.FINE, "Selected {0} servers for ''{1}''",
486 new Object[] { servers.length, lookupName });
487
488 return servers;
489 }
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509 public HostPort[] locate(String service, String domainName) throws NamingException {
510 return locate(service, null, domainName);
511 }
512
513 }